Merge 7cbc0ce80a
into 0afd5a7385
This commit is contained in:
commit
5769041c21
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
|
||||
@ -94,9 +94,12 @@ class Base3ActionRLEnv(BaseEnvironment):
|
||||
|
||||
observation = self._get_observation()
|
||||
|
||||
#user can play with time if they want
|
||||
truncated = False
|
||||
|
||||
self._update_history(info)
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
return observation, step_reward, self._done,truncated, info
|
||||
|
||||
def is_tradesignal(self, action: int) -> bool:
|
||||
"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
|
||||
@ -96,9 +96,12 @@ class Base4ActionRLEnv(BaseEnvironment):
|
||||
|
||||
observation = self._get_observation()
|
||||
|
||||
#user can play with time if they want
|
||||
truncated = False
|
||||
|
||||
self._update_history(info)
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
return observation, step_reward, self._done,truncated, info
|
||||
|
||||
def is_tradesignal(self, action: int) -> bool:
|
||||
"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
|
||||
@ -101,10 +101,12 @@ class Base5ActionRLEnv(BaseEnvironment):
|
||||
)
|
||||
|
||||
observation = self._get_observation()
|
||||
#user can play with time if they want
|
||||
truncated = False
|
||||
|
||||
self._update_history(info)
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
return observation, step_reward, self._done,truncated, info
|
||||
|
||||
def is_tradesignal(self, action: int) -> bool:
|
||||
"""
|
||||
|
@ -4,11 +4,11 @@ from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
from gymnasium import spaces
|
||||
from gymnasium.utils import seeding
|
||||
from pandas import DataFrame
|
||||
|
||||
|
||||
@ -203,7 +203,7 @@ class BaseEnvironment(gym.Env):
|
||||
self.close_trade_profit = []
|
||||
self._total_unrealized_profit = 1
|
||||
|
||||
return self._get_observation()
|
||||
return self._get_observation(), self.history
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action: int):
|
||||
|
@ -6,7 +6,7 @@ from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pandas as pd
|
||||
@ -433,7 +433,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||
seed: int, train_df: DataFrame, price: DataFrame,
|
||||
monitor: bool = False,
|
||||
env_info: Dict[str, Any] = {}) -> Callable:
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
@ -450,8 +449,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||
|
||||
env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank,
|
||||
**env_info)
|
||||
if monitor:
|
||||
env = Monitor(env)
|
||||
|
||||
return env
|
||||
set_random_seed(seed)
|
||||
return _init
|
||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict
|
||||
from pandas import DataFrame
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
from stable_baselines3.common.vec_env import VecMonitor
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
|
||||
@ -41,22 +41,27 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
|
||||
|
||||
env_info = self.pack_env_dict(dk.pair)
|
||||
|
||||
eval_freq = len(train_df) // self.max_threads
|
||||
|
||||
env_id = "train_env"
|
||||
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
|
||||
self.train_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
|
||||
train_df, prices_train,
|
||||
monitor=True,
|
||||
|
||||
env_info=env_info) for i
|
||||
in range(self.max_threads)])
|
||||
in range(self.max_threads)]))
|
||||
|
||||
eval_env_id = 'eval_env'
|
||||
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
||||
self.eval_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
||||
test_df, prices_test,
|
||||
monitor=True,
|
||||
|
||||
env_info=env_info) for i
|
||||
in range(self.max_threads)])
|
||||
in range(self.max_threads)]))
|
||||
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
render=False, eval_freq=eval_freq,
|
||||
best_model_save_path=str(dk.data_path))
|
||||
|
||||
|
||||
# TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS, IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!
|
||||
actions = self.train_env.env_method("get_actions")[0]
|
||||
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||
|
@ -3,8 +3,10 @@
|
||||
|
||||
# Required for freqai-rl
|
||||
torch==1.13.1; python_version < '3.11'
|
||||
stable-baselines3==1.7.0; python_version < '3.11'
|
||||
sb3-contrib==1.7.0; python_version < '3.11'
|
||||
#until these branches will be released we can use this
|
||||
git+https://github.com/Farama-Foundation/Gymnasium@main
|
||||
git+https://github.com/DLR-RM/stable-baselines3@feat/gymnasium-support
|
||||
git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib@feat/gymnasium-support
|
||||
# Gym is forced to this version by stable-baselines3.
|
||||
setuptools==65.5.1 # Should be removed when gym is fixed.
|
||||
gym==0.21; python_version < '3.11'
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user