stable/freqtrade/freqai/prediction_models/RL/RLPrediction_env.py

231 lines
7.2 KiB
Python
Raw Normal View History

from enum import Enum
import gym
import matplotlib.pyplot as plt
import numpy as np
from gym import spaces
from gym.utils import seeding
class Actions(Enum):
Hold = 0
Buy = 1
Sell = 2
class Positions(Enum):
Short = 0
Long = 1
def opposite(self):
return Positions.Short if self == Positions.Long else Positions.Long
class GymAnytrading(gym.Env):
"""
Based on https://github.com/AminHP/gym-anytrading
"""
metadata = {'render.modes': ['human']}
def __init__(self, signal_features, prices, window_size, fee=0.0):
assert signal_features.ndim == 2
self.seed()
self.signal_features = signal_features
self.prices = prices
self.window_size = window_size
self.fee = fee
self.shape = (window_size, self.signal_features.shape[1])
# spaces
self.action_space = spaces.Discrete(len(Actions))
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
# episode
self._start_tick = self.window_size
self._end_tick = len(self.prices) - 1
self._done = None
self._current_tick = None
self._last_trade_tick = None
self._position = None
self._position_history = None
self._total_reward = None
self._total_profit = None
self._first_rendering = None
self.history = None
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def reset(self):
self._done = False
self._current_tick = self._start_tick
self._last_trade_tick = self._current_tick - 1
self._position = Positions.Short
self._position_history = (self.window_size * [None]) + [self._position]
self._total_reward = 0.
self._total_profit = 1. # unit
self._first_rendering = True
self.history = {}
return self._get_observation()
def step(self, action):
self._done = False
self._current_tick += 1
if self._current_tick == self._end_tick:
self._done = True
step_reward = self._calculate_reward(action)
self._total_reward += step_reward
self._update_profit(action)
trade = False
if ((action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)):
trade = True
if trade:
self._position = self._position.opposite()
self._last_trade_tick = self._current_tick
self._position_history.append(self._position)
observation = self._get_observation()
info = dict(
total_reward=self._total_reward,
total_profit=self._total_profit,
position=self._position.value
)
self._update_history(info)
return observation, step_reward, self._done, info
def _get_observation(self):
return self.signal_features[(self._current_tick - self.window_size):self._current_tick]
def _update_history(self, info):
if not self.history:
self.history = {key: [] for key in info.keys()}
for key, value in info.items():
self.history[key].append(value)
def render(self, mode='human'):
def _plot_position(position, tick):
color = None
if position == Positions.Short:
color = 'red'
elif position == Positions.Long:
color = 'green'
if color:
plt.scatter(tick, self.prices[tick], color=color)
if self._first_rendering:
self._first_rendering = False
plt.cla()
plt.plot(self.prices)
start_position = self._position_history[self._start_tick]
_plot_position(start_position, self._start_tick)
_plot_position(self._position, self._current_tick)
plt.suptitle(
"Total Reward: %.6f" % self._total_reward + ' ~ ' +
"Total Profit: %.6f" % self._total_profit
)
plt.pause(0.01)
def render_all(self, mode='human'):
window_ticks = np.arange(len(self._position_history))
plt.plot(self.prices)
short_ticks = []
long_ticks = []
for i, tick in enumerate(window_ticks):
if self._position_history[i] == Positions.Short:
short_ticks.append(tick)
elif self._position_history[i] == Positions.Long:
long_ticks.append(tick)
plt.plot(short_ticks, self.prices[short_ticks], 'ro')
plt.plot(long_ticks, self.prices[long_ticks], 'go')
plt.suptitle(
"Total Reward: %.6f" % self._total_reward + ' ~ ' +
"Total Profit: %.6f" % self._total_profit
)
def close(self):
plt.close()
def save_rendering(self, filepath):
plt.savefig(filepath)
def pause_rendering(self):
plt.show()
def _calculate_reward(self, action):
step_reward = 0
trade = False
if ((action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)):
trade = True
if trade:
current_price = self.prices[self._current_tick]
last_trade_price = self.prices[self._last_trade_tick]
price_diff = current_price - last_trade_price
if self._position == Positions.Long:
step_reward += price_diff
return step_reward
def _update_profit(self, action):
trade = False
if ((action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)):
trade = True
if trade or self._done:
current_price = self.prices[self._current_tick]
last_trade_price = self.prices[self._last_trade_tick]
if self._position == Positions.Long:
shares = (self._total_profit * (1 - self.fee)) / last_trade_price
self._total_profit = (shares * (1 - self.fee)) * current_price
def max_possible_profit(self):
current_tick = self._start_tick
last_trade_tick = current_tick - 1
profit = 1.
while current_tick <= self._end_tick:
position = None
if self.prices[current_tick] < self.prices[current_tick - 1]:
while (current_tick <= self._end_tick and
self.prices[current_tick] < self.prices[current_tick - 1]):
current_tick += 1
position = Positions.Short
else:
while (current_tick <= self._end_tick and
self.prices[current_tick] >= self.prices[current_tick - 1]):
current_tick += 1
position = Positions.Long
if position == Positions.Long:
current_price = self.prices[current_tick - 1]
last_trade_price = self.prices[last_trade_tick]
shares = profit / last_trade_price
profit = shares * current_price
last_trade_tick = current_tick - 1
print(profit)
return profit