Switch to TQDM

This commit is contained in:
Fredrik81 2020-03-11 22:30:36 +01:00
parent 3a8b68c0fd
commit 81cbb92556
2 changed files with 57 additions and 18 deletions

View File

@ -52,8 +52,8 @@ def start_hyperopt_list(args: Dict[str, Any]) -> None:
if not export_csv:
try:
Hyperopt.print_result_table(config, trials, total_epochs,
not filteroptions['only_best'], print_colorized, 0)
print(Hyperopt.get_result_table(config, trials, total_epochs,
not filteroptions['only_best'], print_colorized, 0))
except KeyboardInterrupt:
print('User interrupted..')

View File

@ -21,7 +21,7 @@ from colorama import init as colorama_init
from joblib import (Parallel, cpu_count, delayed, dump, load,
wrap_non_picklable_objects)
from pandas import DataFrame, json_normalize, isna
import progressbar
from tqdm import tqdm
import tabulate
from os import path, popen
import io
@ -275,11 +275,37 @@ class Hyperopt:
if not self.print_all:
# Separate the results explanation string from dots
print("\n")
self.print_result_table(self.config, results, self.total_epochs,
self.print_all, self.print_colorized,
self.hyperopt_table_header)
print(self.get_result_table(
self.config, results, self.total_epochs,
self.print_all, self.print_colorized,
self.hyperopt_table_header
)
)
self.hyperopt_table_header = 2
def get_results(self, results) -> str:
"""
Log results if it is better than any previous evaluation
"""
output = ''
is_best = results['is_best']
# if not self.print_all:
# Print '\n' after each 100th epoch to separate dots from the log messages.
# Otherwise output is messy on a terminal.
# return '.', end='' if results['current_epoch'] % 100 != 0 else None # type: ignore
if self.print_all or is_best:
# if not self.print_all:
# Separate the results explanation string from dots
# print("\n")
output = self.get_result_table(
self.config, results, self.total_epochs,
self.print_all, self.print_colorized,
self.hyperopt_table_header
)
self.hyperopt_table_header = 2
return output
@staticmethod
def print_results_explanation(results, total_epochs, highlight_best: bool,
print_colorized: bool) -> None:
@ -303,13 +329,13 @@ class Hyperopt:
f"Objective: {results['loss']:.5f}")
@staticmethod
def print_result_table(config: dict, results: list, total_epochs: int, highlight_best: bool,
print_colorized: bool, remove_header: int) -> None:
def get_result_table(config: dict, results: list, total_epochs: int, highlight_best: bool,
print_colorized: bool, remove_header: int) -> str:
"""
Log result table
"""
if not results:
return
return ''
tabulate.PRESERVE_WHITESPACE = True
@ -380,7 +406,7 @@ class Hyperopt:
trials.to_dict(orient='list'), tablefmt='psql',
headers='keys', stralign="right"
)
print(table)
return table
@staticmethod
def export_csv_file(config: dict, results: list, total_epochs: int, highlight_best: bool,
@ -661,6 +687,7 @@ class Hyperopt:
try:
with Parallel(n_jobs=config_jobs) as parallel:
"""
self.progress_bar = progressbar.ProgressBar(
min_value=0,
max_value=self.total_epochs,
@ -668,9 +695,17 @@ class Hyperopt:
line_breaks=True,
enable_colors=self.print_colorized
)
self.progress_bar.start()
"""
jobs = parallel._effective_n_jobs()
logger.info(f'Effective number of parallel workers used: {jobs}')
# Define progressbar
self.progress_bar = tqdm(
total=self.total_epochs, ncols=108, unit=' Epoch',
bar_format='Epoch {n_fmt}/{total_fmt} ({percentage:3.0f}%)|{bar}|'
' [{elapsed}<{remaining} {rate_fmt}{postfix}]'
)
EVALS = ceil(self.total_epochs / jobs)
for i in range(EVALS):
# Correct the number of epochs to be processed for the last
@ -684,8 +719,7 @@ class Hyperopt:
self.fix_optimizer_models_list()
# Calculate progressbar outputs
pbar_line = ceil(self._get_height() / 2)
# pbar_line = ceil(self._get_height() / 2)
for j, val in enumerate(f_val):
# Use human-friendly indexes here (starting from 1)
current = i * jobs + j + 1
@ -699,20 +733,25 @@ class Hyperopt:
# evaluations can take different time. Here they are aligned in the
# order they will be shown to the user.
val['is_best'] = is_best
self.print_results(val)
# print(current)
output = self.get_results(val)
self.progress_bar.write(output)
# self.progress_bar.write(str(len(output.split('\n')[0])))
self.progress_bar.ncols = len(output.split('\n')[0])
self.progress_bar.update(1)
"""
if pbar_line <= current:
self.progress_bar.update(current)
pbar_line = current + ceil(self._get_height() / 2)
"""
if is_best:
self.current_best_loss = val['loss']
self.trials.append(val)
# Save results after each best epoch and every 100 epochs
if is_best or current % 100 == 0:
self.save_trials()
self.progress_bar.finish()
self.progress_bar.ncols = 108
self.progress_bar.close()
# self.progress_bar.update(current)
except KeyboardInterrupt: