fix persist a single training environment for PPO
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
# from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
from pandas import DataFrame
|
||||
|
||||
|
||||
# from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -43,6 +46,9 @@ class Base3ActionRLEnv(gym.Env):
|
||||
|
||||
self.id = id
|
||||
self.seed(seed)
|
||||
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
||||
|
||||
def reset_env(self, df, prices, window_size, reward_kwargs, starting_point=True):
|
||||
self.df = df
|
||||
self.signal_features = self.df
|
||||
self.prices = prices
|
||||
@@ -54,7 +60,7 @@ class Base3ActionRLEnv(gym.Env):
|
||||
self.fee = 0.0015
|
||||
|
||||
# # spaces
|
||||
self.shape = (window_size, self.signal_features.shape[1])
|
||||
self.shape = (window_size, self.signal_features.shape[1] + 2)
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
|
||||
@@ -165,7 +171,16 @@ class Base3ActionRLEnv(gym.Env):
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
def _get_observation(self):
|
||||
return self.signal_features[(self._current_tick - self.window_size):self._current_tick]
|
||||
features_window = self.signal_features[(
|
||||
self._current_tick - self.window_size):self._current_tick]
|
||||
features_and_state = DataFrame(np.zeros((len(features_window), 2)),
|
||||
columns=['current_profit_pct', 'position'],
|
||||
index=features_window.index)
|
||||
|
||||
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
|
||||
features_and_state['position'] = self._position.value
|
||||
features_and_state = pd.concat([features_window, features_and_state], axis=1)
|
||||
return features_and_state
|
||||
|
||||
def get_unrealized_profit(self):
|
||||
|
||||
@@ -307,7 +322,7 @@ class Base3ActionRLEnv(gym.Env):
|
||||
def prev_price(self) -> float:
|
||||
return self.prices.iloc[self._current_tick - 1].open
|
||||
|
||||
def sharpe_ratio(self):
|
||||
def sharpe_ratio(self) -> float:
|
||||
if len(self.close_trade_profit) == 0:
|
||||
return 0.
|
||||
returns = np.array(self.close_trade_profit)
|
||||
|
Reference in New Issue
Block a user