This commit is contained in:
Richard Jozsa 2023-04-12 09:19:24 -07:00 committed by GitHub
commit 5769041c21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 38 additions and 25 deletions

View File

@ -1,7 +1,7 @@
import logging import logging
from enum import Enum from enum import Enum
from gym import spaces from gymnasium import spaces
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
@ -94,9 +94,12 @@ class Base3ActionRLEnv(BaseEnvironment):
observation = self._get_observation() observation = self._get_observation()
#user can play with time if they want
truncated = False
self._update_history(info) self._update_history(info)
return observation, step_reward, self._done, info return observation, step_reward, self._done,truncated, info
def is_tradesignal(self, action: int) -> bool: def is_tradesignal(self, action: int) -> bool:
""" """

View File

@ -1,7 +1,7 @@
import logging import logging
from enum import Enum from enum import Enum
from gym import spaces from gymnasium import spaces
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
@ -96,9 +96,12 @@ class Base4ActionRLEnv(BaseEnvironment):
observation = self._get_observation() observation = self._get_observation()
#user can play with time if they want
truncated = False
self._update_history(info) self._update_history(info)
return observation, step_reward, self._done, info return observation, step_reward, self._done,truncated, info
def is_tradesignal(self, action: int) -> bool: def is_tradesignal(self, action: int) -> bool:
""" """

View File

@ -1,7 +1,7 @@
import logging import logging
from enum import Enum from enum import Enum
from gym import spaces from gymnasium import spaces
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
@ -101,10 +101,12 @@ class Base5ActionRLEnv(BaseEnvironment):
) )
observation = self._get_observation() observation = self._get_observation()
#user can play with time if they want
truncated = False
self._update_history(info) self._update_history(info)
return observation, step_reward, self._done, info return observation, step_reward, self._done,truncated, info
def is_tradesignal(self, action: int) -> bool: def is_tradesignal(self, action: int) -> bool:
""" """

View File

@ -4,11 +4,11 @@ from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Optional, Type, Union from typing import Optional, Type, Union
import gym import gymnasium as gym
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from gym import spaces from gymnasium import spaces
from gym.utils import seeding from gymnasium.utils import seeding
from pandas import DataFrame from pandas import DataFrame
@ -203,7 +203,7 @@ class BaseEnvironment(gym.Env):
self.close_trade_profit = [] self.close_trade_profit = []
self._total_unrealized_profit = 1 self._total_unrealized_profit = 1
return self._get_observation() return self._get_observation(), self.history
@abstractmethod @abstractmethod
def step(self, action: int): def step(self, action: int):

View File

@ -6,7 +6,7 @@ from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
import gym import gymnasium as gym
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import pandas as pd import pandas as pd
@ -433,7 +433,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame, seed: int, train_df: DataFrame, price: DataFrame,
monitor: bool = False,
env_info: Dict[str, Any] = {}) -> Callable: env_info: Dict[str, Any] = {}) -> Callable:
""" """
Utility function for multiprocessed env. Utility function for multiprocessed env.
@ -450,8 +449,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank, env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank,
**env_info) **env_info)
if monitor:
env = Monitor(env)
return env return env
set_random_seed(seed) set_random_seed(seed)
return _init return _init

View File

@ -4,7 +4,7 @@ from typing import Any, Dict
from pandas import DataFrame from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env import VecMonitor
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
@ -41,22 +41,27 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
env_info = self.pack_env_dict(dk.pair) env_info = self.pack_env_dict(dk.pair)
eval_freq = len(train_df) // self.max_threads
env_id = "train_env" env_id = "train_env"
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, self.train_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
train_df, prices_train, train_df, prices_train,
monitor=True,
env_info=env_info) for i env_info=env_info) for i
in range(self.max_threads)]) in range(self.max_threads)]))
eval_env_id = 'eval_env' eval_env_id = 'eval_env'
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, self.eval_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
test_df, prices_test, test_df, prices_test,
monitor=True,
env_info=env_info) for i env_info=env_info) for i
in range(self.max_threads)]) in range(self.max_threads)]))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True, self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df), render=False, eval_freq=eval_freq,
best_model_save_path=str(dk.data_path)) best_model_save_path=str(dk.data_path))
# TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS, IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!
actions = self.train_env.env_method("get_actions")[0] actions = self.train_env.env_method("get_actions")[0]
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)

View File

@ -3,8 +3,10 @@
# Required for freqai-rl # Required for freqai-rl
torch==1.13.1; python_version < '3.11' torch==1.13.1; python_version < '3.11'
stable-baselines3==1.7.0; python_version < '3.11' #until these branches will be released we can use this
sb3-contrib==1.7.0; python_version < '3.11' git+https://github.com/Farama-Foundation/Gymnasium@main
git+https://github.com/DLR-RM/stable-baselines3@feat/gymnasium-support
git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib@feat/gymnasium-support
# Gym is forced to this version by stable-baselines3. # Gym is forced to this version by stable-baselines3.
setuptools==65.5.1 # Should be removed when gym is fixed. setuptools==65.5.1 # Should be removed when gym is fixed.
gym==0.21; python_version < '3.11'