Merge pull request #3054 from Fredrik81/progress-bar

Hyperopt: Progressbar during hyperopt
This commit is contained in:
Matthias 2020-04-12 09:32:52 +02:00 committed by GitHub
commit 18a6c98a82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 97 additions and 55 deletions

View File

@ -52,8 +52,8 @@ def start_hyperopt_list(args: Dict[str, Any]) -> None:
if not export_csv: if not export_csv:
try: try:
Hyperopt.print_result_table(config, trials, total_epochs, print(Hyperopt.get_result_table(config, trials, total_epochs,
not filteroptions['only_best'], print_colorized, 0) not filteroptions['only_best'], print_colorized, 0))
except KeyboardInterrupt: except KeyboardInterrupt:
print('User interrupted..') print('User interrupted..')

View File

@ -7,7 +7,6 @@ This module contains the hyperopt logic
import locale import locale
import logging import logging
import random import random
import sys
import warnings import warnings
from math import ceil from math import ceil
from collections import OrderedDict from collections import OrderedDict
@ -18,10 +17,10 @@ from typing import Any, Dict, List, Optional
import rapidjson import rapidjson
from colorama import Fore, Style from colorama import Fore, Style
from colorama import init as colorama_init
from joblib import (Parallel, cpu_count, delayed, dump, load, from joblib import (Parallel, cpu_count, delayed, dump, load,
wrap_non_picklable_objects) wrap_non_picklable_objects)
from pandas import DataFrame, json_normalize, isna from pandas import DataFrame, json_normalize, isna
import progressbar
import tabulate import tabulate
from os import path from os import path
import io import io
@ -43,7 +42,8 @@ with warnings.catch_warnings():
from skopt import Optimizer from skopt import Optimizer
from skopt.space import Dimension from skopt.space import Dimension
progressbar.streams.wrap_stderr()
progressbar.streams.wrap_stdout()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -266,21 +266,33 @@ class Hyperopt:
Log results if it is better than any previous evaluation Log results if it is better than any previous evaluation
""" """
is_best = results['is_best'] is_best = results['is_best']
if not self.print_all:
# Print '\n' after each 100th epoch to separate dots from the log messages.
# Otherwise output is messy on a terminal.
print('.', end='' if results['current_epoch'] % 100 != 0 else None) # type: ignore
sys.stdout.flush()
if self.print_all or is_best: if self.print_all or is_best:
if not self.print_all: print(
# Separate the results explanation string from dots self.get_result_table(
print("\n") self.config, results, self.total_epochs,
self.print_result_table(self.config, results, self.total_epochs,
self.print_all, self.print_colorized, self.print_all, self.print_colorized,
self.hyperopt_table_header) self.hyperopt_table_header
)
)
self.hyperopt_table_header = 2 self.hyperopt_table_header = 2
def get_results(self, results) -> str:
"""
Log results if it is better than any previous evaluation
"""
output = ''
is_best = results['is_best']
if self.print_all or is_best:
output = self.get_result_table(
self.config, results, self.total_epochs,
self.print_all, self.print_colorized,
self.hyperopt_table_header
)
self.hyperopt_table_header = 2
return output
@staticmethod @staticmethod
def print_results_explanation(results, total_epochs, highlight_best: bool, def print_results_explanation(results, total_epochs, highlight_best: bool,
print_colorized: bool) -> None: print_colorized: bool) -> None:
@ -304,13 +316,13 @@ class Hyperopt:
f"Objective: {results['loss']:.5f}") f"Objective: {results['loss']:.5f}")
@staticmethod @staticmethod
def print_result_table(config: dict, results: list, total_epochs: int, highlight_best: bool, def get_result_table(config: dict, results: list, total_epochs: int, highlight_best: bool,
print_colorized: bool, remove_header: int) -> None: print_colorized: bool, remove_header: int) -> str:
""" """
Log result table Log result table
""" """
if not results: if not results:
return return ''
tabulate.PRESERVE_WHITESPACE = True tabulate.PRESERVE_WHITESPACE = True
@ -381,7 +393,7 @@ class Hyperopt:
trials.to_dict(orient='list'), tablefmt='psql', trials.to_dict(orient='list'), tablefmt='psql',
headers='keys', stralign="right" headers='keys', stralign="right"
) )
print(table) return table
@staticmethod @staticmethod
def export_csv_file(config: dict, results: list, total_epochs: int, highlight_best: bool, def export_csv_file(config: dict, results: list, total_epochs: int, highlight_best: bool,
@ -654,13 +666,36 @@ class Hyperopt:
self.dimensions: List[Dimension] = self.hyperopt_space() self.dimensions: List[Dimension] = self.hyperopt_space()
self.opt = self.get_optimizer(self.dimensions, config_jobs) self.opt = self.get_optimizer(self.dimensions, config_jobs)
if self.print_colorized:
colorama_init(autoreset=True)
try: try:
with Parallel(n_jobs=config_jobs) as parallel: with Parallel(n_jobs=config_jobs) as parallel:
jobs = parallel._effective_n_jobs() jobs = parallel._effective_n_jobs()
logger.info(f'Effective number of parallel workers used: {jobs}') logger.info(f'Effective number of parallel workers used: {jobs}')
# Define progressbar
if self.print_colorized:
widgets = [
' [Epoch ', progressbar.Counter(), ' of ', str(self.total_epochs),
' (', progressbar.Percentage(), ')] ',
progressbar.Bar(marker=progressbar.AnimatedMarker(
fill='\N{FULL BLOCK}',
fill_wrap=Fore.GREEN + '{}' + Fore.RESET,
marker_wrap=Style.BRIGHT + '{}' + Style.RESET_ALL,
)),
' [', progressbar.ETA(), ', ', progressbar.Timer(), ']',
]
else:
widgets = [
' [Epoch ', progressbar.Counter(), ' of ', str(self.total_epochs),
' (', progressbar.Percentage(), ')] ',
progressbar.Bar(marker=progressbar.AnimatedMarker(
fill='\N{FULL BLOCK}',
)),
' [', progressbar.ETA(), ', ', progressbar.Timer(), ']',
]
with progressbar.ProgressBar(
maxval=self.total_epochs, redirect_stdout=True, redirect_stderr=True,
widgets=widgets
) as pbar:
EVALS = ceil(self.total_epochs / jobs) EVALS = ceil(self.total_epochs / jobs)
for i in range(EVALS): for i in range(EVALS):
# Correct the number of epochs to be processed for the last # Correct the number of epochs to be processed for the last
@ -673,11 +708,13 @@ class Hyperopt:
self.opt.tell(asked, [v['loss'] for v in f_val]) self.opt.tell(asked, [v['loss'] for v in f_val])
self.fix_optimizer_models_list() self.fix_optimizer_models_list()
# Calculate progressbar outputs
for j, val in enumerate(f_val): for j, val in enumerate(f_val):
# Use human-friendly indexes here (starting from 1) # Use human-friendly indexes here (starting from 1)
current = i * jobs + j + 1 current = i * jobs + j + 1
val['current_epoch'] = current val['current_epoch'] = current
val['is_initial_point'] = current <= INITIAL_POINTS val['is_initial_point'] = current <= INITIAL_POINTS
logger.debug(f"Optimizer epoch evaluated: {val}") logger.debug(f"Optimizer epoch evaluated: {val}")
is_best = self.is_best_loss(val, self.current_best_loss) is_best = self.is_best_loss(val, self.current_best_loss)
@ -686,15 +723,18 @@ class Hyperopt:
# evaluations can take different time. Here they are aligned in the # evaluations can take different time. Here they are aligned in the
# order they will be shown to the user. # order they will be shown to the user.
val['is_best'] = is_best val['is_best'] = is_best
self.print_results(val) self.print_results(val)
if is_best: if is_best:
self.current_best_loss = val['loss'] self.current_best_loss = val['loss']
self.trials.append(val) self.trials.append(val)
# Save results after each best epoch and every 100 epochs # Save results after each best epoch and every 100 epochs
if is_best or current % 100 == 0: if is_best or current % 100 == 0:
self.save_trials() self.save_trials()
pbar.update(current)
except KeyboardInterrupt: except KeyboardInterrupt:
print('User interrupted..') print('User interrupted..')

View File

@ -7,3 +7,4 @@ scikit-learn==0.22.2.post1
scikit-optimize==0.7.4 scikit-optimize==0.7.4
filelock==3.0.12 filelock==3.0.12
joblib==0.14.1 joblib==0.14.1
progressbar2==3.50.1

View File

@ -24,6 +24,7 @@ hyperopt = [
'scikit-optimize', 'scikit-optimize',
'filelock', 'filelock',
'joblib', 'joblib',
'progressbar2',
] ]
develop = [ develop = [