add continual retraining feature, handly mypy typing reqs, improve docstrings
This commit is contained in:
parent
b708134c1a
commit
c0cee5df07
@ -85,12 +85,13 @@
|
|||||||
"verbose": 1
|
"verbose": 1
|
||||||
},
|
},
|
||||||
"rl_config": {
|
"rl_config": {
|
||||||
"train_cycles": 10,
|
"train_cycles": 3,
|
||||||
"eval_cycles": 3,
|
"eval_cycles": 3,
|
||||||
"thread_count": 4,
|
"thread_count": 4,
|
||||||
"max_trade_duration_candles": 100,
|
"max_trade_duration_candles": 100,
|
||||||
"model_type": "PPO",
|
"model_type": "PPO",
|
||||||
"policy_type": "MlpPolicy",
|
"policy_type": "MlpPolicy",
|
||||||
|
"continual_retraining": true,
|
||||||
"model_reward_parameters": {
|
"model_reward_parameters": {
|
||||||
"rr": 1,
|
"rr": 1,
|
||||||
"profit_aim": 0.02,
|
"profit_aim": 0.02,
|
||||||
|
@ -1,330 +1,330 @@
|
|||||||
import logging
|
# import logging
|
||||||
from enum import Enum
|
# from enum import Enum
|
||||||
|
|
||||||
import gym
|
# import gym
|
||||||
import numpy as np
|
# import numpy as np
|
||||||
import pandas as pd
|
# import pandas as pd
|
||||||
from gym import spaces
|
# from gym import spaces
|
||||||
from gym.utils import seeding
|
# from gym.utils import seeding
|
||||||
from pandas import DataFrame
|
# from pandas import DataFrame
|
||||||
|
|
||||||
|
|
||||||
# from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
# # from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
# logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Actions(Enum):
|
# class Actions(Enum):
|
||||||
Short = 0
|
# Short = 0
|
||||||
Long = 1
|
# Long = 1
|
||||||
Neutral = 2
|
# Neutral = 2
|
||||||
|
|
||||||
|
|
||||||
class Positions(Enum):
|
# class Positions(Enum):
|
||||||
Short = 0
|
# Short = 0
|
||||||
Long = 1
|
# Long = 1
|
||||||
Neutral = 0.5
|
# Neutral = 0.5
|
||||||
|
|
||||||
def opposite(self):
|
# def opposite(self):
|
||||||
return Positions.Short if self == Positions.Long else Positions.Long
|
# return Positions.Short if self == Positions.Long else Positions.Long
|
||||||
|
|
||||||
|
|
||||||
def mean_over_std(x):
|
# def mean_over_std(x):
|
||||||
std = np.std(x, ddof=1)
|
# std = np.std(x, ddof=1)
|
||||||
mean = np.mean(x)
|
# mean = np.mean(x)
|
||||||
return mean / std if std > 0 else 0
|
# return mean / std if std > 0 else 0
|
||||||
|
|
||||||
|
|
||||||
class Base3ActionRLEnv(gym.Env):
|
# class Base3ActionRLEnv(gym.Env):
|
||||||
|
|
||||||
metadata = {'render.modes': ['human']}
|
# metadata = {'render.modes': ['human']}
|
||||||
|
|
||||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
# def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
# reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||||
id: str = 'baseenv-1', seed: int = 1):
|
# id: str = 'baseenv-1', seed: int = 1):
|
||||||
assert df.ndim == 2
|
# assert df.ndim == 2
|
||||||
|
|
||||||
self.id = id
|
# self.id = id
|
||||||
self.seed(seed)
|
# self.seed(seed)
|
||||||
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
# self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
||||||
|
|
||||||
def reset_env(self, df, prices, window_size, reward_kwargs, starting_point=True):
|
# def reset_env(self, df, prices, window_size, reward_kwargs, starting_point=True):
|
||||||
self.df = df
|
# self.df = df
|
||||||
self.signal_features = self.df
|
# self.signal_features = self.df
|
||||||
self.prices = prices
|
# self.prices = prices
|
||||||
self.window_size = window_size
|
# self.window_size = window_size
|
||||||
self.starting_point = starting_point
|
# self.starting_point = starting_point
|
||||||
self.rr = reward_kwargs["rr"]
|
# self.rr = reward_kwargs["rr"]
|
||||||
self.profit_aim = reward_kwargs["profit_aim"]
|
# self.profit_aim = reward_kwargs["profit_aim"]
|
||||||
|
|
||||||
self.fee = 0.0015
|
# self.fee = 0.0015
|
||||||
|
|
||||||
# # spaces
|
# # # spaces
|
||||||
self.shape = (window_size, self.signal_features.shape[1] + 2)
|
# self.shape = (window_size, self.signal_features.shape[1] + 2)
|
||||||
self.action_space = spaces.Discrete(len(Actions))
|
# self.action_space = spaces.Discrete(len(Actions))
|
||||||
self.observation_space = spaces.Box(
|
# self.observation_space = spaces.Box(
|
||||||
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
|
# low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
|
||||||
|
|
||||||
# episode
|
# # episode
|
||||||
self._start_tick = self.window_size
|
# self._start_tick = self.window_size
|
||||||
self._end_tick = len(self.prices) - 1
|
# self._end_tick = len(self.prices) - 1
|
||||||
self._done = None
|
# self._done = None
|
||||||
self._current_tick = None
|
# self._current_tick = None
|
||||||
self._last_trade_tick = None
|
# self._last_trade_tick = None
|
||||||
self._position = Positions.Neutral
|
# self._position = Positions.Neutral
|
||||||
self._position_history = None
|
# self._position_history = None
|
||||||
self.total_reward = None
|
# self.total_reward = None
|
||||||
self._total_profit = None
|
# self._total_profit = None
|
||||||
self._first_rendering = None
|
# self._first_rendering = None
|
||||||
self.history = None
|
# self.history = None
|
||||||
self.trade_history = []
|
# self.trade_history = []
|
||||||
|
|
||||||
def seed(self, seed: int = 1):
|
# def seed(self, seed: int = 1):
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
# self.np_random, seed = seeding.np_random(seed)
|
||||||
return [seed]
|
# return [seed]
|
||||||
|
|
||||||
def reset(self):
|
# def reset(self):
|
||||||
|
|
||||||
self._done = False
|
# self._done = False
|
||||||
|
|
||||||
if self.starting_point is True:
|
# if self.starting_point is True:
|
||||||
self._position_history = (self._start_tick * [None]) + [self._position]
|
# self._position_history = (self._start_tick * [None]) + [self._position]
|
||||||
else:
|
# else:
|
||||||
self._position_history = (self.window_size * [None]) + [self._position]
|
# self._position_history = (self.window_size * [None]) + [self._position]
|
||||||
|
|
||||||
self._current_tick = self._start_tick
|
# self._current_tick = self._start_tick
|
||||||
self._last_trade_tick = None
|
# self._last_trade_tick = None
|
||||||
self._position = Positions.Neutral
|
# self._position = Positions.Neutral
|
||||||
|
|
||||||
self.total_reward = 0.
|
# self.total_reward = 0.
|
||||||
self._total_profit = 1. # unit
|
# self._total_profit = 1. # unit
|
||||||
self._first_rendering = True
|
# self._first_rendering = True
|
||||||
self.history = {}
|
# self.history = {}
|
||||||
self.trade_history = []
|
# self.trade_history = []
|
||||||
self.portfolio_log_returns = np.zeros(len(self.prices))
|
# self.portfolio_log_returns = np.zeros(len(self.prices))
|
||||||
|
|
||||||
self._profits = [(self._start_tick, 1)]
|
# self._profits = [(self._start_tick, 1)]
|
||||||
self.close_trade_profit = []
|
# self.close_trade_profit = []
|
||||||
|
|
||||||
return self._get_observation()
|
# return self._get_observation()
|
||||||
|
|
||||||
def step(self, action: int):
|
# def step(self, action: int):
|
||||||
self._done = False
|
# self._done = False
|
||||||
self._current_tick += 1
|
# self._current_tick += 1
|
||||||
|
|
||||||
if self._current_tick == self._end_tick:
|
# if self._current_tick == self._end_tick:
|
||||||
self._done = True
|
# self._done = True
|
||||||
|
|
||||||
self.update_portfolio_log_returns(action)
|
# self.update_portfolio_log_returns(action)
|
||||||
|
|
||||||
self._update_profit(action)
|
# self._update_profit(action)
|
||||||
step_reward = self.calculate_reward(action)
|
# step_reward = self.calculate_reward(action)
|
||||||
self.total_reward += step_reward
|
# self.total_reward += step_reward
|
||||||
|
|
||||||
trade_type = None
|
# trade_type = None
|
||||||
if self.is_tradesignal(action): # exclude 3 case not trade
|
# if self.is_tradesignal(action): # exclude 3 case not trade
|
||||||
# Update position
|
# # Update position
|
||||||
"""
|
# """
|
||||||
Action: Neutral, position: Long -> Close Long
|
# Action: Neutral, position: Long -> Close Long
|
||||||
Action: Neutral, position: Short -> Close Short
|
# Action: Neutral, position: Short -> Close Short
|
||||||
|
|
||||||
Action: Long, position: Neutral -> Open Long
|
# Action: Long, position: Neutral -> Open Long
|
||||||
Action: Long, position: Short -> Close Short and Open Long
|
# Action: Long, position: Short -> Close Short and Open Long
|
||||||
|
|
||||||
Action: Short, position: Neutral -> Open Short
|
# Action: Short, position: Neutral -> Open Short
|
||||||
Action: Short, position: Long -> Close Long and Open Short
|
# Action: Short, position: Long -> Close Long and Open Short
|
||||||
"""
|
# """
|
||||||
|
|
||||||
if action == Actions.Neutral.value:
|
# if action == Actions.Neutral.value:
|
||||||
self._position = Positions.Neutral
|
# self._position = Positions.Neutral
|
||||||
trade_type = "neutral"
|
# trade_type = "neutral"
|
||||||
elif action == Actions.Long.value:
|
# elif action == Actions.Long.value:
|
||||||
self._position = Positions.Long
|
# self._position = Positions.Long
|
||||||
trade_type = "long"
|
# trade_type = "long"
|
||||||
elif action == Actions.Short.value:
|
# elif action == Actions.Short.value:
|
||||||
self._position = Positions.Short
|
# self._position = Positions.Short
|
||||||
trade_type = "short"
|
# trade_type = "short"
|
||||||
else:
|
# else:
|
||||||
print("case not defined")
|
# print("case not defined")
|
||||||
|
|
||||||
# Update last trade tick
|
# # Update last trade tick
|
||||||
self._last_trade_tick = self._current_tick
|
# self._last_trade_tick = self._current_tick
|
||||||
|
|
||||||
if trade_type is not None:
|
# if trade_type is not None:
|
||||||
self.trade_history.append(
|
# self.trade_history.append(
|
||||||
{'price': self.current_price(), 'index': self._current_tick,
|
# {'price': self.current_price(), 'index': self._current_tick,
|
||||||
'type': trade_type})
|
# 'type': trade_type})
|
||||||
|
|
||||||
if self._total_profit < 0.2:
|
# if self._total_profit < 0.2:
|
||||||
self._done = True
|
# self._done = True
|
||||||
|
|
||||||
self._position_history.append(self._position)
|
# self._position_history.append(self._position)
|
||||||
observation = self._get_observation()
|
# observation = self._get_observation()
|
||||||
info = dict(
|
# info = dict(
|
||||||
tick=self._current_tick,
|
# tick=self._current_tick,
|
||||||
total_reward=self.total_reward,
|
# total_reward=self.total_reward,
|
||||||
total_profit=self._total_profit,
|
# total_profit=self._total_profit,
|
||||||
position=self._position.value
|
# position=self._position.value
|
||||||
)
|
# )
|
||||||
self._update_history(info)
|
# self._update_history(info)
|
||||||
|
|
||||||
return observation, step_reward, self._done, info
|
# return observation, step_reward, self._done, info
|
||||||
|
|
||||||
def _get_observation(self):
|
# def _get_observation(self):
|
||||||
features_window = self.signal_features[(
|
# features_window = self.signal_features[(
|
||||||
self._current_tick - self.window_size):self._current_tick]
|
# self._current_tick - self.window_size):self._current_tick]
|
||||||
features_and_state = DataFrame(np.zeros((len(features_window), 2)),
|
# features_and_state = DataFrame(np.zeros((len(features_window), 2)),
|
||||||
columns=['current_profit_pct', 'position'],
|
# columns=['current_profit_pct', 'position'],
|
||||||
index=features_window.index)
|
# index=features_window.index)
|
||||||
|
|
||||||
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
|
# features_and_state['current_profit_pct'] = self.get_unrealized_profit()
|
||||||
features_and_state['position'] = self._position.value
|
# features_and_state['position'] = self._position.value
|
||||||
features_and_state = pd.concat([features_window, features_and_state], axis=1)
|
# features_and_state = pd.concat([features_window, features_and_state], axis=1)
|
||||||
return features_and_state
|
# return features_and_state
|
||||||
|
|
||||||
def get_unrealized_profit(self):
|
# def get_unrealized_profit(self):
|
||||||
|
|
||||||
if self._last_trade_tick is None:
|
# if self._last_trade_tick is None:
|
||||||
return 0.
|
# return 0.
|
||||||
|
|
||||||
if self._position == Positions.Neutral:
|
# if self._position == Positions.Neutral:
|
||||||
return 0.
|
# return 0.
|
||||||
elif self._position == Positions.Short:
|
# elif self._position == Positions.Short:
|
||||||
current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open)
|
# current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open)
|
||||||
last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open)
|
# last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open)
|
||||||
return (last_trade_price - current_price) / last_trade_price
|
# return (last_trade_price - current_price) / last_trade_price
|
||||||
elif self._position == Positions.Long:
|
# elif self._position == Positions.Long:
|
||||||
current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open)
|
# current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open)
|
||||||
last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open)
|
# last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open)
|
||||||
return (current_price - last_trade_price) / last_trade_price
|
# return (current_price - last_trade_price) / last_trade_price
|
||||||
else:
|
# else:
|
||||||
return 0.
|
# return 0.
|
||||||
|
|
||||||
def is_tradesignal(self, action: int):
|
# def is_tradesignal(self, action: int):
|
||||||
# trade signal
|
# # trade signal
|
||||||
"""
|
# """
|
||||||
not trade signal is :
|
# not trade signal is :
|
||||||
Action: Neutral, position: Neutral -> Nothing
|
# Action: Neutral, position: Neutral -> Nothing
|
||||||
Action: Long, position: Long -> Hold Long
|
# Action: Long, position: Long -> Hold Long
|
||||||
Action: Short, position: Short -> Hold Short
|
# Action: Short, position: Short -> Hold Short
|
||||||
"""
|
# """
|
||||||
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral)
|
# return not ((action == Actions.Neutral.value and self._position == Positions.Neutral)
|
||||||
or (action == Actions.Short.value and self._position == Positions.Short)
|
# or (action == Actions.Short.value and self._position == Positions.Short)
|
||||||
or (action == Actions.Long.value and self._position == Positions.Long))
|
# or (action == Actions.Long.value and self._position == Positions.Long))
|
||||||
|
|
||||||
def _is_trade(self, action: Actions):
|
# def _is_trade(self, action: Actions):
|
||||||
return ((action == Actions.Long.value and self._position == Positions.Short) or
|
# return ((action == Actions.Long.value and self._position == Positions.Short) or
|
||||||
(action == Actions.Short.value and self._position == Positions.Long) or
|
# (action == Actions.Short.value and self._position == Positions.Long) or
|
||||||
(action == Actions.Neutral.value and self._position == Positions.Long) or
|
# (action == Actions.Neutral.value and self._position == Positions.Long) or
|
||||||
(action == Actions.Neutral.value and self._position == Positions.Short)
|
# (action == Actions.Neutral.value and self._position == Positions.Short)
|
||||||
)
|
# )
|
||||||
|
|
||||||
def is_hold(self, action):
|
# def is_hold(self, action):
|
||||||
return ((action == Actions.Short.value and self._position == Positions.Short)
|
# return ((action == Actions.Short.value and self._position == Positions.Short)
|
||||||
or (action == Actions.Long.value and self._position == Positions.Long))
|
# or (action == Actions.Long.value and self._position == Positions.Long))
|
||||||
|
|
||||||
def add_buy_fee(self, price):
|
# def add_buy_fee(self, price):
|
||||||
return price * (1 + self.fee)
|
# return price * (1 + self.fee)
|
||||||
|
|
||||||
def add_sell_fee(self, price):
|
# def add_sell_fee(self, price):
|
||||||
return price / (1 + self.fee)
|
# return price / (1 + self.fee)
|
||||||
|
|
||||||
def _update_history(self, info):
|
# def _update_history(self, info):
|
||||||
if not self.history:
|
# if not self.history:
|
||||||
self.history = {key: [] for key in info.keys()}
|
# self.history = {key: [] for key in info.keys()}
|
||||||
|
|
||||||
for key, value in info.items():
|
# for key, value in info.items():
|
||||||
self.history[key].append(value)
|
# self.history[key].append(value)
|
||||||
|
|
||||||
def get_sharpe_ratio(self):
|
# def get_sharpe_ratio(self):
|
||||||
return mean_over_std(self.get_portfolio_log_returns())
|
# return mean_over_std(self.get_portfolio_log_returns())
|
||||||
|
|
||||||
def calculate_reward(self, action):
|
# def calculate_reward(self, action):
|
||||||
|
|
||||||
if self._last_trade_tick is None:
|
# if self._last_trade_tick is None:
|
||||||
return 0.
|
# return 0.
|
||||||
|
|
||||||
# close long
|
# # close long
|
||||||
if (action == Actions.Short.value or
|
# if (action == Actions.Short.value or
|
||||||
action == Actions.Neutral.value) and self._position == Positions.Long:
|
# action == Actions.Neutral.value) and self._position == Positions.Long:
|
||||||
last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open)
|
# last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open)
|
||||||
current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open)
|
# current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open)
|
||||||
return float(np.log(current_price) - np.log(last_trade_price))
|
# return float(np.log(current_price) - np.log(last_trade_price))
|
||||||
|
|
||||||
# close short
|
# # close short
|
||||||
if (action == Actions.Long.value or
|
# if (action == Actions.Long.value or
|
||||||
action == Actions.Neutral.value) and self._position == Positions.Short:
|
# action == Actions.Neutral.value) and self._position == Positions.Short:
|
||||||
last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open)
|
# last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open)
|
||||||
current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open)
|
# current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open)
|
||||||
return float(np.log(last_trade_price) - np.log(current_price))
|
# return float(np.log(last_trade_price) - np.log(current_price))
|
||||||
|
|
||||||
return 0.
|
# return 0.
|
||||||
|
|
||||||
def _update_profit(self, action):
|
# def _update_profit(self, action):
|
||||||
if self._is_trade(action) or self._done:
|
# if self._is_trade(action) or self._done:
|
||||||
pnl = self.get_unrealized_profit()
|
# pnl = self.get_unrealized_profit()
|
||||||
|
|
||||||
if self._position == Positions.Long:
|
# if self._position == Positions.Long:
|
||||||
self._total_profit = self._total_profit + self._total_profit * pnl
|
# self._total_profit = self._total_profit + self._total_profit * pnl
|
||||||
self._profits.append((self._current_tick, self._total_profit))
|
# self._profits.append((self._current_tick, self._total_profit))
|
||||||
self.close_trade_profit.append(pnl)
|
# self.close_trade_profit.append(pnl)
|
||||||
|
|
||||||
if self._position == Positions.Short:
|
# if self._position == Positions.Short:
|
||||||
self._total_profit = self._total_profit + self._total_profit * pnl
|
# self._total_profit = self._total_profit + self._total_profit * pnl
|
||||||
self._profits.append((self._current_tick, self._total_profit))
|
# self._profits.append((self._current_tick, self._total_profit))
|
||||||
self.close_trade_profit.append(pnl)
|
# self.close_trade_profit.append(pnl)
|
||||||
|
|
||||||
def most_recent_return(self, action: int):
|
# def most_recent_return(self, action: int):
|
||||||
"""
|
# """
|
||||||
We support Long, Neutral and Short positions.
|
# We support Long, Neutral and Short positions.
|
||||||
Return is generated from rising prices in Long
|
# Return is generated from rising prices in Long
|
||||||
and falling prices in Short positions.
|
# and falling prices in Short positions.
|
||||||
The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
|
# The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
|
||||||
"""
|
# """
|
||||||
# Long positions
|
# # Long positions
|
||||||
if self._position == Positions.Long:
|
# if self._position == Positions.Long:
|
||||||
current_price = self.prices.iloc[self._current_tick].open
|
# current_price = self.prices.iloc[self._current_tick].open
|
||||||
if action == Actions.Short.value or action == Actions.Neutral.value:
|
# if action == Actions.Short.value or action == Actions.Neutral.value:
|
||||||
current_price = self.add_sell_fee(current_price)
|
# current_price = self.add_sell_fee(current_price)
|
||||||
|
|
||||||
previous_price = self.prices.iloc[self._current_tick - 1].open
|
# previous_price = self.prices.iloc[self._current_tick - 1].open
|
||||||
|
|
||||||
if (self._position_history[self._current_tick - 1] == Positions.Short
|
# if (self._position_history[self._current_tick - 1] == Positions.Short
|
||||||
or self._position_history[self._current_tick - 1] == Positions.Neutral):
|
# or self._position_history[self._current_tick - 1] == Positions.Neutral):
|
||||||
previous_price = self.add_buy_fee(previous_price)
|
# previous_price = self.add_buy_fee(previous_price)
|
||||||
|
|
||||||
return np.log(current_price) - np.log(previous_price)
|
# return np.log(current_price) - np.log(previous_price)
|
||||||
|
|
||||||
# Short positions
|
# # Short positions
|
||||||
if self._position == Positions.Short:
|
# if self._position == Positions.Short:
|
||||||
current_price = self.prices.iloc[self._current_tick].open
|
# current_price = self.prices.iloc[self._current_tick].open
|
||||||
if action == Actions.Long.value or action == Actions.Neutral.value:
|
# if action == Actions.Long.value or action == Actions.Neutral.value:
|
||||||
current_price = self.add_buy_fee(current_price)
|
# current_price = self.add_buy_fee(current_price)
|
||||||
|
|
||||||
previous_price = self.prices.iloc[self._current_tick - 1].open
|
# previous_price = self.prices.iloc[self._current_tick - 1].open
|
||||||
if (self._position_history[self._current_tick - 1] == Positions.Long
|
# if (self._position_history[self._current_tick - 1] == Positions.Long
|
||||||
or self._position_history[self._current_tick - 1] == Positions.Neutral):
|
# or self._position_history[self._current_tick - 1] == Positions.Neutral):
|
||||||
previous_price = self.add_sell_fee(previous_price)
|
# previous_price = self.add_sell_fee(previous_price)
|
||||||
|
|
||||||
return np.log(previous_price) - np.log(current_price)
|
# return np.log(previous_price) - np.log(current_price)
|
||||||
|
|
||||||
return 0
|
# return 0
|
||||||
|
|
||||||
def get_portfolio_log_returns(self):
|
# def get_portfolio_log_returns(self):
|
||||||
return self.portfolio_log_returns[1:self._current_tick + 1]
|
# return self.portfolio_log_returns[1:self._current_tick + 1]
|
||||||
|
|
||||||
def update_portfolio_log_returns(self, action):
|
# def update_portfolio_log_returns(self, action):
|
||||||
self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action)
|
# self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action)
|
||||||
|
|
||||||
def current_price(self) -> float:
|
# def current_price(self) -> float:
|
||||||
return self.prices.iloc[self._current_tick].open
|
# return self.prices.iloc[self._current_tick].open
|
||||||
|
|
||||||
def prev_price(self) -> float:
|
# def prev_price(self) -> float:
|
||||||
return self.prices.iloc[self._current_tick - 1].open
|
# return self.prices.iloc[self._current_tick - 1].open
|
||||||
|
|
||||||
def sharpe_ratio(self) -> float:
|
# def sharpe_ratio(self) -> float:
|
||||||
if len(self.close_trade_profit) == 0:
|
# if len(self.close_trade_profit) == 0:
|
||||||
return 0.
|
# return 0.
|
||||||
returns = np.array(self.close_trade_profit)
|
# returns = np.array(self.close_trade_profit)
|
||||||
reward = (np.mean(returns) - 0. + 1e-9) / (np.std(returns) + 1e-9)
|
# reward = (np.mean(returns) - 0. + 1e-9) / (np.std(returns) + 1e-9)
|
||||||
return reward
|
# return reward
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
# from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
from typing import Optional
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -44,14 +44,14 @@ class Base5ActionRLEnv(gym.Env):
|
|||||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {}):
|
id: str = 'baseenv-1', seed: int = 1, config: dict = {}):
|
||||||
assert df.ndim == 2
|
|
||||||
|
|
||||||
self.rl_config = config['freqai']['rl_config']
|
self.rl_config = config['freqai']['rl_config']
|
||||||
self.id = id
|
self.id = id
|
||||||
self.seed(seed)
|
self.seed(seed)
|
||||||
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
||||||
|
|
||||||
def reset_env(self, df, prices, window_size, reward_kwargs, starting_point=True):
|
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
||||||
|
reward_kwargs: dict, starting_point=True):
|
||||||
self.df = df
|
self.df = df
|
||||||
self.signal_features = self.df
|
self.signal_features = self.df
|
||||||
self.prices = prices
|
self.prices = prices
|
||||||
@ -69,18 +69,18 @@ class Base5ActionRLEnv(gym.Env):
|
|||||||
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
|
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
|
||||||
|
|
||||||
# episode
|
# episode
|
||||||
self._start_tick = self.window_size
|
self._start_tick: int = self.window_size
|
||||||
self._end_tick = len(self.prices) - 1
|
self._end_tick: int = len(self.prices) - 1
|
||||||
self._done = None
|
self._done: bool = False
|
||||||
self._current_tick = None
|
self._current_tick: int = self._start_tick
|
||||||
self._last_trade_tick = None
|
self._last_trade_tick: Optional[int] = None
|
||||||
self._position = Positions.Neutral
|
self._position = Positions.Neutral
|
||||||
self._position_history = None
|
self._position_history: list = [None]
|
||||||
self.total_reward = None
|
self.total_reward: float = 0
|
||||||
self._total_profit = None
|
self._total_profit: float = 0
|
||||||
self._first_rendering = None
|
self._first_rendering: bool = False
|
||||||
self.history = None
|
self.history: dict = {}
|
||||||
self.trade_history = []
|
self.trade_history: list = []
|
||||||
|
|
||||||
def seed(self, seed: int = 1):
|
def seed(self, seed: int = 1):
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
@ -125,8 +125,7 @@ class Base5ActionRLEnv(gym.Env):
|
|||||||
self.total_reward += step_reward
|
self.total_reward += step_reward
|
||||||
|
|
||||||
trade_type = None
|
trade_type = None
|
||||||
if self.is_tradesignal(action): # exclude 3 case not trade
|
if self.is_tradesignal(action):
|
||||||
# Update position
|
|
||||||
"""
|
"""
|
||||||
Action: Neutral, position: Long -> Close Long
|
Action: Neutral, position: Long -> Close Long
|
||||||
Action: Neutral, position: Short -> Close Short
|
Action: Neutral, position: Short -> Close Short
|
||||||
@ -223,9 +222,8 @@ class Base5ActionRLEnv(gym.Env):
|
|||||||
# trade signal
|
# trade signal
|
||||||
"""
|
"""
|
||||||
not trade signal is :
|
not trade signal is :
|
||||||
Action: Neutral, position: Neutral -> Nothing
|
Determine if the signal is non sensical
|
||||||
Action: Long, position: Long -> Hold Long
|
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||||
Action: Short, position: Short -> Hold Short
|
|
||||||
"""
|
"""
|
||||||
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
|
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
|
||||||
(action == Actions.Neutral.value and self._position == Positions.Short) or
|
(action == Actions.Neutral.value and self._position == Positions.Short) or
|
||||||
@ -292,7 +290,7 @@ class Base5ActionRLEnv(gym.Env):
|
|||||||
|
|
||||||
def most_recent_return(self, action: int):
|
def most_recent_return(self, action: int):
|
||||||
"""
|
"""
|
||||||
We support Long, Neutral and Short positions.
|
Calculate the tick to tick return if in a trade.
|
||||||
Return is generated from rising prices in Long
|
Return is generated from rising prices in Long
|
||||||
and falling prices in Short positions.
|
and falling prices in Short positions.
|
||||||
The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
|
The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
|
||||||
|
@ -19,6 +19,7 @@ from typing import Callable
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from stable_baselines3.common.utils import set_random_seed
|
from stable_baselines3.common.utils import set_random_seed
|
||||||
import gym
|
import gym
|
||||||
|
from pathlib import Path
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||||
@ -40,6 +41,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
self.eval_env: Base5ActionRLEnv = None
|
self.eval_env: Base5ActionRLEnv = None
|
||||||
self.eval_callback: EvalCallback = None
|
self.eval_callback: EvalCallback = None
|
||||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||||
|
self.rl_config = self.freqai_info['rl_config']
|
||||||
|
self.continual_retraining = self.rl_config['continual_retraining']
|
||||||
if self.model_type in SB3_MODELS:
|
if self.model_type in SB3_MODELS:
|
||||||
import_str = 'stable_baselines3'
|
import_str = 'stable_baselines3'
|
||||||
elif self.model_type in SB3_CONTRIB_MODELS:
|
elif self.model_type in SB3_CONTRIB_MODELS:
|
||||||
@ -68,7 +71,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
logger.info("--------------------Starting training " f"{pair} --------------------")
|
logger.info("--------------------Starting training " f"{pair} --------------------")
|
||||||
|
|
||||||
# filter the features requested by user in the configuration file and elegantly handle NaNs
|
|
||||||
features_filtered, labels_filtered = dk.filter_features(
|
features_filtered, labels_filtered = dk.filter_features(
|
||||||
unfiltered_dataframe,
|
unfiltered_dataframe,
|
||||||
dk.training_features_list,
|
dk.training_features_list,
|
||||||
@ -78,19 +80,19 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
|
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
|
||||||
features_filtered, labels_filtered)
|
features_filtered, labels_filtered)
|
||||||
dk.fit_labels() # useless for now, but just satiating append methods
|
dk.fit_labels() # FIXME useless for now, but just satiating append methods
|
||||||
|
|
||||||
# normalize all data based on train_dataset only
|
# normalize all data based on train_dataset only
|
||||||
prices_train, prices_test = self.build_ohlc_price_dataframes(dk.data_dictionary, pair, dk)
|
prices_train, prices_test = self.build_ohlc_price_dataframes(dk.data_dictionary, pair, dk)
|
||||||
data_dictionary = dk.normalize_data(data_dictionary)
|
data_dictionary = dk.normalize_data(data_dictionary)
|
||||||
|
|
||||||
# optional additional data cleaning/analysis
|
# data cleaning/analysis
|
||||||
self.data_cleaning_train(dk)
|
self.data_cleaning_train(dk)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f'Training model on {len(dk.data_dictionary["train_features"].columns)}' " features"
|
f'Training model on {len(dk.data_dictionary["train_features"].columns)}'
|
||||||
|
f' features and {len(data_dictionary["train_features"])} data points'
|
||||||
)
|
)
|
||||||
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
|
|
||||||
|
|
||||||
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
|
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
|
||||||
|
|
||||||
@ -100,9 +102,11 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def set_train_and_eval_environments(self, data_dictionary, prices_train, prices_test, dk):
|
def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame],
|
||||||
|
prices_train: DataFrame, prices_test: DataFrame,
|
||||||
|
dk: FreqaiDataKitchen):
|
||||||
"""
|
"""
|
||||||
User overrides this as shown here if they are using a custom MyRLEnv
|
User can override this if they are using a custom MyRLEnv
|
||||||
"""
|
"""
|
||||||
train_df = data_dictionary["train_features"]
|
train_df = data_dictionary["train_features"]
|
||||||
test_df = data_dictionary["test_features"]
|
test_df = data_dictionary["test_features"]
|
||||||
@ -114,18 +118,22 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
reward_kwargs=self.reward_params, config=self.config)
|
reward_kwargs=self.reward_params, config=self.config)
|
||||||
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
|
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
|
||||||
window_size=self.CONV_WIDTH,
|
window_size=self.CONV_WIDTH,
|
||||||
reward_kwargs=self.reward_params, config=self.config), ".")
|
reward_kwargs=self.reward_params, config=self.config),
|
||||||
|
str(Path(dk.data_path / 'monitor')))
|
||||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||||
render=False, eval_freq=eval_freq,
|
render=False, eval_freq=eval_freq,
|
||||||
best_model_save_path=dk.data_path)
|
best_model_save_path=str(dk.data_path))
|
||||||
else:
|
else:
|
||||||
self.train_env.reset()
|
self.train_env.reset()
|
||||||
self.eval_env.reset()
|
self.eval_env.reset()
|
||||||
self.train_env.reset_env(train_df, prices_train, self.CONV_WIDTH, self.reward_params)
|
self.train_env.reset_env(train_df, prices_train, self.CONV_WIDTH, self.reward_params)
|
||||||
self.eval_env.reset_env(test_df, prices_test, self.CONV_WIDTH, self.reward_params)
|
self.eval_env.reset_env(test_df, prices_test, self.CONV_WIDTH, self.reward_params)
|
||||||
|
# self.eval_callback.eval_env = self.eval_env
|
||||||
|
# self.eval_callback.best_model_save_path = str(dk.data_path)
|
||||||
|
# self.eval_callback._init_callback()
|
||||||
self.eval_callback.__init__(self.eval_env, deterministic=True,
|
self.eval_callback.__init__(self.eval_env, deterministic=True,
|
||||||
render=False, eval_freq=eval_freq,
|
render=False, eval_freq=eval_freq,
|
||||||
best_model_save_path=dk.data_path)
|
best_model_save_path=str(dk.data_path))
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
|
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
|
||||||
@ -137,19 +145,20 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def get_state_info(self, pair):
|
def get_state_info(self, pair: str):
|
||||||
open_trades = Trade.get_trades_proxy(is_open=True)
|
open_trades = Trade.get_trades_proxy(is_open=True)
|
||||||
market_side = 0.5
|
market_side = 0.5
|
||||||
current_profit = 0
|
current_profit: float = 0
|
||||||
trade_duration = 0
|
trade_duration = 0
|
||||||
for trade in open_trades:
|
for trade in open_trades:
|
||||||
if trade.pair == pair:
|
if trade.pair == pair:
|
||||||
|
# FIXME: mypy typing doesnt like that strategy may be "None" (it never will be)
|
||||||
current_value = self.strategy.dp._exchange.get_rate(
|
current_value = self.strategy.dp._exchange.get_rate(
|
||||||
pair, refresh=False, side="exit", is_short=trade.is_short)
|
pair, refresh=False, side="exit", is_short=trade.is_short)
|
||||||
openrate = trade.open_rate
|
openrate = trade.open_rate
|
||||||
now = datetime.now(timezone.utc).timestamp()
|
now = datetime.now(timezone.utc).timestamp()
|
||||||
trade_duration = (now - trade.open_date.timestamp()) / self.base_tf_seconds
|
trade_duration = int((now - trade.open_date.timestamp()) / self.base_tf_seconds)
|
||||||
if 'long' in trade.enter_tag:
|
if 'long' in str(trade.enter_tag):
|
||||||
market_side = 1
|
market_side = 1
|
||||||
current_profit = (current_value - openrate) / openrate
|
current_profit = (current_value - openrate) / openrate
|
||||||
else:
|
else:
|
||||||
@ -245,8 +254,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
def make_env(env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame,
|
||||||
reward_params, window_size, monitor=False, config={}) -> Callable:
|
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||||
|
config: Dict[str, Any] = {}) -> Callable:
|
||||||
"""
|
"""
|
||||||
Utility function for multiprocessed env.
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
|
@ -22,6 +22,12 @@ class ReinforcementLearnerCustomAgent(BaseReinforcementLearningModel):
|
|||||||
"""
|
"""
|
||||||
User can customize agent by defining the class and using it directly.
|
User can customize agent by defining the class and using it directly.
|
||||||
Here the example is "TDQN"
|
Here the example is "TDQN"
|
||||||
|
|
||||||
|
Warning!
|
||||||
|
This is an advanced example of how a user may create and use a highly
|
||||||
|
customized model class (which can inherit from existing classes,
|
||||||
|
similar to how the example below inherits from DQN).
|
||||||
|
This file is for example purposes only, and should not be run.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
|
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
|
||||||
@ -34,7 +40,7 @@ class ReinforcementLearnerCustomAgent(BaseReinforcementLearningModel):
|
|||||||
|
|
||||||
# TDQN is a custom agent defined below
|
# TDQN is a custom agent defined below
|
||||||
model = TDQN(self.policy_type, self.train_env,
|
model = TDQN(self.policy_type, self.train_env,
|
||||||
tensorboard_log=Path(dk.data_path / "tensorboard"),
|
tensorboard_log=str(Path(dk.data_path / "tensorboard")),
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
**self.freqai_info['model_training_parameters']
|
**self.freqai_info['model_training_parameters']
|
||||||
)
|
)
|
||||||
@ -217,7 +223,7 @@ class TDQN(DQN):
|
|||||||
exploration_initial_eps: float = 1.0,
|
exploration_initial_eps: float = 1.0,
|
||||||
exploration_final_eps: float = 0.05,
|
exploration_final_eps: float = 0.05,
|
||||||
max_grad_norm: float = 10,
|
max_grad_norm: float = 10,
|
||||||
tensorboard_log: Optional[Path] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
create_eval_env: bool = False,
|
create_eval_env: bool = False,
|
||||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
verbose: int = 1,
|
verbose: int = 1,
|
@ -485,6 +485,10 @@ class FreqaiDataDrawer:
|
|||||||
f"Unable to load model, ensure model exists at " f"{dk.data_path} "
|
f"Unable to load model, ensure model exists at " f"{dk.data_path} "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# load it into ram if it was loaded from disk
|
||||||
|
if coin not in self.model_dictionary:
|
||||||
|
self.model_dictionary[coin] = model
|
||||||
|
|
||||||
if self.config["freqai"]["feature_parameters"]["principal_component_analysis"]:
|
if self.config["freqai"]["feature_parameters"]["principal_component_analysis"]:
|
||||||
dk.pca = cloudpickle.load(
|
dk.pca = cloudpickle.load(
|
||||||
open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "rb")
|
open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "rb")
|
||||||
|
@ -76,7 +76,8 @@ class ReinforcementLearningExample5ac(IStrategy):
|
|||||||
informative[f"%-{coin}pct-change"] = informative["close"].pct_change()
|
informative[f"%-{coin}pct-change"] = informative["close"].pct_change()
|
||||||
informative[f"%-{coin}raw_volume"] = informative["volume"]
|
informative[f"%-{coin}raw_volume"] = informative["volume"]
|
||||||
|
|
||||||
# The following features are necessary for RL models
|
# FIXME: add these outside the user strategy?
|
||||||
|
# The following columns are necessary for RL models.
|
||||||
informative[f"%-{coin}raw_close"] = informative["close"]
|
informative[f"%-{coin}raw_close"] = informative["close"]
|
||||||
informative[f"%-{coin}raw_open"] = informative["open"]
|
informative[f"%-{coin}raw_open"] = informative["open"]
|
||||||
informative[f"%-{coin}raw_high"] = informative["high"]
|
informative[f"%-{coin}raw_high"] = informative["high"]
|
||||||
|
@ -57,9 +57,9 @@ class BaseClassifierModel(IFreqaiModel):
|
|||||||
self.data_cleaning_train(dk)
|
self.data_cleaning_train(dk)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f'Training model on {len(dk.data_dictionary["train_features"].columns)}' " features"
|
f'Training model on {len(dk.data_dictionary["train_features"].columns)}'
|
||||||
|
f' features and {len(data_dictionary["train_features"])} data points'
|
||||||
)
|
)
|
||||||
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
|
|
||||||
|
|
||||||
model = self.fit(data_dictionary)
|
model = self.fit(data_dictionary)
|
||||||
|
|
||||||
|
@ -56,9 +56,9 @@ class BaseRegressionModel(IFreqaiModel):
|
|||||||
self.data_cleaning_train(dk)
|
self.data_cleaning_train(dk)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f'Training model on {len(dk.data_dictionary["train_features"].columns)}' " features"
|
f'Training model on {len(dk.data_dictionary["train_features"].columns)}'
|
||||||
|
f' features and {len(data_dictionary["train_features"])} data points'
|
||||||
)
|
)
|
||||||
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
|
|
||||||
|
|
||||||
model = self.fit(data_dictionary)
|
model = self.fit(data_dictionary)
|
||||||
|
|
||||||
|
@ -53,9 +53,9 @@ class BaseTensorFlowModel(IFreqaiModel):
|
|||||||
self.data_cleaning_train(dk)
|
self.data_cleaning_train(dk)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f'Training model on {len(dk.data_dictionary["train_features"].columns)}' " features"
|
f'Training model on {len(dk.data_dictionary["train_features"].columns)}'
|
||||||
|
f' features and {len(data_dictionary["train_features"])} data points'
|
||||||
)
|
)
|
||||||
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
|
|
||||||
|
|
||||||
model = self.fit(data_dictionary)
|
model = self.fit(data_dictionary)
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict # , Tuple
|
from typing import Any, Dict
|
||||||
|
|
||||||
# import numpy.typing as npt
|
|
||||||
import torch as th
|
import torch as th
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
|
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
|
||||||
@ -22,12 +21,18 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
||||||
|
|
||||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||||
net_arch=[256, 256, 128])
|
net_arch=[512, 512, 256])
|
||||||
|
|
||||||
|
if dk.pair not in self.dd.model_dictionary or not self.continual_retraining:
|
||||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||||
tensorboard_log=Path(dk.data_path / "tensorboard"),
|
tensorboard_log=Path(dk.data_path / "tensorboard"),
|
||||||
**self.freqai_info['model_training_parameters']
|
**self.freqai_info['model_training_parameters']
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.info('Continual training activated - starting training from previously '
|
||||||
|
'trained agent.')
|
||||||
|
model = self.dd.model_dictionary[dk.pair]
|
||||||
|
model.set_env(self.train_env)
|
||||||
|
|
||||||
model.learn(
|
model.learn(
|
||||||
total_timesteps=int(total_timesteps),
|
total_timesteps=int(total_timesteps),
|
||||||
|
Loading…
Reference in New Issue
Block a user