add continual retraining feature, handly mypy typing reqs, improve docstrings

This commit is contained in:
robcaulk
2022-08-24 12:54:02 +02:00
parent b708134c1a
commit c0cee5df07
11 changed files with 387 additions and 362 deletions

View File

@@ -1,7 +1,6 @@
import logging
from typing import Any, Dict # , Tuple
from typing import Any, Dict
# import numpy.typing as npt
import torch as th
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
@@ -22,12 +21,18 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[256, 256, 128])
net_arch=[512, 512, 256])
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
tensorboard_log=Path(dk.data_path / "tensorboard"),
**self.freqai_info['model_training_parameters']
)
if dk.pair not in self.dd.model_dictionary or not self.continual_retraining:
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
tensorboard_log=Path(dk.data_path / "tensorboard"),
**self.freqai_info['model_training_parameters']
)
else:
logger.info('Continual training activated - starting training from previously '
'trained agent.')
model = self.dd.model_dictionary[dk.pair]
model.set_env(self.train_env)
model.learn(
total_timesteps=int(total_timesteps),