reinforce training with state info, reinforce prediction with state info, restructure config to accommodate all parameters from any user imported model type. Set 5Act to default env on TDQN. Clean example config.
This commit is contained in:
@@ -6,6 +6,7 @@ import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
from pandas import DataFrame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,7 +36,8 @@ class Base3ActionRLEnv(gym.Env):
|
||||
|
||||
metadata = {'render.modes': ['human']}
|
||||
|
||||
def __init__(self, df, prices, reward_kwargs, window_size=10, starting_point=True,
|
||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||
id: str = 'baseenv-1', seed: int = 1):
|
||||
assert df.ndim == 2
|
||||
|
||||
|
@@ -6,6 +6,7 @@ import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
from pandas import DataFrame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,7 +40,8 @@ class Base5ActionRLEnv(gym.Env):
|
||||
"""
|
||||
metadata = {'render.modes': ['human']}
|
||||
|
||||
def __init__(self, df, prices, reward_kwargs, window_size=10, starting_point=True,
|
||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||
id: str = 'baseenv-1', seed: int = 1):
|
||||
assert df.ndim == 2
|
||||
|
||||
@@ -56,7 +58,7 @@ class Base5ActionRLEnv(gym.Env):
|
||||
self.fee = 0.0015
|
||||
|
||||
# # spaces
|
||||
self.shape = (window_size, self.signal_features.shape[1])
|
||||
self.shape = (window_size, self.signal_features.shape[1] + 2)
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
|
||||
@@ -161,19 +163,26 @@ class Base5ActionRLEnv(gym.Env):
|
||||
self._done = True
|
||||
|
||||
self._position_history.append(self._position)
|
||||
observation = self._get_observation()
|
||||
|
||||
info = dict(
|
||||
tick=self._current_tick,
|
||||
total_reward=self.total_reward,
|
||||
total_profit=self._total_profit,
|
||||
position=self._position.value
|
||||
)
|
||||
|
||||
observation = self._get_observation()
|
||||
|
||||
self._update_history(info)
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
def _get_observation(self):
|
||||
return self.signal_features[(self._current_tick - self.window_size):self._current_tick]
|
||||
features_and_state = self.signal_features[(
|
||||
self._current_tick - self.window_size):self._current_tick]
|
||||
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
|
||||
features_and_state['position'] = self._position.value
|
||||
return features_and_state
|
||||
|
||||
def get_unrealized_profit(self):
|
||||
|
||||
|
@@ -13,7 +13,7 @@ from freqtrade.persistence import Trade
|
||||
import torch.multiprocessing
|
||||
import torch as th
|
||||
logger = logging.getLogger(__name__)
|
||||
th.set_num_threads(8)
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
|
||||
|
||||
@@ -22,6 +22,11 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
User created Reinforcement Learning Model prediction model.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(config=kwargs['config'])
|
||||
th.set_num_threads(self.freqai_info.get('data_kitchen_thread_count', 4))
|
||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||
|
||||
def train(
|
||||
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
||||
) -> Any:
|
||||
@@ -62,12 +67,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
model = self.fit_rl(data_dictionary, pair, dk, prices_train, prices_test)
|
||||
|
||||
if pair not in self.dd.historic_predictions:
|
||||
self.set_initial_historic_predictions(
|
||||
data_dictionary['train_features'], model, dk, pair)
|
||||
|
||||
self.dd.save_historic_predictions_to_disk()
|
||||
|
||||
logger.info(f"--------------------done training {pair}--------------------")
|
||||
|
||||
return model
|
||||
@@ -127,7 +126,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
# optional additional data cleaning/analysis
|
||||
self.data_cleaning_predict(dk, filtered_dataframe)
|
||||
|
||||
pred_df = self.rl_model_predict(dk.data_dictionary["prediction_features"], dk, self.model)
|
||||
pred_df = self.rl_model_predict(
|
||||
dk.data_dictionary["prediction_features"], dk, self.model)
|
||||
pred_df.fillna(0, inplace=True)
|
||||
|
||||
return (pred_df, dk.do_predict)
|
||||
@@ -135,10 +135,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
def rl_model_predict(self, dataframe: DataFrame,
|
||||
dk: FreqaiDataKitchen, model: Any) -> DataFrame:
|
||||
|
||||
output = pd.DataFrame(np.full((len(dataframe), 1), 2), columns=dk.label_list)
|
||||
output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
|
||||
|
||||
def _predict(window):
|
||||
market_side, current_profit, total_profit = self.get_state_info(dk.pair)
|
||||
observations = dataframe.iloc[window.index]
|
||||
observations['current_profit'] = current_profit
|
||||
observations['position'] = market_side
|
||||
res, _ = model.predict(observations, deterministic=True)
|
||||
return res
|
||||
|
||||
@@ -174,29 +177,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
return prices_train, prices_test
|
||||
|
||||
def set_initial_historic_predictions(
|
||||
self, df: DataFrame, model: Any, dk: FreqaiDataKitchen, pair: str
|
||||
) -> None:
|
||||
|
||||
pred_df = self.rl_model_predict(df, dk, model)
|
||||
pred_df.fillna(0, inplace=True)
|
||||
self.dd.historic_predictions[pair] = pred_df
|
||||
hist_preds_df = self.dd.historic_predictions[pair]
|
||||
|
||||
for label in hist_preds_df.columns:
|
||||
if hist_preds_df[label].dtype == object:
|
||||
continue
|
||||
hist_preds_df[f'{label}_mean'] = 0
|
||||
hist_preds_df[f'{label}_std'] = 0
|
||||
|
||||
hist_preds_df['do_predict'] = 0
|
||||
|
||||
if self.freqai_info['feature_parameters'].get('DI_threshold', 0) > 0:
|
||||
hist_preds_df['DI_values'] = 0
|
||||
|
||||
for return_str in dk.data['extra_returns_per_train']:
|
||||
hist_preds_df[return_str] = 0
|
||||
|
||||
# TODO take care of this appendage. Right now it needs to be called because FreqAI enforces it.
|
||||
# But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor
|
||||
# all the other existing fit() functions to include dk argument. For now we instantiate and
|
||||
|
Reference in New Issue
Block a user