Replace hyperopt progressbar with rich progressbar
This commit is contained in:
parent
299e788891
commit
bfd9e35e34
@ -13,13 +13,13 @@ from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import progressbar
|
||||
import rapidjson
|
||||
from colorama import Fore, Style
|
||||
from colorama import init as colorama_init
|
||||
from joblib import Parallel, cpu_count, delayed, dump, load, wrap_non_picklable_objects
|
||||
from joblib.externals import cloudpickle
|
||||
from pandas import DataFrame
|
||||
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, TaskProgressColumn, TextColumn,
|
||||
TimeElapsedColumn, TimeRemainingColumn)
|
||||
|
||||
from freqtrade.constants import DATETIME_PRINT_FORMAT, FTHYPT_FILEVERSION, LAST_BT_RESULT_FN, Config
|
||||
from freqtrade.data.converter import trim_dataframes
|
||||
@ -44,8 +44,6 @@ with warnings.catch_warnings():
|
||||
from skopt import Optimizer
|
||||
from skopt.space import Dimension
|
||||
|
||||
progressbar.streams.wrap_stderr()
|
||||
progressbar.streams.wrap_stdout()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -520,29 +518,6 @@ class Hyperopt:
|
||||
else:
|
||||
return self.opt.ask(n_points=n_points), [False for _ in range(n_points)]
|
||||
|
||||
def get_progressbar_widgets(self):
|
||||
if self.print_colorized:
|
||||
widgets = [
|
||||
' [Epoch ', progressbar.Counter(), ' of ', str(self.total_epochs),
|
||||
' (', progressbar.Percentage(), ')] ',
|
||||
progressbar.Bar(marker=progressbar.AnimatedMarker(
|
||||
fill='\N{FULL BLOCK}',
|
||||
fill_wrap=Fore.GREEN + '{}' + Fore.RESET,
|
||||
marker_wrap=Style.BRIGHT + '{}' + Style.RESET_ALL,
|
||||
)),
|
||||
' [', progressbar.ETA(), ', ', progressbar.Timer(), ']',
|
||||
]
|
||||
else:
|
||||
widgets = [
|
||||
' [Epoch ', progressbar.Counter(), ' of ', str(self.total_epochs),
|
||||
' (', progressbar.Percentage(), ')] ',
|
||||
progressbar.Bar(marker=progressbar.AnimatedMarker(
|
||||
fill='\N{FULL BLOCK}',
|
||||
)),
|
||||
' [', progressbar.ETA(), ', ', progressbar.Timer(), ']',
|
||||
]
|
||||
return widgets
|
||||
|
||||
def evaluate_result(self, val: Dict[str, Any], current: int, is_random: bool):
|
||||
"""
|
||||
Evaluate results returned from generate_optimizer
|
||||
@ -602,11 +577,18 @@ class Hyperopt:
|
||||
logger.info(f'Effective number of parallel workers used: {jobs}')
|
||||
|
||||
# Define progressbar
|
||||
widgets = self.get_progressbar_widgets()
|
||||
with progressbar.ProgressBar(
|
||||
max_value=self.total_epochs, redirect_stdout=False, redirect_stderr=False,
|
||||
widgets=widgets
|
||||
with Progress(
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(bar_width=None),
|
||||
MofNCompleteColumn(),
|
||||
TaskProgressColumn(),
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
expand=True,
|
||||
) as pbar:
|
||||
task = pbar.add_task("Epochs", total=self.total_epochs)
|
||||
|
||||
start = 0
|
||||
|
||||
if self.analyze_per_epoch:
|
||||
@ -616,7 +598,7 @@ class Hyperopt:
|
||||
f_val0 = self.generate_optimizer(asked[0])
|
||||
self.opt.tell(asked, [f_val0['loss']])
|
||||
self.evaluate_result(f_val0, 1, is_random[0])
|
||||
pbar.update(1)
|
||||
pbar.update(task, advance=1)
|
||||
start += 1
|
||||
|
||||
evals = ceil((self.total_epochs - start) / jobs)
|
||||
@ -630,14 +612,12 @@ class Hyperopt:
|
||||
f_val = self.run_optimizer_parallel(parallel, asked)
|
||||
self.opt.tell(asked, [v['loss'] for v in f_val])
|
||||
|
||||
# Calculate progressbar outputs
|
||||
for j, val in enumerate(f_val):
|
||||
# Use human-friendly indexes here (starting from 1)
|
||||
current = i * jobs + j + 1 + start
|
||||
|
||||
self.evaluate_result(val, current, is_random[j])
|
||||
|
||||
pbar.update(current)
|
||||
pbar.update(task, advance=1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print('User interrupted..')
|
||||
|
Loading…
Reference in New Issue
Block a user