refactor environment inheritence tree to accommodate flexible action types/counts. fix bug in train profit handling

This commit is contained in:
robcaulk
2022-08-28 19:21:57 +02:00
parent 8c313b431d
commit 7766350c15
8 changed files with 339 additions and 440 deletions

View File

@@ -1,15 +1,14 @@
import logging
from pathlib import Path
from typing import Any, Dict
import torch as th
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Positions
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
from pathlib import Path
# from pandas import DataFrame
# from stable_baselines3.common.callbacks import EvalCallback
# from stable_baselines3.common.monitor import Monitor
import numpy as np
import torch as th
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
logger = logging.getLogger(__name__)
@@ -53,7 +52,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
return model
class MyRLEnv(BaseReinforcementLearningModel.MyRLEnv):
class MyRLEnv(Base5ActionRLEnv):
"""
User can override any function in BaseRLEnv and gym.Env. Here the user
sets a custom reward based on profit and trade duration.

View File

@@ -1,15 +1,16 @@
import logging
from pathlib import Path
from typing import Any, Dict # , Tuple
# import numpy.typing as npt
import torch as th
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.BaseReinforcementLearningModel import (BaseReinforcementLearningModel,
make_env)
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from pathlib import Path
logger = logging.getLogger(__name__)
@@ -26,7 +27,7 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
# model arch
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[256, 256])
net_arch=[256, 256, 128])
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
@@ -64,9 +65,9 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
test_df = data_dictionary["test_features"]
env_id = "train_env"
num_cpu = int(self.freqai_info["rl_config"]["thread_count"] / 2)
num_cpu = int(self.freqai_info["rl_config"]["thread_count"])
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
self.reward_params, self.CONV_WIDTH,
self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config) for i
in range(num_cpu)])