Merge pull request #3054 from Fredrik81/progress-bar
Hyperopt: Progressbar during hyperopt
This commit is contained in:
		| @@ -52,8 +52,8 @@ def start_hyperopt_list(args: Dict[str, Any]) -> None: | |||||||
|  |  | ||||||
|     if not export_csv: |     if not export_csv: | ||||||
|         try: |         try: | ||||||
|             Hyperopt.print_result_table(config, trials, total_epochs, |             print(Hyperopt.get_result_table(config, trials, total_epochs, | ||||||
|                                         not filteroptions['only_best'], print_colorized, 0) |                                             not filteroptions['only_best'], print_colorized, 0)) | ||||||
|         except KeyboardInterrupt: |         except KeyboardInterrupt: | ||||||
|             print('User interrupted..') |             print('User interrupted..') | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,7 +7,6 @@ This module contains the hyperopt logic | |||||||
| import locale | import locale | ||||||
| import logging | import logging | ||||||
| import random | import random | ||||||
| import sys |  | ||||||
| import warnings | import warnings | ||||||
| from math import ceil | from math import ceil | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| @@ -18,10 +17,10 @@ from typing import Any, Dict, List, Optional | |||||||
|  |  | ||||||
| import rapidjson | import rapidjson | ||||||
| from colorama import Fore, Style | from colorama import Fore, Style | ||||||
| from colorama import init as colorama_init |  | ||||||
| from joblib import (Parallel, cpu_count, delayed, dump, load, | from joblib import (Parallel, cpu_count, delayed, dump, load, | ||||||
|                     wrap_non_picklable_objects) |                     wrap_non_picklable_objects) | ||||||
| from pandas import DataFrame, json_normalize, isna | from pandas import DataFrame, json_normalize, isna | ||||||
|  | import progressbar | ||||||
| import tabulate | import tabulate | ||||||
| from os import path | from os import path | ||||||
| import io | import io | ||||||
| @@ -43,7 +42,8 @@ with warnings.catch_warnings(): | |||||||
|     from skopt import Optimizer |     from skopt import Optimizer | ||||||
|     from skopt.space import Dimension |     from skopt.space import Dimension | ||||||
|  |  | ||||||
|  | progressbar.streams.wrap_stderr() | ||||||
|  | progressbar.streams.wrap_stdout() | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -266,21 +266,33 @@ class Hyperopt: | |||||||
|         Log results if it is better than any previous evaluation |         Log results if it is better than any previous evaluation | ||||||
|         """ |         """ | ||||||
|         is_best = results['is_best'] |         is_best = results['is_best'] | ||||||
|         if not self.print_all: |  | ||||||
|             # Print '\n' after each 100th epoch to separate dots from the log messages. |  | ||||||
|             # Otherwise output is messy on a terminal. |  | ||||||
|             print('.', end='' if results['current_epoch'] % 100 != 0 else None)  # type: ignore |  | ||||||
|             sys.stdout.flush() |  | ||||||
|  |  | ||||||
|         if self.print_all or is_best: |         if self.print_all or is_best: | ||||||
|             if not self.print_all: |             print( | ||||||
|                 # Separate the results explanation string from dots |                 self.get_result_table( | ||||||
|                 print("\n") |                     self.config, results, self.total_epochs, | ||||||
|             self.print_result_table(self.config, results, self.total_epochs, |  | ||||||
|                     self.print_all, self.print_colorized, |                     self.print_all, self.print_colorized, | ||||||
|                                     self.hyperopt_table_header) |                     self.hyperopt_table_header | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|             self.hyperopt_table_header = 2 |             self.hyperopt_table_header = 2 | ||||||
|  |  | ||||||
|  |     def get_results(self, results) -> str: | ||||||
|  |         """ | ||||||
|  |         Log results if it is better than any previous evaluation | ||||||
|  |         """ | ||||||
|  |         output = '' | ||||||
|  |         is_best = results['is_best'] | ||||||
|  |  | ||||||
|  |         if self.print_all or is_best: | ||||||
|  |             output = self.get_result_table( | ||||||
|  |                 self.config, results, self.total_epochs, | ||||||
|  |                 self.print_all, self.print_colorized, | ||||||
|  |                 self.hyperopt_table_header | ||||||
|  |             ) | ||||||
|  |             self.hyperopt_table_header = 2 | ||||||
|  |         return output | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def print_results_explanation(results, total_epochs, highlight_best: bool, |     def print_results_explanation(results, total_epochs, highlight_best: bool, | ||||||
|                                   print_colorized: bool) -> None: |                                   print_colorized: bool) -> None: | ||||||
| @@ -304,13 +316,13 @@ class Hyperopt: | |||||||
|                 f"Objective: {results['loss']:.5f}") |                 f"Objective: {results['loss']:.5f}") | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def print_result_table(config: dict, results: list, total_epochs: int, highlight_best: bool, |     def get_result_table(config: dict, results: list, total_epochs: int, highlight_best: bool, | ||||||
|                            print_colorized: bool, remove_header: int) -> None: |                          print_colorized: bool, remove_header: int) -> str: | ||||||
|         """ |         """ | ||||||
|         Log result table |         Log result table | ||||||
|         """ |         """ | ||||||
|         if not results: |         if not results: | ||||||
|             return |             return '' | ||||||
|  |  | ||||||
|         tabulate.PRESERVE_WHITESPACE = True |         tabulate.PRESERVE_WHITESPACE = True | ||||||
|  |  | ||||||
| @@ -381,7 +393,7 @@ class Hyperopt: | |||||||
|                 trials.to_dict(orient='list'), tablefmt='psql', |                 trials.to_dict(orient='list'), tablefmt='psql', | ||||||
|                 headers='keys', stralign="right" |                 headers='keys', stralign="right" | ||||||
|             ) |             ) | ||||||
|         print(table) |         return table | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def export_csv_file(config: dict, results: list, total_epochs: int, highlight_best: bool, |     def export_csv_file(config: dict, results: list, total_epochs: int, highlight_best: bool, | ||||||
| @@ -654,13 +666,36 @@ class Hyperopt: | |||||||
|         self.dimensions: List[Dimension] = self.hyperopt_space() |         self.dimensions: List[Dimension] = self.hyperopt_space() | ||||||
|         self.opt = self.get_optimizer(self.dimensions, config_jobs) |         self.opt = self.get_optimizer(self.dimensions, config_jobs) | ||||||
|  |  | ||||||
|         if self.print_colorized: |  | ||||||
|             colorama_init(autoreset=True) |  | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             with Parallel(n_jobs=config_jobs) as parallel: |             with Parallel(n_jobs=config_jobs) as parallel: | ||||||
|                 jobs = parallel._effective_n_jobs() |                 jobs = parallel._effective_n_jobs() | ||||||
|                 logger.info(f'Effective number of parallel workers used: {jobs}') |                 logger.info(f'Effective number of parallel workers used: {jobs}') | ||||||
|  |  | ||||||
|  |                 # Define progressbar | ||||||
|  |                 if self.print_colorized: | ||||||
|  |                     widgets = [ | ||||||
|  |                         ' [Epoch ', progressbar.Counter(), ' of ', str(self.total_epochs), | ||||||
|  |                         ' (', progressbar.Percentage(), ')] ', | ||||||
|  |                         progressbar.Bar(marker=progressbar.AnimatedMarker( | ||||||
|  |                             fill='\N{FULL BLOCK}', | ||||||
|  |                             fill_wrap=Fore.GREEN + '{}' + Fore.RESET, | ||||||
|  |                             marker_wrap=Style.BRIGHT + '{}' + Style.RESET_ALL, | ||||||
|  |                         )), | ||||||
|  |                         ' [', progressbar.ETA(), ', ', progressbar.Timer(), ']', | ||||||
|  |                     ] | ||||||
|  |                 else: | ||||||
|  |                     widgets = [ | ||||||
|  |                         ' [Epoch ', progressbar.Counter(), ' of ', str(self.total_epochs), | ||||||
|  |                         ' (', progressbar.Percentage(), ')] ', | ||||||
|  |                         progressbar.Bar(marker=progressbar.AnimatedMarker( | ||||||
|  |                             fill='\N{FULL BLOCK}', | ||||||
|  |                         )), | ||||||
|  |                         ' [', progressbar.ETA(), ', ', progressbar.Timer(), ']', | ||||||
|  |                     ] | ||||||
|  |                 with progressbar.ProgressBar( | ||||||
|  |                          maxval=self.total_epochs, redirect_stdout=True, redirect_stderr=True, | ||||||
|  |                          widgets=widgets | ||||||
|  |                      ) as pbar: | ||||||
|                     EVALS = ceil(self.total_epochs / jobs) |                     EVALS = ceil(self.total_epochs / jobs) | ||||||
|                     for i in range(EVALS): |                     for i in range(EVALS): | ||||||
|                         # Correct the number of epochs to be processed for the last |                         # Correct the number of epochs to be processed for the last | ||||||
| @@ -673,11 +708,13 @@ class Hyperopt: | |||||||
|                         self.opt.tell(asked, [v['loss'] for v in f_val]) |                         self.opt.tell(asked, [v['loss'] for v in f_val]) | ||||||
|                         self.fix_optimizer_models_list() |                         self.fix_optimizer_models_list() | ||||||
|  |  | ||||||
|  |                         # Calculate progressbar outputs | ||||||
|                         for j, val in enumerate(f_val): |                         for j, val in enumerate(f_val): | ||||||
|                             # Use human-friendly indexes here (starting from 1) |                             # Use human-friendly indexes here (starting from 1) | ||||||
|                             current = i * jobs + j + 1 |                             current = i * jobs + j + 1 | ||||||
|                             val['current_epoch'] = current |                             val['current_epoch'] = current | ||||||
|                             val['is_initial_point'] = current <= INITIAL_POINTS |                             val['is_initial_point'] = current <= INITIAL_POINTS | ||||||
|  |  | ||||||
|                             logger.debug(f"Optimizer epoch evaluated: {val}") |                             logger.debug(f"Optimizer epoch evaluated: {val}") | ||||||
|  |  | ||||||
|                             is_best = self.is_best_loss(val, self.current_best_loss) |                             is_best = self.is_best_loss(val, self.current_best_loss) | ||||||
| @@ -686,15 +723,18 @@ class Hyperopt: | |||||||
|                             # evaluations can take different time. Here they are aligned in the |                             # evaluations can take different time. Here they are aligned in the | ||||||
|                             # order they will be shown to the user. |                             # order they will be shown to the user. | ||||||
|                             val['is_best'] = is_best |                             val['is_best'] = is_best | ||||||
|  |  | ||||||
|                             self.print_results(val) |                             self.print_results(val) | ||||||
|  |  | ||||||
|                             if is_best: |                             if is_best: | ||||||
|                                 self.current_best_loss = val['loss'] |                                 self.current_best_loss = val['loss'] | ||||||
|                             self.trials.append(val) |                             self.trials.append(val) | ||||||
|  |  | ||||||
|                             # Save results after each best epoch and every 100 epochs |                             # Save results after each best epoch and every 100 epochs | ||||||
|                             if is_best or current % 100 == 0: |                             if is_best or current % 100 == 0: | ||||||
|                                 self.save_trials() |                                 self.save_trials() | ||||||
|  |  | ||||||
|  |                             pbar.update(current) | ||||||
|  |  | ||||||
|         except KeyboardInterrupt: |         except KeyboardInterrupt: | ||||||
|             print('User interrupted..') |             print('User interrupted..') | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,3 +7,4 @@ scikit-learn==0.22.2.post1 | |||||||
| scikit-optimize==0.7.4 | scikit-optimize==0.7.4 | ||||||
| filelock==3.0.12 | filelock==3.0.12 | ||||||
| joblib==0.14.1 | joblib==0.14.1 | ||||||
|  | progressbar2==3.50.1 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user