Switch to TQDM
This commit is contained in:
parent
3a8b68c0fd
commit
81cbb92556
@ -52,8 +52,8 @@ def start_hyperopt_list(args: Dict[str, Any]) -> None:
|
||||
|
||||
if not export_csv:
|
||||
try:
|
||||
Hyperopt.print_result_table(config, trials, total_epochs,
|
||||
not filteroptions['only_best'], print_colorized, 0)
|
||||
print(Hyperopt.get_result_table(config, trials, total_epochs,
|
||||
not filteroptions['only_best'], print_colorized, 0))
|
||||
except KeyboardInterrupt:
|
||||
print('User interrupted..')
|
||||
|
||||
|
@ -21,7 +21,7 @@ from colorama import init as colorama_init
|
||||
from joblib import (Parallel, cpu_count, delayed, dump, load,
|
||||
wrap_non_picklable_objects)
|
||||
from pandas import DataFrame, json_normalize, isna
|
||||
import progressbar
|
||||
from tqdm import tqdm
|
||||
import tabulate
|
||||
from os import path, popen
|
||||
import io
|
||||
@ -275,11 +275,37 @@ class Hyperopt:
|
||||
if not self.print_all:
|
||||
# Separate the results explanation string from dots
|
||||
print("\n")
|
||||
self.print_result_table(self.config, results, self.total_epochs,
|
||||
self.print_all, self.print_colorized,
|
||||
self.hyperopt_table_header)
|
||||
print(self.get_result_table(
|
||||
self.config, results, self.total_epochs,
|
||||
self.print_all, self.print_colorized,
|
||||
self.hyperopt_table_header
|
||||
)
|
||||
)
|
||||
self.hyperopt_table_header = 2
|
||||
|
||||
def get_results(self, results) -> str:
|
||||
"""
|
||||
Log results if it is better than any previous evaluation
|
||||
"""
|
||||
output = ''
|
||||
is_best = results['is_best']
|
||||
# if not self.print_all:
|
||||
# Print '\n' after each 100th epoch to separate dots from the log messages.
|
||||
# Otherwise output is messy on a terminal.
|
||||
# return '.', end='' if results['current_epoch'] % 100 != 0 else None # type: ignore
|
||||
|
||||
if self.print_all or is_best:
|
||||
# if not self.print_all:
|
||||
# Separate the results explanation string from dots
|
||||
# print("\n")
|
||||
output = self.get_result_table(
|
||||
self.config, results, self.total_epochs,
|
||||
self.print_all, self.print_colorized,
|
||||
self.hyperopt_table_header
|
||||
)
|
||||
self.hyperopt_table_header = 2
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def print_results_explanation(results, total_epochs, highlight_best: bool,
|
||||
print_colorized: bool) -> None:
|
||||
@ -303,13 +329,13 @@ class Hyperopt:
|
||||
f"Objective: {results['loss']:.5f}")
|
||||
|
||||
@staticmethod
|
||||
def print_result_table(config: dict, results: list, total_epochs: int, highlight_best: bool,
|
||||
print_colorized: bool, remove_header: int) -> None:
|
||||
def get_result_table(config: dict, results: list, total_epochs: int, highlight_best: bool,
|
||||
print_colorized: bool, remove_header: int) -> str:
|
||||
"""
|
||||
Log result table
|
||||
"""
|
||||
if not results:
|
||||
return
|
||||
return ''
|
||||
|
||||
tabulate.PRESERVE_WHITESPACE = True
|
||||
|
||||
@ -380,7 +406,7 @@ class Hyperopt:
|
||||
trials.to_dict(orient='list'), tablefmt='psql',
|
||||
headers='keys', stralign="right"
|
||||
)
|
||||
print(table)
|
||||
return table
|
||||
|
||||
@staticmethod
|
||||
def export_csv_file(config: dict, results: list, total_epochs: int, highlight_best: bool,
|
||||
@ -661,6 +687,7 @@ class Hyperopt:
|
||||
|
||||
try:
|
||||
with Parallel(n_jobs=config_jobs) as parallel:
|
||||
"""
|
||||
self.progress_bar = progressbar.ProgressBar(
|
||||
min_value=0,
|
||||
max_value=self.total_epochs,
|
||||
@ -668,9 +695,17 @@ class Hyperopt:
|
||||
line_breaks=True,
|
||||
enable_colors=self.print_colorized
|
||||
)
|
||||
self.progress_bar.start()
|
||||
"""
|
||||
jobs = parallel._effective_n_jobs()
|
||||
logger.info(f'Effective number of parallel workers used: {jobs}')
|
||||
|
||||
# Define progressbar
|
||||
self.progress_bar = tqdm(
|
||||
total=self.total_epochs, ncols=108, unit=' Epoch',
|
||||
bar_format='Epoch {n_fmt}/{total_fmt} ({percentage:3.0f}%)|{bar}|'
|
||||
' [{elapsed}<{remaining} {rate_fmt}{postfix}]'
|
||||
)
|
||||
|
||||
EVALS = ceil(self.total_epochs / jobs)
|
||||
for i in range(EVALS):
|
||||
# Correct the number of epochs to be processed for the last
|
||||
@ -684,8 +719,7 @@ class Hyperopt:
|
||||
self.fix_optimizer_models_list()
|
||||
|
||||
# Calculate progressbar outputs
|
||||
pbar_line = ceil(self._get_height() / 2)
|
||||
|
||||
# pbar_line = ceil(self._get_height() / 2)
|
||||
for j, val in enumerate(f_val):
|
||||
# Use human-friendly indexes here (starting from 1)
|
||||
current = i * jobs + j + 1
|
||||
@ -699,20 +733,25 @@ class Hyperopt:
|
||||
# evaluations can take different time. Here they are aligned in the
|
||||
# order they will be shown to the user.
|
||||
val['is_best'] = is_best
|
||||
|
||||
self.print_results(val)
|
||||
|
||||
# print(current)
|
||||
output = self.get_results(val)
|
||||
self.progress_bar.write(output)
|
||||
# self.progress_bar.write(str(len(output.split('\n')[0])))
|
||||
self.progress_bar.ncols = len(output.split('\n')[0])
|
||||
self.progress_bar.update(1)
|
||||
"""
|
||||
if pbar_line <= current:
|
||||
self.progress_bar.update(current)
|
||||
pbar_line = current + ceil(self._get_height() / 2)
|
||||
|
||||
"""
|
||||
if is_best:
|
||||
self.current_best_loss = val['loss']
|
||||
self.trials.append(val)
|
||||
# Save results after each best epoch and every 100 epochs
|
||||
if is_best or current % 100 == 0:
|
||||
self.save_trials()
|
||||
self.progress_bar.finish()
|
||||
self.progress_bar.ncols = 108
|
||||
self.progress_bar.close()
|
||||
|
||||
# self.progress_bar.update(current)
|
||||
except KeyboardInterrupt:
|
||||
|
Loading…
Reference in New Issue
Block a user