add continual retraining feature, handly mypy typing reqs, improve docstrings
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
# from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@@ -44,14 +44,14 @@ class Base5ActionRLEnv(gym.Env):
|
||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {}):
|
||||
assert df.ndim == 2
|
||||
|
||||
self.rl_config = config['freqai']['rl_config']
|
||||
self.id = id
|
||||
self.seed(seed)
|
||||
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
||||
|
||||
def reset_env(self, df, prices, window_size, reward_kwargs, starting_point=True):
|
||||
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
||||
reward_kwargs: dict, starting_point=True):
|
||||
self.df = df
|
||||
self.signal_features = self.df
|
||||
self.prices = prices
|
||||
@@ -69,18 +69,18 @@ class Base5ActionRLEnv(gym.Env):
|
||||
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
|
||||
|
||||
# episode
|
||||
self._start_tick = self.window_size
|
||||
self._end_tick = len(self.prices) - 1
|
||||
self._done = None
|
||||
self._current_tick = None
|
||||
self._last_trade_tick = None
|
||||
self._start_tick: int = self.window_size
|
||||
self._end_tick: int = len(self.prices) - 1
|
||||
self._done: bool = False
|
||||
self._current_tick: int = self._start_tick
|
||||
self._last_trade_tick: Optional[int] = None
|
||||
self._position = Positions.Neutral
|
||||
self._position_history = None
|
||||
self.total_reward = None
|
||||
self._total_profit = None
|
||||
self._first_rendering = None
|
||||
self.history = None
|
||||
self.trade_history = []
|
||||
self._position_history: list = [None]
|
||||
self.total_reward: float = 0
|
||||
self._total_profit: float = 0
|
||||
self._first_rendering: bool = False
|
||||
self.history: dict = {}
|
||||
self.trade_history: list = []
|
||||
|
||||
def seed(self, seed: int = 1):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
@@ -125,8 +125,7 @@ class Base5ActionRLEnv(gym.Env):
|
||||
self.total_reward += step_reward
|
||||
|
||||
trade_type = None
|
||||
if self.is_tradesignal(action): # exclude 3 case not trade
|
||||
# Update position
|
||||
if self.is_tradesignal(action):
|
||||
"""
|
||||
Action: Neutral, position: Long -> Close Long
|
||||
Action: Neutral, position: Short -> Close Short
|
||||
@@ -223,9 +222,8 @@ class Base5ActionRLEnv(gym.Env):
|
||||
# trade signal
|
||||
"""
|
||||
not trade signal is :
|
||||
Action: Neutral, position: Neutral -> Nothing
|
||||
Action: Long, position: Long -> Hold Long
|
||||
Action: Short, position: Short -> Hold Short
|
||||
Determine if the signal is non sensical
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
"""
|
||||
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
|
||||
(action == Actions.Neutral.value and self._position == Positions.Short) or
|
||||
@@ -292,7 +290,7 @@ class Base5ActionRLEnv(gym.Env):
|
||||
|
||||
def most_recent_return(self, action: int):
|
||||
"""
|
||||
We support Long, Neutral and Short positions.
|
||||
Calculate the tick to tick return if in a trade.
|
||||
Return is generated from rising prices in Long
|
||||
and falling prices in Short positions.
|
||||
The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
|
||||
|
||||
Reference in New Issue
Block a user