add continual retraining feature, handly mypy typing reqs, improve docstrings
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Dict # , Tuple
|
||||
from typing import Any, Dict
|
||||
|
||||
# import numpy.typing as npt
|
||||
import torch as th
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
|
||||
@@ -22,12 +21,18 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
||||
|
||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||
net_arch=[256, 256, 128])
|
||||
net_arch=[512, 512, 256])
|
||||
|
||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=Path(dk.data_path / "tensorboard"),
|
||||
**self.freqai_info['model_training_parameters']
|
||||
)
|
||||
if dk.pair not in self.dd.model_dictionary or not self.continual_retraining:
|
||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=Path(dk.data_path / "tensorboard"),
|
||||
**self.freqai_info['model_training_parameters']
|
||||
)
|
||||
else:
|
||||
logger.info('Continual training activated - starting training from previously '
|
||||
'trained agent.')
|
||||
model = self.dd.model_dictionary[dk.pair]
|
||||
model.set_env(self.train_env)
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(total_timesteps),
|
||||
|
||||
Reference in New Issue
Block a user