Working base for reinforcement learning model
This commit is contained in:
parent
a6d78a8615
commit
05ed1b544f
@ -123,8 +123,8 @@ Mandatory parameters are marked as **Required**, which means that they are requi
|
||||
| `learning_rate` | Boosting learning rate during regression. <br> **Datatype:** Float.
|
||||
| `n_jobs`, `thread_count`, `task_type` | Set the number of threads for parallel processing and the `task_type` (`gpu` or `cpu`). Different model libraries use different parameter names. <br> **Datatype:** Float.
|
||||
| | **Extraneous parameters**
|
||||
| `keras` | If your model makes use of Keras (typical for Tensorflow-based prediction models), activate this flag so that the model save/loading follows Keras standards. <br> **Datatype:** Boolean. Default: `False`.
|
||||
| `conv_width` | The width of a convolutional neural network input tensor. This replaces the need for shifting candles (`include_shifted_candles`) by feeding in historical data points as the second dimension of the tensor. Technically, this parameter can also be used for regressors, but it only adds computational overhead and does not change the model training/prediction. <br> **Datatype:** Integer. Default: 2.
|
||||
| `keras` | If your model makes use of keras (typical of Tensorflow based prediction models), activate this flag so that the model save/loading follows keras standards. Default value `false` <br> **Datatype:** boolean.
|
||||
| `conv_width` | The width of a convolutional neural network input tensor or the `ReinforcementLearningModel` `window_size`. This replaces the need for `shift` by feeding in historical data points as the second dimension of the tensor. Technically, this parameter can also be used for regressors, but it only adds computational overhead and does not change the model training/prediction. Default value, 2 <br> **Datatype:** integer.
|
||||
|
||||
### Important dataframe key patterns
|
||||
|
||||
|
@ -520,10 +520,7 @@ CONF_SCHEMA = {
|
||||
},
|
||||
},
|
||||
"model_training_parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"n_estimators": {"type": "integer", "default": 1000}
|
||||
},
|
||||
"type": "object"
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
|
@ -390,10 +390,13 @@ class FreqaiDataDrawer:
|
||||
save_path = Path(dk.data_path)
|
||||
|
||||
# Save the trained model
|
||||
if not dk.keras:
|
||||
model_type = self.freqai_info.get('model_save_type', 'joblib')
|
||||
if model_type == 'joblib':
|
||||
dump(model, save_path / f"{dk.model_filename}_model.joblib")
|
||||
else:
|
||||
elif model_type == 'keras':
|
||||
model.save(save_path / f"{dk.model_filename}_model.h5")
|
||||
elif model_type == 'stable_baselines':
|
||||
model.save(save_path / f"{dk.model_filename}_model.zip")
|
||||
|
||||
if dk.svm_model is not None:
|
||||
dump(dk.svm_model, save_path / f"{dk.model_filename}_svm_model.joblib")
|
||||
@ -459,15 +462,18 @@ class FreqaiDataDrawer:
|
||||
dk.data_path / f"{dk.model_filename}_trained_df.pkl"
|
||||
)
|
||||
|
||||
model_type = self.freqai_info.get('model_save_type', 'joblib')
|
||||
# try to access model in memory instead of loading object from disk to save time
|
||||
if dk.live and coin in self.model_dictionary:
|
||||
model = self.model_dictionary[coin]
|
||||
elif not dk.keras:
|
||||
elif model_type == 'joblib':
|
||||
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
|
||||
else:
|
||||
elif model_type == 'keras':
|
||||
from tensorflow import keras
|
||||
|
||||
model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5")
|
||||
elif model_type == 'stable_baselines':
|
||||
from stable_baselines3.ppo.ppo import PPO
|
||||
model = PPO.load(dk.data_path / f"{dk.model_filename}_model.zip")
|
||||
|
||||
if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file():
|
||||
dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib")
|
||||
|
147
freqtrade/freqai/example_strats/ReinforcementLearningExample.py
Normal file
147
freqtrade/freqai/example_strats/ReinforcementLearningExample.py
Normal file
@ -0,0 +1,147 @@
|
||||
import logging
|
||||
from functools import reduce
|
||||
|
||||
import pandas as pd
|
||||
import talib.abstract as ta
|
||||
from pandas import DataFrame
|
||||
|
||||
from freqtrade.strategy import DecimalParameter, IntParameter, IStrategy, merge_informative_pair
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReinforcementLearningExample(IStrategy):
|
||||
"""
|
||||
Test strategy - used for testing freqAI functionalities.
|
||||
DO not use in production.
|
||||
"""
|
||||
|
||||
minimal_roi = {"0": 0.1, "240": -1}
|
||||
|
||||
plot_config = {
|
||||
"main_plot": {},
|
||||
"subplots": {
|
||||
"prediction": {"prediction": {"color": "blue"}},
|
||||
"target_roi": {
|
||||
"target_roi": {"color": "brown"},
|
||||
},
|
||||
"do_predict": {
|
||||
"do_predict": {"color": "brown"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
process_only_new_candles = True
|
||||
stoploss = -0.05
|
||||
use_exit_signal = True
|
||||
startup_candle_count: int = 300
|
||||
can_short = False
|
||||
|
||||
linear_roi_offset = DecimalParameter(
|
||||
0.00, 0.02, default=0.005, space="sell", optimize=False, load=True
|
||||
)
|
||||
max_roi_time_long = IntParameter(0, 800, default=400, space="sell", optimize=False, load=True)
|
||||
|
||||
def informative_pairs(self):
|
||||
whitelist_pairs = self.dp.current_whitelist()
|
||||
corr_pairs = self.config["freqai"]["feature_parameters"]["include_corr_pairlist"]
|
||||
informative_pairs = []
|
||||
for tf in self.config["freqai"]["feature_parameters"]["include_timeframes"]:
|
||||
for pair in whitelist_pairs:
|
||||
informative_pairs.append((pair, tf))
|
||||
for pair in corr_pairs:
|
||||
if pair in whitelist_pairs:
|
||||
continue # avoid duplication
|
||||
informative_pairs.append((pair, tf))
|
||||
return informative_pairs
|
||||
|
||||
def populate_any_indicators(
|
||||
self, pair, df, tf, informative=None, set_generalized_indicators=False
|
||||
):
|
||||
|
||||
coin = pair.split('/')[0]
|
||||
|
||||
with self.freqai.lock:
|
||||
if informative is None:
|
||||
informative = self.dp.get_pair_dataframe(pair, tf)
|
||||
|
||||
# first loop is automatically duplicating indicators for time periods
|
||||
for t in self.freqai_info["feature_parameters"]["indicator_periods_candles"]:
|
||||
|
||||
t = int(t)
|
||||
informative[f"%-{coin}rsi-period_{t}"] = ta.RSI(informative, timeperiod=t)
|
||||
informative[f"%-{coin}mfi-period_{t}"] = ta.MFI(informative, timeperiod=t)
|
||||
informative[f"%-{coin}adx-period_{t}"] = ta.ADX(informative, window=t)
|
||||
|
||||
informative[f"%-{coin}pct-change"] = informative["close"].pct_change()
|
||||
informative[f"%-{coin}raw_volume"] = informative["volume"]
|
||||
|
||||
# Raw price currently necessary for RL models:
|
||||
informative[f"%-{coin}raw_price"] = informative["close"]
|
||||
|
||||
indicators = [col for col in informative if col.startswith("%")]
|
||||
# This loop duplicates and shifts all indicators to add a sense of recency to data
|
||||
for n in range(self.freqai_info["feature_parameters"]["include_shifted_candles"] + 1):
|
||||
if n == 0:
|
||||
continue
|
||||
informative_shift = informative[indicators].shift(n)
|
||||
informative_shift = informative_shift.add_suffix("_shift-" + str(n))
|
||||
informative = pd.concat((informative, informative_shift), axis=1)
|
||||
|
||||
df = merge_informative_pair(df, informative, self.config["timeframe"], tf, ffill=True)
|
||||
skip_columns = [
|
||||
(s + "_" + tf) for s in ["date", "open", "high", "low", "close", "volume"]
|
||||
]
|
||||
df = df.drop(columns=skip_columns)
|
||||
|
||||
# Add generalized indicators here (because in live, it will call this
|
||||
# function to populate indicators during training). Notice how we ensure not to
|
||||
# add them multiple times
|
||||
if set_generalized_indicators:
|
||||
df["%-day_of_week"] = (df["date"].dt.dayofweek + 1) / 7
|
||||
df["%-hour_of_day"] = (df["date"].dt.hour + 1) / 25
|
||||
|
||||
# user adds targets here by prepending them with &- (see convention below)
|
||||
# If user wishes to use multiple targets, a multioutput prediction model
|
||||
# needs to be used such as templates/CatboostPredictionMultiModel.py
|
||||
df["&-action"] = 2
|
||||
|
||||
return df
|
||||
|
||||
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
|
||||
|
||||
self.freqai_info = self.config["freqai"]
|
||||
|
||||
dataframe = self.freqai.start(dataframe, metadata, self)
|
||||
|
||||
return dataframe
|
||||
|
||||
def populate_entry_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
|
||||
|
||||
enter_long_conditions = [df["do_predict"] == 1, df["&-action"] == 1]
|
||||
|
||||
if enter_long_conditions:
|
||||
df.loc[
|
||||
reduce(lambda x, y: x & y, enter_long_conditions), ["enter_long", "enter_tag"]
|
||||
] = (1, "long")
|
||||
|
||||
enter_short_conditions = [df["do_predict"] == 1, df["&-action"] == 2]
|
||||
|
||||
if enter_short_conditions:
|
||||
df.loc[
|
||||
reduce(lambda x, y: x & y, enter_short_conditions), ["enter_short", "enter_tag"]
|
||||
] = (1, "short")
|
||||
|
||||
return df
|
||||
|
||||
def populate_exit_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
|
||||
exit_long_conditions = [df["do_predict"] == 1, df["&-action"] == 2]
|
||||
if exit_long_conditions:
|
||||
df.loc[reduce(lambda x, y: x & y, exit_long_conditions), "exit_long"] = 1
|
||||
|
||||
exit_short_conditions = [df["do_predict"] == 1, df["&-action"] == 1]
|
||||
if exit_short_conditions:
|
||||
df.loc[reduce(lambda x, y: x & y, exit_short_conditions), "exit_short"] = 1
|
||||
|
||||
return df
|
@ -657,7 +657,7 @@ class IFreqaiModel(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, data_dictionary: Dict[str, Any]) -> Any:
|
||||
def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any:
|
||||
"""
|
||||
Most regressors use the same function names and arguments e.g. user
|
||||
can drop in LGBMRegressor in place of CatBoostRegressor and all data
|
||||
|
162
freqtrade/freqai/prediction_models/RL/RLPrediction_agent.py
Normal file
162
freqtrade/freqai/prediction_models/RL/RLPrediction_agent.py
Normal file
@ -0,0 +1,162 @@
|
||||
# common library
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3 import DDPG
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3 import SAC
|
||||
from stable_baselines3 import TD3
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.noise import NormalActionNoise
|
||||
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise
|
||||
# from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from freqtrade.freqai.prediction_models.RL import config
|
||||
# from meta.env_stock_trading.env_stock_trading import StockTradingEnv
|
||||
|
||||
# RL models from stable-baselines
|
||||
|
||||
|
||||
MODELS = {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO}
|
||||
|
||||
|
||||
MODEL_KWARGS = {x: config.__dict__[f"{x.upper()}_PARAMS"] for x in MODELS.keys()}
|
||||
|
||||
|
||||
NOISE = {
|
||||
"normal": NormalActionNoise,
|
||||
"ornstein_uhlenbeck": OrnsteinUhlenbeckActionNoise,
|
||||
}
|
||||
|
||||
|
||||
class TensorboardCallback(BaseCallback):
|
||||
"""
|
||||
Custom callback for plotting additional values in tensorboard.
|
||||
"""
|
||||
|
||||
def __init__(self, verbose=0):
|
||||
super(TensorboardCallback, self).__init__(verbose)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
try:
|
||||
self.logger.record(key="train/reward", value=self.locals["rewards"][0])
|
||||
except BaseException:
|
||||
self.logger.record(key="train/reward", value=self.locals["reward"][0])
|
||||
return True
|
||||
|
||||
|
||||
class RLPrediction_agent:
|
||||
"""Provides implementations for DRL algorithms
|
||||
Based on:
|
||||
https://github.com/AI4Finance-Foundation/FinRL-Meta/blob/master/agents/stablebaselines3_models.py
|
||||
Attributes
|
||||
----------
|
||||
env: gym environment class
|
||||
user-defined class
|
||||
|
||||
Methods
|
||||
-------
|
||||
get_model()
|
||||
setup DRL algorithms
|
||||
train_model()
|
||||
train DRL algorithms in a train dataset
|
||||
and output the trained model
|
||||
DRL_prediction()
|
||||
make a prediction in a test dataset and get results
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
self.env = env
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name,
|
||||
policy="MlpPolicy",
|
||||
policy_kwargs=None,
|
||||
model_kwargs=None,
|
||||
verbose=1,
|
||||
seed=None,
|
||||
):
|
||||
if model_name not in MODELS:
|
||||
raise NotImplementedError("NotImplementedError")
|
||||
|
||||
if model_kwargs is None:
|
||||
model_kwargs = MODEL_KWARGS[model_name]
|
||||
|
||||
if "action_noise" in model_kwargs:
|
||||
n_actions = self.env.action_space.shape[-1]
|
||||
model_kwargs["action_noise"] = NOISE[model_kwargs["action_noise"]](
|
||||
mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)
|
||||
)
|
||||
print(model_kwargs)
|
||||
model = MODELS[model_name](
|
||||
policy=policy,
|
||||
env=self.env,
|
||||
tensorboard_log=f"{config.TENSORBOARD_LOG_DIR}/{model_name}",
|
||||
verbose=verbose,
|
||||
policy_kwargs=policy_kwargs,
|
||||
seed=seed,
|
||||
**model_kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
def train_model(self, model, tb_log_name, total_timesteps=5000):
|
||||
model = model.learn(
|
||||
total_timesteps=total_timesteps,
|
||||
tb_log_name=tb_log_name,
|
||||
callback=TensorboardCallback(),
|
||||
)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def DRL_prediction(model, environment):
|
||||
test_env, test_obs = environment.get_sb_env()
|
||||
"""make a prediction"""
|
||||
account_memory = []
|
||||
actions_memory = []
|
||||
test_env.reset()
|
||||
for i in range(len(environment.df.index.unique())):
|
||||
action, _states = model.predict(test_obs)
|
||||
# account_memory = test_env.env_method(method_name="save_asset_memory")
|
||||
# actions_memory = test_env.env_method(method_name="save_action_memory")
|
||||
test_obs, rewards, dones, info = test_env.step(action)
|
||||
if i == (len(environment.df.index.unique()) - 2):
|
||||
account_memory = test_env.env_method(method_name="save_asset_memory")
|
||||
actions_memory = test_env.env_method(method_name="save_action_memory")
|
||||
if dones[0]:
|
||||
print("hit end!")
|
||||
break
|
||||
return account_memory[0], actions_memory[0]
|
||||
|
||||
@staticmethod
|
||||
def DRL_prediction_load_from_file(model_name, environment, cwd):
|
||||
if model_name not in MODELS:
|
||||
raise NotImplementedError("NotImplementedError")
|
||||
try:
|
||||
# load agent
|
||||
model = MODELS[model_name].load(cwd)
|
||||
print("Successfully load model", cwd)
|
||||
except BaseException:
|
||||
raise ValueError("Fail to load agent!")
|
||||
|
||||
# test on the testing env
|
||||
state = environment.reset()
|
||||
episode_returns = list() # the cumulative_return / initial_account
|
||||
episode_total_assets = list()
|
||||
episode_total_assets.append(environment.initial_total_asset)
|
||||
done = False
|
||||
while not done:
|
||||
action = model.predict(state)[0]
|
||||
state, reward, done, _ = environment.step(action)
|
||||
|
||||
total_asset = (
|
||||
environment.cash
|
||||
+ (environment.price_array[environment.time] * environment.stocks).sum()
|
||||
)
|
||||
episode_total_assets.append(total_asset)
|
||||
episode_return = total_asset / environment.initial_total_asset
|
||||
episode_returns.append(episode_return)
|
||||
|
||||
print("episode_return", episode_return)
|
||||
print("Test Finished!")
|
||||
return episode_total_assets
|
230
freqtrade/freqai/prediction_models/RL/RLPrediction_env.py
Normal file
230
freqtrade/freqai/prediction_models/RL/RLPrediction_env.py
Normal file
@ -0,0 +1,230 @@
|
||||
from enum import Enum
|
||||
|
||||
import gym
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
|
||||
|
||||
class Actions(Enum):
|
||||
Hold = 0
|
||||
Buy = 1
|
||||
Sell = 2
|
||||
|
||||
|
||||
class Positions(Enum):
|
||||
Short = 0
|
||||
Long = 1
|
||||
|
||||
def opposite(self):
|
||||
return Positions.Short if self == Positions.Long else Positions.Long
|
||||
|
||||
|
||||
class GymAnytrading(gym.Env):
|
||||
"""
|
||||
Based on https://github.com/AminHP/gym-anytrading
|
||||
"""
|
||||
|
||||
metadata = {'render.modes': ['human']}
|
||||
|
||||
def __init__(self, signal_features, prices, window_size, fee=0.0):
|
||||
assert signal_features.ndim == 2
|
||||
|
||||
self.seed()
|
||||
self.signal_features = signal_features
|
||||
self.prices = prices
|
||||
self.window_size = window_size
|
||||
self.fee = fee
|
||||
self.shape = (window_size, self.signal_features.shape[1])
|
||||
|
||||
# spaces
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
|
||||
|
||||
# episode
|
||||
self._start_tick = self.window_size
|
||||
self._end_tick = len(self.prices) - 1
|
||||
self._done = None
|
||||
self._current_tick = None
|
||||
self._last_trade_tick = None
|
||||
self._position = None
|
||||
self._position_history = None
|
||||
self._total_reward = None
|
||||
self._total_profit = None
|
||||
self._first_rendering = None
|
||||
self.history = None
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def reset(self):
|
||||
self._done = False
|
||||
self._current_tick = self._start_tick
|
||||
self._last_trade_tick = self._current_tick - 1
|
||||
self._position = Positions.Short
|
||||
self._position_history = (self.window_size * [None]) + [self._position]
|
||||
self._total_reward = 0.
|
||||
self._total_profit = 1. # unit
|
||||
self._first_rendering = True
|
||||
self.history = {}
|
||||
return self._get_observation()
|
||||
|
||||
def step(self, action):
|
||||
self._done = False
|
||||
self._current_tick += 1
|
||||
|
||||
if self._current_tick == self._end_tick:
|
||||
self._done = True
|
||||
|
||||
step_reward = self._calculate_reward(action)
|
||||
self._total_reward += step_reward
|
||||
|
||||
self._update_profit(action)
|
||||
|
||||
trade = False
|
||||
if ((action == Actions.Buy.value and self._position == Positions.Short) or
|
||||
(action == Actions.Sell.value and self._position == Positions.Long)):
|
||||
trade = True
|
||||
|
||||
if trade:
|
||||
self._position = self._position.opposite()
|
||||
self._last_trade_tick = self._current_tick
|
||||
|
||||
self._position_history.append(self._position)
|
||||
observation = self._get_observation()
|
||||
info = dict(
|
||||
total_reward=self._total_reward,
|
||||
total_profit=self._total_profit,
|
||||
position=self._position.value
|
||||
)
|
||||
self._update_history(info)
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
def _get_observation(self):
|
||||
return self.signal_features[(self._current_tick - self.window_size):self._current_tick]
|
||||
|
||||
def _update_history(self, info):
|
||||
if not self.history:
|
||||
self.history = {key: [] for key in info.keys()}
|
||||
|
||||
for key, value in info.items():
|
||||
self.history[key].append(value)
|
||||
|
||||
def render(self, mode='human'):
|
||||
def _plot_position(position, tick):
|
||||
color = None
|
||||
if position == Positions.Short:
|
||||
color = 'red'
|
||||
elif position == Positions.Long:
|
||||
color = 'green'
|
||||
if color:
|
||||
plt.scatter(tick, self.prices[tick], color=color)
|
||||
|
||||
if self._first_rendering:
|
||||
self._first_rendering = False
|
||||
plt.cla()
|
||||
plt.plot(self.prices)
|
||||
start_position = self._position_history[self._start_tick]
|
||||
_plot_position(start_position, self._start_tick)
|
||||
|
||||
_plot_position(self._position, self._current_tick)
|
||||
|
||||
plt.suptitle(
|
||||
"Total Reward: %.6f" % self._total_reward + ' ~ ' +
|
||||
"Total Profit: %.6f" % self._total_profit
|
||||
)
|
||||
|
||||
plt.pause(0.01)
|
||||
|
||||
def render_all(self, mode='human'):
|
||||
window_ticks = np.arange(len(self._position_history))
|
||||
plt.plot(self.prices)
|
||||
|
||||
short_ticks = []
|
||||
long_ticks = []
|
||||
for i, tick in enumerate(window_ticks):
|
||||
if self._position_history[i] == Positions.Short:
|
||||
short_ticks.append(tick)
|
||||
elif self._position_history[i] == Positions.Long:
|
||||
long_ticks.append(tick)
|
||||
|
||||
plt.plot(short_ticks, self.prices[short_ticks], 'ro')
|
||||
plt.plot(long_ticks, self.prices[long_ticks], 'go')
|
||||
|
||||
plt.suptitle(
|
||||
"Total Reward: %.6f" % self._total_reward + ' ~ ' +
|
||||
"Total Profit: %.6f" % self._total_profit
|
||||
)
|
||||
|
||||
def close(self):
|
||||
plt.close()
|
||||
|
||||
def save_rendering(self, filepath):
|
||||
plt.savefig(filepath)
|
||||
|
||||
def pause_rendering(self):
|
||||
plt.show()
|
||||
|
||||
def _calculate_reward(self, action):
|
||||
step_reward = 0
|
||||
|
||||
trade = False
|
||||
if ((action == Actions.Buy.value and self._position == Positions.Short) or
|
||||
(action == Actions.Sell.value and self._position == Positions.Long)):
|
||||
trade = True
|
||||
|
||||
if trade:
|
||||
current_price = self.prices[self._current_tick]
|
||||
last_trade_price = self.prices[self._last_trade_tick]
|
||||
price_diff = current_price - last_trade_price
|
||||
|
||||
if self._position == Positions.Long:
|
||||
step_reward += price_diff
|
||||
|
||||
return step_reward
|
||||
|
||||
def _update_profit(self, action):
|
||||
trade = False
|
||||
if ((action == Actions.Buy.value and self._position == Positions.Short) or
|
||||
(action == Actions.Sell.value and self._position == Positions.Long)):
|
||||
trade = True
|
||||
|
||||
if trade or self._done:
|
||||
current_price = self.prices[self._current_tick]
|
||||
last_trade_price = self.prices[self._last_trade_tick]
|
||||
|
||||
if self._position == Positions.Long:
|
||||
shares = (self._total_profit * (1 - self.fee)) / last_trade_price
|
||||
self._total_profit = (shares * (1 - self.fee)) * current_price
|
||||
|
||||
def max_possible_profit(self):
|
||||
current_tick = self._start_tick
|
||||
last_trade_tick = current_tick - 1
|
||||
profit = 1.
|
||||
|
||||
while current_tick <= self._end_tick:
|
||||
position = None
|
||||
if self.prices[current_tick] < self.prices[current_tick - 1]:
|
||||
while (current_tick <= self._end_tick and
|
||||
self.prices[current_tick] < self.prices[current_tick - 1]):
|
||||
current_tick += 1
|
||||
position = Positions.Short
|
||||
else:
|
||||
while (current_tick <= self._end_tick and
|
||||
self.prices[current_tick] >= self.prices[current_tick - 1]):
|
||||
current_tick += 1
|
||||
position = Positions.Long
|
||||
|
||||
if position == Positions.Long:
|
||||
current_price = self.prices[current_tick - 1]
|
||||
last_trade_price = self.prices[last_trade_tick]
|
||||
shares = profit / last_trade_price
|
||||
profit = shares * current_price
|
||||
last_trade_tick = current_tick - 1
|
||||
print(profit)
|
||||
|
||||
return profit
|
37
freqtrade/freqai/prediction_models/RL/config.py
Normal file
37
freqtrade/freqai/prediction_models/RL/config.py
Normal file
@ -0,0 +1,37 @@
|
||||
# dir
|
||||
DATA_SAVE_DIR = "datasets"
|
||||
TRAINED_MODEL_DIR = "trained_models"
|
||||
TENSORBOARD_LOG_DIR = "tensorboard_log"
|
||||
RESULTS_DIR = "results"
|
||||
|
||||
# Model Parameters
|
||||
A2C_PARAMS = {"n_steps": 5, "ent_coef": 0.01, "learning_rate": 0.0007}
|
||||
PPO_PARAMS = {
|
||||
"n_steps": 2048,
|
||||
"ent_coef": 0.01,
|
||||
"learning_rate": 0.00025,
|
||||
"batch_size": 64,
|
||||
}
|
||||
DDPG_PARAMS = {"batch_size": 128, "buffer_size": 50000, "learning_rate": 0.001}
|
||||
TD3_PARAMS = {
|
||||
"batch_size": 100,
|
||||
"buffer_size": 1000000,
|
||||
"learning_rate": 0.001,
|
||||
}
|
||||
SAC_PARAMS = {
|
||||
"batch_size": 64,
|
||||
"buffer_size": 100000,
|
||||
"learning_rate": 0.0001,
|
||||
"learning_starts": 100,
|
||||
"ent_coef": "auto_0.1",
|
||||
}
|
||||
ERL_PARAMS = {
|
||||
"learning_rate": 3e-5,
|
||||
"batch_size": 2048,
|
||||
"gamma": 0.985,
|
||||
"seed": 312,
|
||||
"net_dimension": 512,
|
||||
"target_step": 5000,
|
||||
"eval_gap": 30,
|
||||
}
|
||||
RLlib_PARAMS = {"lr": 5e-5, "train_batch_size": 500, "gamma": 0.99}
|
157
freqtrade/freqai/prediction_models/ReinforcementLearningModel.py
Normal file
157
freqtrade/freqai/prediction_models/ReinforcementLearningModel.py
Normal file
@ -0,0 +1,157 @@
|
||||
import logging
|
||||
from typing import Any, Tuple, Dict
|
||||
from freqtrade.freqai.prediction_models.RL.RLPrediction_env import GymAnytrading
|
||||
from freqtrade.freqai.prediction_models.RL.RLPrediction_agent import RLPrediction_agent
|
||||
from pandas import DataFrame
|
||||
import pandas as pd
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReinforcementLearningModel(IFreqaiModel):
|
||||
"""
|
||||
User created Reinforcement Learning Model prediction model.
|
||||
"""
|
||||
|
||||
def train(
|
||||
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
||||
) -> Any:
|
||||
"""
|
||||
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
|
||||
for storing, saving, loading, and analyzing the data.
|
||||
:param unfiltered_dataframe: Full dataframe for the current training period
|
||||
:param metadata: pair metadata from strategy.
|
||||
:returns:
|
||||
:model: Trained model which can be used to inference (self.predict)
|
||||
"""
|
||||
|
||||
logger.info("--------------------Starting training " f"{pair} --------------------")
|
||||
|
||||
# filter the features requested by user in the configuration file and elegantly handle NaNs
|
||||
features_filtered, labels_filtered = dk.filter_features(
|
||||
unfiltered_dataframe,
|
||||
dk.training_features_list,
|
||||
dk.label_list,
|
||||
training_filter=True,
|
||||
)
|
||||
|
||||
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
|
||||
features_filtered, labels_filtered)
|
||||
dk.fit_labels() # useless for now, but just satiating append methods
|
||||
|
||||
# normalize all data based on train_dataset only
|
||||
data_dictionary = dk.normalize_data(data_dictionary)
|
||||
|
||||
# optional additional data cleaning/analysis
|
||||
self.data_cleaning_train(dk)
|
||||
|
||||
logger.info(
|
||||
f'Training model on {len(dk.data_dictionary["train_features"].columns)}' " features"
|
||||
)
|
||||
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
|
||||
|
||||
model = self.fit(data_dictionary, pair)
|
||||
|
||||
if pair not in self.dd.historic_predictions:
|
||||
self.set_initial_historic_predictions(
|
||||
data_dictionary['train_features'], model, dk, pair)
|
||||
|
||||
self.dd.save_historic_predictions_to_disk()
|
||||
|
||||
logger.info(f"--------------------done training {pair}--------------------")
|
||||
|
||||
return model
|
||||
|
||||
def fit(self, data_dictionary: Dict[str, Any], pair: str = ''):
|
||||
|
||||
train_df = data_dictionary["train_features"]
|
||||
|
||||
sep = '/'
|
||||
coin = pair.split(sep, 1)[0]
|
||||
price = train_df[f"%-{coin}raw_price_{self.config['timeframe']}"]
|
||||
price.reset_index(inplace=True, drop=True)
|
||||
|
||||
model_name = 'ppo'
|
||||
|
||||
env_instance = GymAnytrading(train_df, price, self.CONV_WIDTH)
|
||||
|
||||
agent_params = self.freqai_info['model_training_parameters']
|
||||
total_timesteps = agent_params.get('total_timesteps', 1000)
|
||||
|
||||
agent = RLPrediction_agent(env_instance)
|
||||
|
||||
model = agent.get_model(model_name, model_kwargs=agent_params)
|
||||
trained_model = agent.train_model(model=model,
|
||||
tb_log_name=model_name,
|
||||
total_timesteps=total_timesteps)
|
||||
print('Training finished!')
|
||||
|
||||
return trained_model
|
||||
|
||||
def predict(
|
||||
self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False
|
||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||
"""
|
||||
Filter the prediction features data and predict with it.
|
||||
:param: unfiltered_dataframe: Full dataframe for the current backtest period.
|
||||
:return:
|
||||
:pred_df: dataframe containing the predictions
|
||||
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
|
||||
data (NaNs) or felt uncertain about data (PCA and DI index)
|
||||
"""
|
||||
|
||||
dk.find_features(unfiltered_dataframe)
|
||||
filtered_dataframe, _ = dk.filter_features(
|
||||
unfiltered_dataframe, dk.training_features_list, training_filter=False
|
||||
)
|
||||
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
|
||||
dk.data_dictionary["prediction_features"] = filtered_dataframe
|
||||
|
||||
# optional additional data cleaning/analysis
|
||||
self.data_cleaning_predict(dk, filtered_dataframe)
|
||||
|
||||
pred_df = self.rl_model_predict(dk.data_dictionary["prediction_features"], dk, self.model)
|
||||
pred_df.fillna(0, inplace=True)
|
||||
|
||||
return (pred_df, dk.do_predict)
|
||||
|
||||
def rl_model_predict(self, dataframe: DataFrame,
|
||||
dk: FreqaiDataKitchen, model: Any) -> DataFrame:
|
||||
|
||||
output = pd.DataFrame(np.full((len(dataframe), 1), 2), columns=dk.label_list)
|
||||
|
||||
def _predict(window):
|
||||
observations = dataframe.iloc[window.index]
|
||||
res, _ = model.predict(observations, deterministic=True)
|
||||
return res
|
||||
|
||||
output = output.rolling(window=self.CONV_WIDTH).apply(_predict)
|
||||
|
||||
return output
|
||||
|
||||
def set_initial_historic_predictions(
|
||||
self, df: DataFrame, model: Any, dk: FreqaiDataKitchen, pair: str
|
||||
) -> None:
|
||||
|
||||
pred_df = self.rl_model_predict(df, dk, model)
|
||||
pred_df.fillna(0, inplace=True)
|
||||
self.dd.historic_predictions[pair] = pred_df
|
||||
hist_preds_df = self.dd.historic_predictions[pair]
|
||||
|
||||
for label in hist_preds_df.columns:
|
||||
if hist_preds_df[label].dtype == object:
|
||||
continue
|
||||
hist_preds_df[f'{label}_mean'] = 0
|
||||
hist_preds_df[f'{label}_std'] = 0
|
||||
|
||||
hist_preds_df['do_predict'] = 0
|
||||
|
||||
if self.freqai_info['feature_parameters'].get('DI_threshold', 0) > 0:
|
||||
hist_preds_df['DI_values'] = 0
|
||||
|
||||
for return_str in dk.data['extra_returns_per_train']:
|
||||
hist_preds_df[return_str] = 0
|
Loading…
Reference in New Issue
Block a user