Merge pull request #3054 from Fredrik81/progress-bar
Hyperopt: Progressbar during hyperopt
This commit is contained in:
commit
18a6c98a82
@ -52,8 +52,8 @@ def start_hyperopt_list(args: Dict[str, Any]) -> None:
|
|||||||
|
|
||||||
if not export_csv:
|
if not export_csv:
|
||||||
try:
|
try:
|
||||||
Hyperopt.print_result_table(config, trials, total_epochs,
|
print(Hyperopt.get_result_table(config, trials, total_epochs,
|
||||||
not filteroptions['only_best'], print_colorized, 0)
|
not filteroptions['only_best'], print_colorized, 0))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print('User interrupted..')
|
print('User interrupted..')
|
||||||
|
|
||||||
|
@ -7,7 +7,6 @@ This module contains the hyperopt logic
|
|||||||
import locale
|
import locale
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import sys
|
|
||||||
import warnings
|
import warnings
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
@ -18,10 +17,10 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
import rapidjson
|
import rapidjson
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
from colorama import init as colorama_init
|
|
||||||
from joblib import (Parallel, cpu_count, delayed, dump, load,
|
from joblib import (Parallel, cpu_count, delayed, dump, load,
|
||||||
wrap_non_picklable_objects)
|
wrap_non_picklable_objects)
|
||||||
from pandas import DataFrame, json_normalize, isna
|
from pandas import DataFrame, json_normalize, isna
|
||||||
|
import progressbar
|
||||||
import tabulate
|
import tabulate
|
||||||
from os import path
|
from os import path
|
||||||
import io
|
import io
|
||||||
@ -43,7 +42,8 @@ with warnings.catch_warnings():
|
|||||||
from skopt import Optimizer
|
from skopt import Optimizer
|
||||||
from skopt.space import Dimension
|
from skopt.space import Dimension
|
||||||
|
|
||||||
|
progressbar.streams.wrap_stderr()
|
||||||
|
progressbar.streams.wrap_stdout()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -266,21 +266,33 @@ class Hyperopt:
|
|||||||
Log results if it is better than any previous evaluation
|
Log results if it is better than any previous evaluation
|
||||||
"""
|
"""
|
||||||
is_best = results['is_best']
|
is_best = results['is_best']
|
||||||
if not self.print_all:
|
|
||||||
# Print '\n' after each 100th epoch to separate dots from the log messages.
|
|
||||||
# Otherwise output is messy on a terminal.
|
|
||||||
print('.', end='' if results['current_epoch'] % 100 != 0 else None) # type: ignore
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
if self.print_all or is_best:
|
if self.print_all or is_best:
|
||||||
if not self.print_all:
|
print(
|
||||||
# Separate the results explanation string from dots
|
self.get_result_table(
|
||||||
print("\n")
|
self.config, results, self.total_epochs,
|
||||||
self.print_result_table(self.config, results, self.total_epochs,
|
|
||||||
self.print_all, self.print_colorized,
|
self.print_all, self.print_colorized,
|
||||||
self.hyperopt_table_header)
|
self.hyperopt_table_header
|
||||||
|
)
|
||||||
|
)
|
||||||
self.hyperopt_table_header = 2
|
self.hyperopt_table_header = 2
|
||||||
|
|
||||||
|
def get_results(self, results) -> str:
|
||||||
|
"""
|
||||||
|
Log results if it is better than any previous evaluation
|
||||||
|
"""
|
||||||
|
output = ''
|
||||||
|
is_best = results['is_best']
|
||||||
|
|
||||||
|
if self.print_all or is_best:
|
||||||
|
output = self.get_result_table(
|
||||||
|
self.config, results, self.total_epochs,
|
||||||
|
self.print_all, self.print_colorized,
|
||||||
|
self.hyperopt_table_header
|
||||||
|
)
|
||||||
|
self.hyperopt_table_header = 2
|
||||||
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def print_results_explanation(results, total_epochs, highlight_best: bool,
|
def print_results_explanation(results, total_epochs, highlight_best: bool,
|
||||||
print_colorized: bool) -> None:
|
print_colorized: bool) -> None:
|
||||||
@ -304,13 +316,13 @@ class Hyperopt:
|
|||||||
f"Objective: {results['loss']:.5f}")
|
f"Objective: {results['loss']:.5f}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def print_result_table(config: dict, results: list, total_epochs: int, highlight_best: bool,
|
def get_result_table(config: dict, results: list, total_epochs: int, highlight_best: bool,
|
||||||
print_colorized: bool, remove_header: int) -> None:
|
print_colorized: bool, remove_header: int) -> str:
|
||||||
"""
|
"""
|
||||||
Log result table
|
Log result table
|
||||||
"""
|
"""
|
||||||
if not results:
|
if not results:
|
||||||
return
|
return ''
|
||||||
|
|
||||||
tabulate.PRESERVE_WHITESPACE = True
|
tabulate.PRESERVE_WHITESPACE = True
|
||||||
|
|
||||||
@ -381,7 +393,7 @@ class Hyperopt:
|
|||||||
trials.to_dict(orient='list'), tablefmt='psql',
|
trials.to_dict(orient='list'), tablefmt='psql',
|
||||||
headers='keys', stralign="right"
|
headers='keys', stralign="right"
|
||||||
)
|
)
|
||||||
print(table)
|
return table
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def export_csv_file(config: dict, results: list, total_epochs: int, highlight_best: bool,
|
def export_csv_file(config: dict, results: list, total_epochs: int, highlight_best: bool,
|
||||||
@ -654,13 +666,36 @@ class Hyperopt:
|
|||||||
self.dimensions: List[Dimension] = self.hyperopt_space()
|
self.dimensions: List[Dimension] = self.hyperopt_space()
|
||||||
self.opt = self.get_optimizer(self.dimensions, config_jobs)
|
self.opt = self.get_optimizer(self.dimensions, config_jobs)
|
||||||
|
|
||||||
if self.print_colorized:
|
|
||||||
colorama_init(autoreset=True)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Parallel(n_jobs=config_jobs) as parallel:
|
with Parallel(n_jobs=config_jobs) as parallel:
|
||||||
jobs = parallel._effective_n_jobs()
|
jobs = parallel._effective_n_jobs()
|
||||||
logger.info(f'Effective number of parallel workers used: {jobs}')
|
logger.info(f'Effective number of parallel workers used: {jobs}')
|
||||||
|
|
||||||
|
# Define progressbar
|
||||||
|
if self.print_colorized:
|
||||||
|
widgets = [
|
||||||
|
' [Epoch ', progressbar.Counter(), ' of ', str(self.total_epochs),
|
||||||
|
' (', progressbar.Percentage(), ')] ',
|
||||||
|
progressbar.Bar(marker=progressbar.AnimatedMarker(
|
||||||
|
fill='\N{FULL BLOCK}',
|
||||||
|
fill_wrap=Fore.GREEN + '{}' + Fore.RESET,
|
||||||
|
marker_wrap=Style.BRIGHT + '{}' + Style.RESET_ALL,
|
||||||
|
)),
|
||||||
|
' [', progressbar.ETA(), ', ', progressbar.Timer(), ']',
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
widgets = [
|
||||||
|
' [Epoch ', progressbar.Counter(), ' of ', str(self.total_epochs),
|
||||||
|
' (', progressbar.Percentage(), ')] ',
|
||||||
|
progressbar.Bar(marker=progressbar.AnimatedMarker(
|
||||||
|
fill='\N{FULL BLOCK}',
|
||||||
|
)),
|
||||||
|
' [', progressbar.ETA(), ', ', progressbar.Timer(), ']',
|
||||||
|
]
|
||||||
|
with progressbar.ProgressBar(
|
||||||
|
maxval=self.total_epochs, redirect_stdout=True, redirect_stderr=True,
|
||||||
|
widgets=widgets
|
||||||
|
) as pbar:
|
||||||
EVALS = ceil(self.total_epochs / jobs)
|
EVALS = ceil(self.total_epochs / jobs)
|
||||||
for i in range(EVALS):
|
for i in range(EVALS):
|
||||||
# Correct the number of epochs to be processed for the last
|
# Correct the number of epochs to be processed for the last
|
||||||
@ -673,11 +708,13 @@ class Hyperopt:
|
|||||||
self.opt.tell(asked, [v['loss'] for v in f_val])
|
self.opt.tell(asked, [v['loss'] for v in f_val])
|
||||||
self.fix_optimizer_models_list()
|
self.fix_optimizer_models_list()
|
||||||
|
|
||||||
|
# Calculate progressbar outputs
|
||||||
for j, val in enumerate(f_val):
|
for j, val in enumerate(f_val):
|
||||||
# Use human-friendly indexes here (starting from 1)
|
# Use human-friendly indexes here (starting from 1)
|
||||||
current = i * jobs + j + 1
|
current = i * jobs + j + 1
|
||||||
val['current_epoch'] = current
|
val['current_epoch'] = current
|
||||||
val['is_initial_point'] = current <= INITIAL_POINTS
|
val['is_initial_point'] = current <= INITIAL_POINTS
|
||||||
|
|
||||||
logger.debug(f"Optimizer epoch evaluated: {val}")
|
logger.debug(f"Optimizer epoch evaluated: {val}")
|
||||||
|
|
||||||
is_best = self.is_best_loss(val, self.current_best_loss)
|
is_best = self.is_best_loss(val, self.current_best_loss)
|
||||||
@ -686,15 +723,18 @@ class Hyperopt:
|
|||||||
# evaluations can take different time. Here they are aligned in the
|
# evaluations can take different time. Here they are aligned in the
|
||||||
# order they will be shown to the user.
|
# order they will be shown to the user.
|
||||||
val['is_best'] = is_best
|
val['is_best'] = is_best
|
||||||
|
|
||||||
self.print_results(val)
|
self.print_results(val)
|
||||||
|
|
||||||
if is_best:
|
if is_best:
|
||||||
self.current_best_loss = val['loss']
|
self.current_best_loss = val['loss']
|
||||||
self.trials.append(val)
|
self.trials.append(val)
|
||||||
|
|
||||||
# Save results after each best epoch and every 100 epochs
|
# Save results after each best epoch and every 100 epochs
|
||||||
if is_best or current % 100 == 0:
|
if is_best or current % 100 == 0:
|
||||||
self.save_trials()
|
self.save_trials()
|
||||||
|
|
||||||
|
pbar.update(current)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print('User interrupted..')
|
print('User interrupted..')
|
||||||
|
|
||||||
|
@ -7,3 +7,4 @@ scikit-learn==0.22.2.post1
|
|||||||
scikit-optimize==0.7.4
|
scikit-optimize==0.7.4
|
||||||
filelock==3.0.12
|
filelock==3.0.12
|
||||||
joblib==0.14.1
|
joblib==0.14.1
|
||||||
|
progressbar2==3.50.1
|
||||||
|
Loading…
Reference in New Issue
Block a user