persist a single training environment.
This commit is contained in:
@@ -7,7 +7,7 @@ import numpy as np
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
from pandas import DataFrame
|
||||
|
||||
import pandas as pd
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -47,6 +47,9 @@ class Base5ActionRLEnv(gym.Env):
|
||||
|
||||
self.id = id
|
||||
self.seed(seed)
|
||||
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
||||
|
||||
def reset_env(self, df, prices, window_size, reward_kwargs, starting_point=True):
|
||||
self.df = df
|
||||
self.signal_features = self.df
|
||||
self.prices = prices
|
||||
@@ -178,10 +181,15 @@ class Base5ActionRLEnv(gym.Env):
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
def _get_observation(self):
|
||||
features_and_state = self.signal_features[(
|
||||
features_window = self.signal_features[(
|
||||
self._current_tick - self.window_size):self._current_tick]
|
||||
features_and_state = DataFrame(np.zeros((len(features_window), 2)),
|
||||
columns=['current_profit_pct', 'position'],
|
||||
index=features_window.index)
|
||||
|
||||
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
|
||||
features_and_state['position'] = self._position.value
|
||||
features_and_state = pd.concat([features_window, features_and_state], axis=1)
|
||||
return features_and_state
|
||||
|
||||
def get_unrealized_profit(self):
|
||||
|
||||
Reference in New Issue
Block a user