fix persist a single training environment for PPO

This commit is contained in:
sonnhfit
2022-08-19 01:49:11 +07:00
committed by robcaulk
parent f95602f6bd
commit 4baa36bdcf
3 changed files with 51 additions and 25 deletions

View File

@@ -1,13 +1,16 @@
import logging
from enum import Enum
# from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import pandas as pd
from gym import spaces
from gym.utils import seeding
from pandas import DataFrame
# from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
logger = logging.getLogger(__name__)
@@ -43,6 +46,9 @@ class Base3ActionRLEnv(gym.Env):
self.id = id
self.seed(seed)
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
def reset_env(self, df, prices, window_size, reward_kwargs, starting_point=True):
self.df = df
self.signal_features = self.df
self.prices = prices
@@ -54,7 +60,7 @@ class Base3ActionRLEnv(gym.Env):
self.fee = 0.0015
# # spaces
self.shape = (window_size, self.signal_features.shape[1])
self.shape = (window_size, self.signal_features.shape[1] + 2)
self.action_space = spaces.Discrete(len(Actions))
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
@@ -165,7 +171,16 @@ class Base3ActionRLEnv(gym.Env):
return observation, step_reward, self._done, info
def _get_observation(self):
return self.signal_features[(self._current_tick - self.window_size):self._current_tick]
features_window = self.signal_features[(
self._current_tick - self.window_size):self._current_tick]
features_and_state = DataFrame(np.zeros((len(features_window), 2)),
columns=['current_profit_pct', 'position'],
index=features_window.index)
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
features_and_state['position'] = self._position.value
features_and_state = pd.concat([features_window, features_and_state], axis=1)
return features_and_state
def get_unrealized_profit(self):
@@ -307,7 +322,7 @@ class Base3ActionRLEnv(gym.Env):
def prev_price(self) -> float:
return self.prices.iloc[self._current_tick - 1].open
def sharpe_ratio(self):
def sharpe_ratio(self) -> float:
if len(self.close_trade_profit) == 0:
return 0.
returns = np.array(self.close_trade_profit)