Switch to TQDM
This commit is contained in:
parent
3a8b68c0fd
commit
81cbb92556
@ -52,8 +52,8 @@ def start_hyperopt_list(args: Dict[str, Any]) -> None:
|
|||||||
|
|
||||||
if not export_csv:
|
if not export_csv:
|
||||||
try:
|
try:
|
||||||
Hyperopt.print_result_table(config, trials, total_epochs,
|
print(Hyperopt.get_result_table(config, trials, total_epochs,
|
||||||
not filteroptions['only_best'], print_colorized, 0)
|
not filteroptions['only_best'], print_colorized, 0))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print('User interrupted..')
|
print('User interrupted..')
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ from colorama import init as colorama_init
|
|||||||
from joblib import (Parallel, cpu_count, delayed, dump, load,
|
from joblib import (Parallel, cpu_count, delayed, dump, load,
|
||||||
wrap_non_picklable_objects)
|
wrap_non_picklable_objects)
|
||||||
from pandas import DataFrame, json_normalize, isna
|
from pandas import DataFrame, json_normalize, isna
|
||||||
import progressbar
|
from tqdm import tqdm
|
||||||
import tabulate
|
import tabulate
|
||||||
from os import path, popen
|
from os import path, popen
|
||||||
import io
|
import io
|
||||||
@ -275,11 +275,37 @@ class Hyperopt:
|
|||||||
if not self.print_all:
|
if not self.print_all:
|
||||||
# Separate the results explanation string from dots
|
# Separate the results explanation string from dots
|
||||||
print("\n")
|
print("\n")
|
||||||
self.print_result_table(self.config, results, self.total_epochs,
|
print(self.get_result_table(
|
||||||
self.print_all, self.print_colorized,
|
self.config, results, self.total_epochs,
|
||||||
self.hyperopt_table_header)
|
self.print_all, self.print_colorized,
|
||||||
|
self.hyperopt_table_header
|
||||||
|
)
|
||||||
|
)
|
||||||
self.hyperopt_table_header = 2
|
self.hyperopt_table_header = 2
|
||||||
|
|
||||||
|
def get_results(self, results) -> str:
|
||||||
|
"""
|
||||||
|
Log results if it is better than any previous evaluation
|
||||||
|
"""
|
||||||
|
output = ''
|
||||||
|
is_best = results['is_best']
|
||||||
|
# if not self.print_all:
|
||||||
|
# Print '\n' after each 100th epoch to separate dots from the log messages.
|
||||||
|
# Otherwise output is messy on a terminal.
|
||||||
|
# return '.', end='' if results['current_epoch'] % 100 != 0 else None # type: ignore
|
||||||
|
|
||||||
|
if self.print_all or is_best:
|
||||||
|
# if not self.print_all:
|
||||||
|
# Separate the results explanation string from dots
|
||||||
|
# print("\n")
|
||||||
|
output = self.get_result_table(
|
||||||
|
self.config, results, self.total_epochs,
|
||||||
|
self.print_all, self.print_colorized,
|
||||||
|
self.hyperopt_table_header
|
||||||
|
)
|
||||||
|
self.hyperopt_table_header = 2
|
||||||
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def print_results_explanation(results, total_epochs, highlight_best: bool,
|
def print_results_explanation(results, total_epochs, highlight_best: bool,
|
||||||
print_colorized: bool) -> None:
|
print_colorized: bool) -> None:
|
||||||
@ -303,13 +329,13 @@ class Hyperopt:
|
|||||||
f"Objective: {results['loss']:.5f}")
|
f"Objective: {results['loss']:.5f}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def print_result_table(config: dict, results: list, total_epochs: int, highlight_best: bool,
|
def get_result_table(config: dict, results: list, total_epochs: int, highlight_best: bool,
|
||||||
print_colorized: bool, remove_header: int) -> None:
|
print_colorized: bool, remove_header: int) -> str:
|
||||||
"""
|
"""
|
||||||
Log result table
|
Log result table
|
||||||
"""
|
"""
|
||||||
if not results:
|
if not results:
|
||||||
return
|
return ''
|
||||||
|
|
||||||
tabulate.PRESERVE_WHITESPACE = True
|
tabulate.PRESERVE_WHITESPACE = True
|
||||||
|
|
||||||
@ -380,7 +406,7 @@ class Hyperopt:
|
|||||||
trials.to_dict(orient='list'), tablefmt='psql',
|
trials.to_dict(orient='list'), tablefmt='psql',
|
||||||
headers='keys', stralign="right"
|
headers='keys', stralign="right"
|
||||||
)
|
)
|
||||||
print(table)
|
return table
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def export_csv_file(config: dict, results: list, total_epochs: int, highlight_best: bool,
|
def export_csv_file(config: dict, results: list, total_epochs: int, highlight_best: bool,
|
||||||
@ -661,6 +687,7 @@ class Hyperopt:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with Parallel(n_jobs=config_jobs) as parallel:
|
with Parallel(n_jobs=config_jobs) as parallel:
|
||||||
|
"""
|
||||||
self.progress_bar = progressbar.ProgressBar(
|
self.progress_bar = progressbar.ProgressBar(
|
||||||
min_value=0,
|
min_value=0,
|
||||||
max_value=self.total_epochs,
|
max_value=self.total_epochs,
|
||||||
@ -668,9 +695,17 @@ class Hyperopt:
|
|||||||
line_breaks=True,
|
line_breaks=True,
|
||||||
enable_colors=self.print_colorized
|
enable_colors=self.print_colorized
|
||||||
)
|
)
|
||||||
self.progress_bar.start()
|
"""
|
||||||
jobs = parallel._effective_n_jobs()
|
jobs = parallel._effective_n_jobs()
|
||||||
logger.info(f'Effective number of parallel workers used: {jobs}')
|
logger.info(f'Effective number of parallel workers used: {jobs}')
|
||||||
|
|
||||||
|
# Define progressbar
|
||||||
|
self.progress_bar = tqdm(
|
||||||
|
total=self.total_epochs, ncols=108, unit=' Epoch',
|
||||||
|
bar_format='Epoch {n_fmt}/{total_fmt} ({percentage:3.0f}%)|{bar}|'
|
||||||
|
' [{elapsed}<{remaining} {rate_fmt}{postfix}]'
|
||||||
|
)
|
||||||
|
|
||||||
EVALS = ceil(self.total_epochs / jobs)
|
EVALS = ceil(self.total_epochs / jobs)
|
||||||
for i in range(EVALS):
|
for i in range(EVALS):
|
||||||
# Correct the number of epochs to be processed for the last
|
# Correct the number of epochs to be processed for the last
|
||||||
@ -684,8 +719,7 @@ class Hyperopt:
|
|||||||
self.fix_optimizer_models_list()
|
self.fix_optimizer_models_list()
|
||||||
|
|
||||||
# Calculate progressbar outputs
|
# Calculate progressbar outputs
|
||||||
pbar_line = ceil(self._get_height() / 2)
|
# pbar_line = ceil(self._get_height() / 2)
|
||||||
|
|
||||||
for j, val in enumerate(f_val):
|
for j, val in enumerate(f_val):
|
||||||
# Use human-friendly indexes here (starting from 1)
|
# Use human-friendly indexes here (starting from 1)
|
||||||
current = i * jobs + j + 1
|
current = i * jobs + j + 1
|
||||||
@ -699,20 +733,25 @@ class Hyperopt:
|
|||||||
# evaluations can take different time. Here they are aligned in the
|
# evaluations can take different time. Here they are aligned in the
|
||||||
# order they will be shown to the user.
|
# order they will be shown to the user.
|
||||||
val['is_best'] = is_best
|
val['is_best'] = is_best
|
||||||
|
# print(current)
|
||||||
self.print_results(val)
|
output = self.get_results(val)
|
||||||
|
self.progress_bar.write(output)
|
||||||
|
# self.progress_bar.write(str(len(output.split('\n')[0])))
|
||||||
|
self.progress_bar.ncols = len(output.split('\n')[0])
|
||||||
|
self.progress_bar.update(1)
|
||||||
|
"""
|
||||||
if pbar_line <= current:
|
if pbar_line <= current:
|
||||||
self.progress_bar.update(current)
|
self.progress_bar.update(current)
|
||||||
pbar_line = current + ceil(self._get_height() / 2)
|
pbar_line = current + ceil(self._get_height() / 2)
|
||||||
|
"""
|
||||||
if is_best:
|
if is_best:
|
||||||
self.current_best_loss = val['loss']
|
self.current_best_loss = val['loss']
|
||||||
self.trials.append(val)
|
self.trials.append(val)
|
||||||
# Save results after each best epoch and every 100 epochs
|
# Save results after each best epoch and every 100 epochs
|
||||||
if is_best or current % 100 == 0:
|
if is_best or current % 100 == 0:
|
||||||
self.save_trials()
|
self.save_trials()
|
||||||
self.progress_bar.finish()
|
self.progress_bar.ncols = 108
|
||||||
|
self.progress_bar.close()
|
||||||
|
|
||||||
# self.progress_bar.update(current)
|
# self.progress_bar.update(current)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
Loading…
Reference in New Issue
Block a user