stable/freqtrade/freqai/prediction_models/RL/RLPrediction_env.py

231 lines
7.2 KiB
Python

from enum import Enum
import gym
import matplotlib.pyplot as plt
import numpy as np
from gym import spaces
from gym.utils import seeding
class Actions(Enum):
Hold = 0
Buy = 1
Sell = 2
class Positions(Enum):
Short = 0
Long = 1
def opposite(self):
return Positions.Short if self == Positions.Long else Positions.Long
class GymAnytrading(gym.Env):
"""
Based on https://github.com/AminHP/gym-anytrading
"""
metadata = {'render.modes': ['human']}
def __init__(self, signal_features, prices, window_size, fee=0.0):
assert signal_features.ndim == 2
self.seed()
self.signal_features = signal_features
self.prices = prices
self.window_size = window_size
self.fee = fee
self.shape = (window_size, self.signal_features.shape[1])
# spaces
self.action_space = spaces.Discrete(len(Actions))
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
# episode
self._start_tick = self.window_size
self._end_tick = len(self.prices) - 1
self._done = None
self._current_tick = None
self._last_trade_tick = None
self._position = None
self._position_history = None
self._total_reward = None
self._total_profit = None
self._first_rendering = None
self.history = None
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def reset(self):
self._done = False
self._current_tick = self._start_tick
self._last_trade_tick = self._current_tick - 1
self._position = Positions.Short
self._position_history = (self.window_size * [None]) + [self._position]
self._total_reward = 0.
self._total_profit = 1. # unit
self._first_rendering = True
self.history = {}
return self._get_observation()
def step(self, action):
self._done = False
self._current_tick += 1
if self._current_tick == self._end_tick:
self._done = True
step_reward = self._calculate_reward(action)
self._total_reward += step_reward
self._update_profit(action)
trade = False
if ((action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)):
trade = True
if trade:
self._position = self._position.opposite()
self._last_trade_tick = self._current_tick
self._position_history.append(self._position)
observation = self._get_observation()
info = dict(
total_reward=self._total_reward,
total_profit=self._total_profit,
position=self._position.value
)
self._update_history(info)
return observation, step_reward, self._done, info
def _get_observation(self):
return self.signal_features[(self._current_tick - self.window_size):self._current_tick]
def _update_history(self, info):
if not self.history:
self.history = {key: [] for key in info.keys()}
for key, value in info.items():
self.history[key].append(value)
def render(self, mode='human'):
def _plot_position(position, tick):
color = None
if position == Positions.Short:
color = 'red'
elif position == Positions.Long:
color = 'green'
if color:
plt.scatter(tick, self.prices[tick], color=color)
if self._first_rendering:
self._first_rendering = False
plt.cla()
plt.plot(self.prices)
start_position = self._position_history[self._start_tick]
_plot_position(start_position, self._start_tick)
_plot_position(self._position, self._current_tick)
plt.suptitle(
"Total Reward: %.6f" % self._total_reward + ' ~ ' +
"Total Profit: %.6f" % self._total_profit
)
plt.pause(0.01)
def render_all(self, mode='human'):
window_ticks = np.arange(len(self._position_history))
plt.plot(self.prices)
short_ticks = []
long_ticks = []
for i, tick in enumerate(window_ticks):
if self._position_history[i] == Positions.Short:
short_ticks.append(tick)
elif self._position_history[i] == Positions.Long:
long_ticks.append(tick)
plt.plot(short_ticks, self.prices[short_ticks], 'ro')
plt.plot(long_ticks, self.prices[long_ticks], 'go')
plt.suptitle(
"Total Reward: %.6f" % self._total_reward + ' ~ ' +
"Total Profit: %.6f" % self._total_profit
)
def close(self):
plt.close()
def save_rendering(self, filepath):
plt.savefig(filepath)
def pause_rendering(self):
plt.show()
def _calculate_reward(self, action):
step_reward = 0
trade = False
if ((action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)):
trade = True
if trade:
current_price = self.prices[self._current_tick]
last_trade_price = self.prices[self._last_trade_tick]
price_diff = current_price - last_trade_price
if self._position == Positions.Long:
step_reward += price_diff
return step_reward
def _update_profit(self, action):
trade = False
if ((action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)):
trade = True
if trade or self._done:
current_price = self.prices[self._current_tick]
last_trade_price = self.prices[self._last_trade_tick]
if self._position == Positions.Long:
shares = (self._total_profit * (1 - self.fee)) / last_trade_price
self._total_profit = (shares * (1 - self.fee)) * current_price
def max_possible_profit(self):
current_tick = self._start_tick
last_trade_tick = current_tick - 1
profit = 1.
while current_tick <= self._end_tick:
position = None
if self.prices[current_tick] < self.prices[current_tick - 1]:
while (current_tick <= self._end_tick and
self.prices[current_tick] < self.prices[current_tick - 1]):
current_tick += 1
position = Positions.Short
else:
while (current_tick <= self._end_tick and
self.prices[current_tick] >= self.prices[current_tick - 1]):
current_tick += 1
position = Positions.Long
if position == Positions.Long:
current_price = self.prices[current_tick - 1]
last_trade_price = self.prices[last_trade_tick]
shares = profit / last_trade_price
profit = shares * current_price
last_trade_tick = current_tick - 1
print(profit)
return profit