Working base for reinforcement learning model
This commit is contained in:
162
freqtrade/freqai/prediction_models/RL/RLPrediction_agent.py
Normal file
162
freqtrade/freqai/prediction_models/RL/RLPrediction_agent.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# common library
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3 import DDPG
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3 import SAC
|
||||
from stable_baselines3 import TD3
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.noise import NormalActionNoise
|
||||
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise
|
||||
# from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from freqtrade.freqai.prediction_models.RL import config
|
||||
# from meta.env_stock_trading.env_stock_trading import StockTradingEnv
|
||||
|
||||
# RL models from stable-baselines
|
||||
|
||||
|
||||
MODELS = {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO}
|
||||
|
||||
|
||||
MODEL_KWARGS = {x: config.__dict__[f"{x.upper()}_PARAMS"] for x in MODELS.keys()}
|
||||
|
||||
|
||||
NOISE = {
|
||||
"normal": NormalActionNoise,
|
||||
"ornstein_uhlenbeck": OrnsteinUhlenbeckActionNoise,
|
||||
}
|
||||
|
||||
|
||||
class TensorboardCallback(BaseCallback):
|
||||
"""
|
||||
Custom callback for plotting additional values in tensorboard.
|
||||
"""
|
||||
|
||||
def __init__(self, verbose=0):
|
||||
super(TensorboardCallback, self).__init__(verbose)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
try:
|
||||
self.logger.record(key="train/reward", value=self.locals["rewards"][0])
|
||||
except BaseException:
|
||||
self.logger.record(key="train/reward", value=self.locals["reward"][0])
|
||||
return True
|
||||
|
||||
|
||||
class RLPrediction_agent:
|
||||
"""Provides implementations for DRL algorithms
|
||||
Based on:
|
||||
https://github.com/AI4Finance-Foundation/FinRL-Meta/blob/master/agents/stablebaselines3_models.py
|
||||
Attributes
|
||||
----------
|
||||
env: gym environment class
|
||||
user-defined class
|
||||
|
||||
Methods
|
||||
-------
|
||||
get_model()
|
||||
setup DRL algorithms
|
||||
train_model()
|
||||
train DRL algorithms in a train dataset
|
||||
and output the trained model
|
||||
DRL_prediction()
|
||||
make a prediction in a test dataset and get results
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
self.env = env
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name,
|
||||
policy="MlpPolicy",
|
||||
policy_kwargs=None,
|
||||
model_kwargs=None,
|
||||
verbose=1,
|
||||
seed=None,
|
||||
):
|
||||
if model_name not in MODELS:
|
||||
raise NotImplementedError("NotImplementedError")
|
||||
|
||||
if model_kwargs is None:
|
||||
model_kwargs = MODEL_KWARGS[model_name]
|
||||
|
||||
if "action_noise" in model_kwargs:
|
||||
n_actions = self.env.action_space.shape[-1]
|
||||
model_kwargs["action_noise"] = NOISE[model_kwargs["action_noise"]](
|
||||
mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)
|
||||
)
|
||||
print(model_kwargs)
|
||||
model = MODELS[model_name](
|
||||
policy=policy,
|
||||
env=self.env,
|
||||
tensorboard_log=f"{config.TENSORBOARD_LOG_DIR}/{model_name}",
|
||||
verbose=verbose,
|
||||
policy_kwargs=policy_kwargs,
|
||||
seed=seed,
|
||||
**model_kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
def train_model(self, model, tb_log_name, total_timesteps=5000):
|
||||
model = model.learn(
|
||||
total_timesteps=total_timesteps,
|
||||
tb_log_name=tb_log_name,
|
||||
callback=TensorboardCallback(),
|
||||
)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def DRL_prediction(model, environment):
|
||||
test_env, test_obs = environment.get_sb_env()
|
||||
"""make a prediction"""
|
||||
account_memory = []
|
||||
actions_memory = []
|
||||
test_env.reset()
|
||||
for i in range(len(environment.df.index.unique())):
|
||||
action, _states = model.predict(test_obs)
|
||||
# account_memory = test_env.env_method(method_name="save_asset_memory")
|
||||
# actions_memory = test_env.env_method(method_name="save_action_memory")
|
||||
test_obs, rewards, dones, info = test_env.step(action)
|
||||
if i == (len(environment.df.index.unique()) - 2):
|
||||
account_memory = test_env.env_method(method_name="save_asset_memory")
|
||||
actions_memory = test_env.env_method(method_name="save_action_memory")
|
||||
if dones[0]:
|
||||
print("hit end!")
|
||||
break
|
||||
return account_memory[0], actions_memory[0]
|
||||
|
||||
@staticmethod
|
||||
def DRL_prediction_load_from_file(model_name, environment, cwd):
|
||||
if model_name not in MODELS:
|
||||
raise NotImplementedError("NotImplementedError")
|
||||
try:
|
||||
# load agent
|
||||
model = MODELS[model_name].load(cwd)
|
||||
print("Successfully load model", cwd)
|
||||
except BaseException:
|
||||
raise ValueError("Fail to load agent!")
|
||||
|
||||
# test on the testing env
|
||||
state = environment.reset()
|
||||
episode_returns = list() # the cumulative_return / initial_account
|
||||
episode_total_assets = list()
|
||||
episode_total_assets.append(environment.initial_total_asset)
|
||||
done = False
|
||||
while not done:
|
||||
action = model.predict(state)[0]
|
||||
state, reward, done, _ = environment.step(action)
|
||||
|
||||
total_asset = (
|
||||
environment.cash
|
||||
+ (environment.price_array[environment.time] * environment.stocks).sum()
|
||||
)
|
||||
episode_total_assets.append(total_asset)
|
||||
episode_return = total_asset / environment.initial_total_asset
|
||||
episode_returns.append(episode_return)
|
||||
|
||||
print("episode_return", episode_return)
|
||||
print("Test Finished!")
|
||||
return episode_total_assets
|
230
freqtrade/freqai/prediction_models/RL/RLPrediction_env.py
Normal file
230
freqtrade/freqai/prediction_models/RL/RLPrediction_env.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from enum import Enum
|
||||
|
||||
import gym
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
|
||||
|
||||
class Actions(Enum):
|
||||
Hold = 0
|
||||
Buy = 1
|
||||
Sell = 2
|
||||
|
||||
|
||||
class Positions(Enum):
|
||||
Short = 0
|
||||
Long = 1
|
||||
|
||||
def opposite(self):
|
||||
return Positions.Short if self == Positions.Long else Positions.Long
|
||||
|
||||
|
||||
class GymAnytrading(gym.Env):
|
||||
"""
|
||||
Based on https://github.com/AminHP/gym-anytrading
|
||||
"""
|
||||
|
||||
metadata = {'render.modes': ['human']}
|
||||
|
||||
def __init__(self, signal_features, prices, window_size, fee=0.0):
|
||||
assert signal_features.ndim == 2
|
||||
|
||||
self.seed()
|
||||
self.signal_features = signal_features
|
||||
self.prices = prices
|
||||
self.window_size = window_size
|
||||
self.fee = fee
|
||||
self.shape = (window_size, self.signal_features.shape[1])
|
||||
|
||||
# spaces
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
|
||||
|
||||
# episode
|
||||
self._start_tick = self.window_size
|
||||
self._end_tick = len(self.prices) - 1
|
||||
self._done = None
|
||||
self._current_tick = None
|
||||
self._last_trade_tick = None
|
||||
self._position = None
|
||||
self._position_history = None
|
||||
self._total_reward = None
|
||||
self._total_profit = None
|
||||
self._first_rendering = None
|
||||
self.history = None
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def reset(self):
|
||||
self._done = False
|
||||
self._current_tick = self._start_tick
|
||||
self._last_trade_tick = self._current_tick - 1
|
||||
self._position = Positions.Short
|
||||
self._position_history = (self.window_size * [None]) + [self._position]
|
||||
self._total_reward = 0.
|
||||
self._total_profit = 1. # unit
|
||||
self._first_rendering = True
|
||||
self.history = {}
|
||||
return self._get_observation()
|
||||
|
||||
def step(self, action):
|
||||
self._done = False
|
||||
self._current_tick += 1
|
||||
|
||||
if self._current_tick == self._end_tick:
|
||||
self._done = True
|
||||
|
||||
step_reward = self._calculate_reward(action)
|
||||
self._total_reward += step_reward
|
||||
|
||||
self._update_profit(action)
|
||||
|
||||
trade = False
|
||||
if ((action == Actions.Buy.value and self._position == Positions.Short) or
|
||||
(action == Actions.Sell.value and self._position == Positions.Long)):
|
||||
trade = True
|
||||
|
||||
if trade:
|
||||
self._position = self._position.opposite()
|
||||
self._last_trade_tick = self._current_tick
|
||||
|
||||
self._position_history.append(self._position)
|
||||
observation = self._get_observation()
|
||||
info = dict(
|
||||
total_reward=self._total_reward,
|
||||
total_profit=self._total_profit,
|
||||
position=self._position.value
|
||||
)
|
||||
self._update_history(info)
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
def _get_observation(self):
|
||||
return self.signal_features[(self._current_tick - self.window_size):self._current_tick]
|
||||
|
||||
def _update_history(self, info):
|
||||
if not self.history:
|
||||
self.history = {key: [] for key in info.keys()}
|
||||
|
||||
for key, value in info.items():
|
||||
self.history[key].append(value)
|
||||
|
||||
def render(self, mode='human'):
|
||||
def _plot_position(position, tick):
|
||||
color = None
|
||||
if position == Positions.Short:
|
||||
color = 'red'
|
||||
elif position == Positions.Long:
|
||||
color = 'green'
|
||||
if color:
|
||||
plt.scatter(tick, self.prices[tick], color=color)
|
||||
|
||||
if self._first_rendering:
|
||||
self._first_rendering = False
|
||||
plt.cla()
|
||||
plt.plot(self.prices)
|
||||
start_position = self._position_history[self._start_tick]
|
||||
_plot_position(start_position, self._start_tick)
|
||||
|
||||
_plot_position(self._position, self._current_tick)
|
||||
|
||||
plt.suptitle(
|
||||
"Total Reward: %.6f" % self._total_reward + ' ~ ' +
|
||||
"Total Profit: %.6f" % self._total_profit
|
||||
)
|
||||
|
||||
plt.pause(0.01)
|
||||
|
||||
def render_all(self, mode='human'):
|
||||
window_ticks = np.arange(len(self._position_history))
|
||||
plt.plot(self.prices)
|
||||
|
||||
short_ticks = []
|
||||
long_ticks = []
|
||||
for i, tick in enumerate(window_ticks):
|
||||
if self._position_history[i] == Positions.Short:
|
||||
short_ticks.append(tick)
|
||||
elif self._position_history[i] == Positions.Long:
|
||||
long_ticks.append(tick)
|
||||
|
||||
plt.plot(short_ticks, self.prices[short_ticks], 'ro')
|
||||
plt.plot(long_ticks, self.prices[long_ticks], 'go')
|
||||
|
||||
plt.suptitle(
|
||||
"Total Reward: %.6f" % self._total_reward + ' ~ ' +
|
||||
"Total Profit: %.6f" % self._total_profit
|
||||
)
|
||||
|
||||
def close(self):
|
||||
plt.close()
|
||||
|
||||
def save_rendering(self, filepath):
|
||||
plt.savefig(filepath)
|
||||
|
||||
def pause_rendering(self):
|
||||
plt.show()
|
||||
|
||||
def _calculate_reward(self, action):
|
||||
step_reward = 0
|
||||
|
||||
trade = False
|
||||
if ((action == Actions.Buy.value and self._position == Positions.Short) or
|
||||
(action == Actions.Sell.value and self._position == Positions.Long)):
|
||||
trade = True
|
||||
|
||||
if trade:
|
||||
current_price = self.prices[self._current_tick]
|
||||
last_trade_price = self.prices[self._last_trade_tick]
|
||||
price_diff = current_price - last_trade_price
|
||||
|
||||
if self._position == Positions.Long:
|
||||
step_reward += price_diff
|
||||
|
||||
return step_reward
|
||||
|
||||
def _update_profit(self, action):
|
||||
trade = False
|
||||
if ((action == Actions.Buy.value and self._position == Positions.Short) or
|
||||
(action == Actions.Sell.value and self._position == Positions.Long)):
|
||||
trade = True
|
||||
|
||||
if trade or self._done:
|
||||
current_price = self.prices[self._current_tick]
|
||||
last_trade_price = self.prices[self._last_trade_tick]
|
||||
|
||||
if self._position == Positions.Long:
|
||||
shares = (self._total_profit * (1 - self.fee)) / last_trade_price
|
||||
self._total_profit = (shares * (1 - self.fee)) * current_price
|
||||
|
||||
def max_possible_profit(self):
|
||||
current_tick = self._start_tick
|
||||
last_trade_tick = current_tick - 1
|
||||
profit = 1.
|
||||
|
||||
while current_tick <= self._end_tick:
|
||||
position = None
|
||||
if self.prices[current_tick] < self.prices[current_tick - 1]:
|
||||
while (current_tick <= self._end_tick and
|
||||
self.prices[current_tick] < self.prices[current_tick - 1]):
|
||||
current_tick += 1
|
||||
position = Positions.Short
|
||||
else:
|
||||
while (current_tick <= self._end_tick and
|
||||
self.prices[current_tick] >= self.prices[current_tick - 1]):
|
||||
current_tick += 1
|
||||
position = Positions.Long
|
||||
|
||||
if position == Positions.Long:
|
||||
current_price = self.prices[current_tick - 1]
|
||||
last_trade_price = self.prices[last_trade_tick]
|
||||
shares = profit / last_trade_price
|
||||
profit = shares * current_price
|
||||
last_trade_tick = current_tick - 1
|
||||
print(profit)
|
||||
|
||||
return profit
|
37
freqtrade/freqai/prediction_models/RL/config.py
Normal file
37
freqtrade/freqai/prediction_models/RL/config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# dir
|
||||
DATA_SAVE_DIR = "datasets"
|
||||
TRAINED_MODEL_DIR = "trained_models"
|
||||
TENSORBOARD_LOG_DIR = "tensorboard_log"
|
||||
RESULTS_DIR = "results"
|
||||
|
||||
# Model Parameters
|
||||
A2C_PARAMS = {"n_steps": 5, "ent_coef": 0.01, "learning_rate": 0.0007}
|
||||
PPO_PARAMS = {
|
||||
"n_steps": 2048,
|
||||
"ent_coef": 0.01,
|
||||
"learning_rate": 0.00025,
|
||||
"batch_size": 64,
|
||||
}
|
||||
DDPG_PARAMS = {"batch_size": 128, "buffer_size": 50000, "learning_rate": 0.001}
|
||||
TD3_PARAMS = {
|
||||
"batch_size": 100,
|
||||
"buffer_size": 1000000,
|
||||
"learning_rate": 0.001,
|
||||
}
|
||||
SAC_PARAMS = {
|
||||
"batch_size": 64,
|
||||
"buffer_size": 100000,
|
||||
"learning_rate": 0.0001,
|
||||
"learning_starts": 100,
|
||||
"ent_coef": "auto_0.1",
|
||||
}
|
||||
ERL_PARAMS = {
|
||||
"learning_rate": 3e-5,
|
||||
"batch_size": 2048,
|
||||
"gamma": 0.985,
|
||||
"seed": 312,
|
||||
"net_dimension": 512,
|
||||
"target_step": 5000,
|
||||
"eval_gap": 30,
|
||||
}
|
||||
RLlib_PARAMS = {"lr": 5e-5, "train_batch_size": 500, "gamma": 0.99}
|
Reference in New Issue
Block a user