5ac base fixes in logic
This commit is contained in:
		| @@ -26,23 +26,23 @@ class Positions(Enum): | ||||
|     def opposite(self): | ||||
|         return Positions.Short if self == Positions.Long else Positions.Long | ||||
|  | ||||
|  | ||||
| def mean_over_std(x): | ||||
|     std = np.std(x, ddof=1) | ||||
|     mean = np.mean(x) | ||||
|     return mean / std if std > 0 else 0 | ||||
|  | ||||
|  | ||||
| class Base5ActionRLEnv(gym.Env): | ||||
|     """ | ||||
|     Base class for a 5 action environment | ||||
|     """ | ||||
|     metadata = {'render.modes': ['human']} | ||||
|  | ||||
|     def __init__(self, df, prices, reward_kwargs, window_size=10, starting_point=True, ): | ||||
|     def __init__(self, df, prices, reward_kwargs, window_size=10, starting_point=True, | ||||
|                  id: str = 'baseenv-1', seed: int = 1): | ||||
|         assert df.ndim == 2 | ||||
|  | ||||
|         self.seed() | ||||
|         self.id = id | ||||
|         self.seed(seed) | ||||
|         self.df = df | ||||
|         self.signal_features = self.df | ||||
|         self.prices = prices | ||||
| @@ -73,7 +73,7 @@ class Base5ActionRLEnv(gym.Env): | ||||
|         self.history = None | ||||
|         self.trade_history = [] | ||||
|  | ||||
|     def seed(self, seed=None): | ||||
|     def seed(self, seed: int = 1): | ||||
|         self.np_random, seed = seeding.np_random(seed) | ||||
|         return [seed] | ||||
|  | ||||
| @@ -102,7 +102,7 @@ class Base5ActionRLEnv(gym.Env): | ||||
|  | ||||
|         return self._get_observation() | ||||
|  | ||||
|     def step(self, action): | ||||
|     def step(self, action: int): | ||||
|         self._done = False | ||||
|         self._current_tick += 1 | ||||
|  | ||||
| @@ -191,7 +191,7 @@ class Base5ActionRLEnv(gym.Env): | ||||
|         else: | ||||
|             return 0. | ||||
|  | ||||
|     def is_tradesignal(self, action): | ||||
|     def is_tradesignal(self, action: int): | ||||
|         # trade signal | ||||
|         """ | ||||
|         not trade signal is : | ||||
| @@ -200,29 +200,29 @@ class Base5ActionRLEnv(gym.Env): | ||||
|         Action: Short, position: Short -> Hold Short | ||||
|         """ | ||||
|         return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or | ||||
|                     (action == Actions.Neutral.value and self._position == Positions.Short) or | ||||
|                     (action == Actions.Neutral.value and self._position == Positions.Long) or | ||||
|                     (action == Actions.Short_buy.value and self._position == Positions.Short) or | ||||
|                     (action == Actions.Short_sell.value and self._position == Positions.Short) or | ||||
|                     (action == Actions.Short_buy.value and self._position == Positions.Long) or | ||||
|                     (action == Actions.Short_sell.value and self._position == Positions.Short) or | ||||
|                     (action == Actions.Short_sell.value and self._position == Positions.Long) or | ||||
|  | ||||
|                     (action == Actions.Short_sell.value and self._position == Positions.Neutral) or | ||||
|                     (action == Actions.Long_buy.value and self._position == Positions.Long) or | ||||
|                     (action == Actions.Long_sell.value and self._position == Positions.Long) or | ||||
|                     (action == Actions.Long_buy.value and self._position == Positions.Short) or | ||||
|                     (action == Actions.Long_sell.value and self._position == Positions.Short)) | ||||
|                     (action == Actions.Long_sell.value and self._position == Positions.Long) or | ||||
|                     (action == Actions.Long_sell.value and self._position == Positions.Short) or | ||||
|                     (action == Actions.Long_sell.value and self._position == Positions.Neutral)) | ||||
|  | ||||
|     def _is_trade(self, action: Actions): | ||||
|         return ((action == Actions.Long_buy.value and self._position == Positions.Short) or | ||||
|                 (action == Actions.Short_buy.value and self._position == Positions.Long) or | ||||
|                 (action == Actions.Neutral.value and self._position == Positions.Long) or | ||||
|                 (action == Actions.Neutral.value and self._position == Positions.Short) or | ||||
|  | ||||
|                 (action == Actions.Neutral.Short_sell and self._position == Positions.Long) or | ||||
|                 (action == Actions.Neutral.Long_sell and self._position == Positions.Short) | ||||
|                 ) | ||||
|         return ((action == Actions.Long_buy.value and self._position == Positions.Neutral) or | ||||
|                 (action == Actions.Short_buy.value and self._position == Positions.Neutral)) | ||||
|  | ||||
|     def is_hold(self, action): | ||||
|         return ((action == Actions.Short.value and self._position == Positions.Short) | ||||
|                 or (action == Actions.Long.value and self._position == Positions.Long)) | ||||
|         return ((action == Actions.Short_buy.value and self._position == Positions.Short) or | ||||
|                 (action == Actions.Long_buy.value and self._position == Positions.Long) or | ||||
|                 (action == Actions.Neutral.value and self._position == Positions.Long) or | ||||
|                 (action == Actions.Neutral.value and self._position == Positions.Short) or | ||||
|                 (action == Actions.Neutral.value and self._position == Positions.Neutral)) | ||||
|  | ||||
|     def add_buy_fee(self, price): | ||||
|         return price * (1 + self.fee) | ||||
| @@ -240,6 +240,52 @@ class Base5ActionRLEnv(gym.Env): | ||||
|     def get_sharpe_ratio(self): | ||||
|         return mean_over_std(self.get_portfolio_log_returns()) | ||||
|  | ||||
|     def calculate_reward(self, action): | ||||
|  | ||||
|         if self._last_trade_tick is None: | ||||
|             return 0. | ||||
|  | ||||
|         # close long | ||||
|         if action == Actions.Long_sell.value and self._position == Positions.Long: | ||||
|             if len(self.close_trade_profit): | ||||
|                 # aim x2 rw | ||||
|                 if self.close_trade_profit[-1] > self.profit_aim * self.rr: | ||||
|                     last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open) | ||||
|                     current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open) | ||||
|                     return float((np.log(current_price) - np.log(last_trade_price)) * 2) | ||||
|                 # less than aim x1 rw | ||||
|                 elif self.close_trade_profit[-1] < self.profit_aim * self.rr: | ||||
|                     last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open) | ||||
|                     current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open) | ||||
|                     return float(np.log(current_price) - np.log(last_trade_price)) | ||||
|                 # # less than RR SL x2 neg rw | ||||
|                 # elif self.close_trade_profit[-1] < (self.profit_aim * -1): | ||||
|                 #     last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open) | ||||
|                 #     current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open) | ||||
|                 #     return float((np.log(current_price) - np.log(last_trade_price)) * 2) * -1 | ||||
|  | ||||
|  | ||||
|         # close short | ||||
|         if action == Actions.Short_buy.value and self._position == Positions.Short: | ||||
|             if len(self.close_trade_profit): | ||||
|                 # aim x2 rw | ||||
|                 if self.close_trade_profit[-1] > self.profit_aim * self.rr: | ||||
|                     last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open) | ||||
|                     current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open) | ||||
|                     return float((np.log(last_trade_price) - np.log(current_price)) * 2) | ||||
|                 # less than aim x1 rw | ||||
|                 elif self.close_trade_profit[-1] < self.profit_aim * self.rr: | ||||
|                     last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open) | ||||
|                     current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open) | ||||
|                     return float(np.log(last_trade_price) - np.log(current_price)) | ||||
|                 # # less than RR SL x2 neg rw | ||||
|                 # elif self.close_trade_profit[-1] > self.profit_aim * self.rr: | ||||
|                 #     last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open) | ||||
|                 #     current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open) | ||||
|                 #     return float((np.log(last_trade_price) - np.log(current_price)) * 2) * -1 | ||||
|         return 0. | ||||
|  | ||||
|  | ||||
|     def _update_profit(self, action): | ||||
|         # if self._is_trade(action) or self._done: | ||||
|         if self._is_trade(action) or self._done: | ||||
| @@ -255,7 +301,7 @@ class Base5ActionRLEnv(gym.Env): | ||||
|                 self._profits.append((self._current_tick, self._total_profit)) | ||||
|                 self.close_trade_profit.append(pnl) | ||||
|  | ||||
|     def most_recent_return(self, action): | ||||
|     def most_recent_return(self, action: int): | ||||
|         """ | ||||
|         We support Long, Neutral and Short positions. | ||||
|         Return is generated from rising prices in Long | ||||
| @@ -265,7 +311,6 @@ class Base5ActionRLEnv(gym.Env): | ||||
|         # Long positions | ||||
|         if self._position == Positions.Long: | ||||
|             current_price = self.prices.iloc[self._current_tick].open | ||||
|             # if action == Actions.Short.value or action == Actions.Neutral.value: | ||||
|             if action == Actions.Short_buy.value or action == Actions.Neutral.value: | ||||
|                 current_price = self.add_sell_fee(current_price) | ||||
|  | ||||
| @@ -280,7 +325,6 @@ class Base5ActionRLEnv(gym.Env): | ||||
|         # Short positions | ||||
|         if self._position == Positions.Short: | ||||
|             current_price = self.prices.iloc[self._current_tick].open | ||||
|             # if action == Actions.Long.value or action == Actions.Neutral.value: | ||||
|             if action == Actions.Long_buy.value or action == Actions.Neutral.value: | ||||
|                 current_price = self.add_buy_fee(current_price) | ||||
|  | ||||
| @@ -296,9 +340,6 @@ class Base5ActionRLEnv(gym.Env): | ||||
|     def get_portfolio_log_returns(self): | ||||
|         return self.portfolio_log_returns[1:self._current_tick + 1] | ||||
|  | ||||
|     def get_trading_log_return(self): | ||||
|         return self.portfolio_log_returns[self._start_tick:] | ||||
|  | ||||
|     def update_portfolio_log_returns(self, action): | ||||
|         self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action) | ||||
|  | ||||
| @@ -314,37 +355,3 @@ class Base5ActionRLEnv(gym.Env): | ||||
|         returns = np.array(self.close_trade_profit) | ||||
|         reward = (np.mean(returns) - 0. + 1e-9) / (np.std(returns) + 1e-9) | ||||
|         return reward | ||||
|  | ||||
|     def get_bnh_log_return(self): | ||||
|         return np.diff(np.log(self.prices['open'][self._start_tick:])) | ||||
|  | ||||
|     def calculate_reward(self, action): | ||||
|  | ||||
|         if self._last_trade_tick is None: | ||||
|             return 0. | ||||
|  | ||||
|         # close long | ||||
|         if action == Actions.Long_sell.value and self._position == Positions.Long: | ||||
|             last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open) | ||||
|             current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open) | ||||
|             return float(np.log(current_price) - np.log(last_trade_price)) | ||||
|  | ||||
|         if action == Actions.Long_sell.value and self._position == Positions.Long: | ||||
|             if self.close_trade_profit[-1] > self.profit_aim * self.rr: | ||||
|                 last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open) | ||||
|                 current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open) | ||||
|                 return float((np.log(current_price) - np.log(last_trade_price)) * 2) | ||||
|  | ||||
|         # close short | ||||
|         if action == Actions.Short_buy.value and self._position == Positions.Short: | ||||
|             last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open) | ||||
|             current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open) | ||||
|             return float(np.log(last_trade_price) - np.log(current_price)) | ||||
|  | ||||
|         if action == Actions.Short_buy.value and self._position == Positions.Short: | ||||
|             if self.close_trade_profit[-1] > self.profit_aim * self.rr: | ||||
|                 last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open) | ||||
|                 current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open) | ||||
|                 return float((np.log(last_trade_price) - np.log(current_price)) * 2) | ||||
|  | ||||
|         return 0. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user