refactor environment inheritence tree to accommodate flexible action types/counts. fix bug in train profit handling

This commit is contained in:
robcaulk
2022-08-28 19:21:57 +02:00
parent 8c313b431d
commit 7766350c15
8 changed files with 339 additions and 440 deletions

View File

@@ -1,25 +1,28 @@
import logging
from typing import Any, Dict, Tuple
from abc import abstractmethod
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import gym
import numpy as np
import numpy.typing as npt
import pandas as pd
import torch as th
import torch.multiprocessing
from pandas import DataFrame
from abc import abstractmethod
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Base5ActionRLEnv, Actions, Positions
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
from freqtrade.persistence import Trade
import torch.multiprocessing
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
import torch as th
from typing import Callable
from datetime import datetime, timezone
from stable_baselines3.common.utils import set_random_seed
import gym
from pathlib import Path
logger = logging.getLogger(__name__)
torch.multiprocessing.set_sharing_strategy('file_system')
@@ -37,8 +40,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
super().__init__(config=kwargs['config'])
th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4))
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Base5ActionRLEnv = None
self.eval_env: Base5ActionRLEnv = None
self.train_env: BaseEnvironment = None
self.eval_env: BaseEnvironment = None
self.eval_callback: EvalCallback = None
self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config']
@@ -194,7 +197,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def _predict(window):
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
observations = dataframe.iloc[window.index]
observations['current_profit'] = current_profit
observations['current_profit_pct'] = current_profit
observations['position'] = market_side
observations['trade_duration'] = trade_duration
res, _ = model.predict(observations, deterministic=True)
@@ -306,7 +309,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return
def make_env(MyRLEnv: Base5ActionRLEnv, env_id: str, rank: int,
def make_env(MyRLEnv: BaseEnvironment, env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame,
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
config: Dict[str, Any] = {}) -> Callable: