refactor environment inheritence tree to accommodate flexible action types/counts. fix bug in train profit handling
This commit is contained in:
@@ -1,25 +1,28 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Tuple
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pandas as pd
|
||||
import torch as th
|
||||
import torch.multiprocessing
|
||||
from pandas import DataFrame
|
||||
from abc import abstractmethod
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
|
||||
from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Base5ActionRLEnv, Actions, Positions
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
from freqtrade.persistence import Trade
|
||||
import torch.multiprocessing
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
import torch as th
|
||||
from typing import Callable
|
||||
from datetime import datetime, timezone
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
import gym
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
@@ -37,8 +40,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
super().__init__(config=kwargs['config'])
|
||||
th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4))
|
||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||
self.train_env: Base5ActionRLEnv = None
|
||||
self.eval_env: Base5ActionRLEnv = None
|
||||
self.train_env: BaseEnvironment = None
|
||||
self.eval_env: BaseEnvironment = None
|
||||
self.eval_callback: EvalCallback = None
|
||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||
self.rl_config = self.freqai_info['rl_config']
|
||||
@@ -194,7 +197,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
def _predict(window):
|
||||
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
|
||||
observations = dataframe.iloc[window.index]
|
||||
observations['current_profit'] = current_profit
|
||||
observations['current_profit_pct'] = current_profit
|
||||
observations['position'] = market_side
|
||||
observations['trade_duration'] = trade_duration
|
||||
res, _ = model.predict(observations, deterministic=True)
|
||||
@@ -306,7 +309,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
return
|
||||
|
||||
|
||||
def make_env(MyRLEnv: Base5ActionRLEnv, env_id: str, rank: int,
|
||||
def make_env(MyRLEnv: BaseEnvironment, env_id: str, rank: int,
|
||||
seed: int, train_df: DataFrame, price: DataFrame,
|
||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||
config: Dict[str, Any] = {}) -> Callable:
|
||||
|
Reference in New Issue
Block a user