add tests. add guardrails.
This commit is contained in:
@@ -21,7 +21,7 @@ from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
from freqtrade.persistence import Trade
|
||||
|
||||
import pytest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,7 +45,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
self.eval_callback: EvalCallback = None
|
||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||
self.rl_config = self.freqai_info['rl_config']
|
||||
self.continual_learning = self.rl_config.get('continual_learning', False)
|
||||
self.continual_learning = self.freqai_info.get('continual_learning', False)
|
||||
if self.model_type in SB3_MODELS:
|
||||
import_str = 'stable_baselines3'
|
||||
elif self.model_type in SB3_CONTRIB_MODELS:
|
||||
@@ -59,14 +59,30 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
self.model_type])
|
||||
self.MODELCLASS = getattr(mod, self.model_type)
|
||||
self.policy_type = self.freqai_info['rl_config']['policy_type']
|
||||
self.unset_outlier_removal()
|
||||
|
||||
def unset_outlier_removal(self):
|
||||
"""
|
||||
If user has activated any function that may remove training points, this
|
||||
function will set them to false and warn them
|
||||
"""
|
||||
if self.ft_params.get('use_SVM_to_remove_outliers', False):
|
||||
self.ft_params.update({'use_SVM_to_remove_outliers': False})
|
||||
logger.warning('User tried to use SVM with RL. Deactivating SVM.')
|
||||
if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
|
||||
self.ft_params.update({'use_SVM_to_remove_outliers': False})
|
||||
logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
|
||||
if self.freqai_info['data_split_parameters'].get('shuffle', False):
|
||||
self.freqai_info['data_split_parameters'].update('shuffle', False)
|
||||
logger.warning('User tried to shuffle training data. Setting shuffle to False')
|
||||
|
||||
def train(
|
||||
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
||||
self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
|
||||
for storing, saving, loading, and analyzing the data.
|
||||
:param unfiltered_dataframe: Full dataframe for the current training period
|
||||
:param unfiltered_df: Full dataframe for the current training period
|
||||
:param metadata: pair metadata from strategy.
|
||||
:returns:
|
||||
:model: Trained model which can be used to inference (self.predict)
|
||||
@@ -75,7 +91,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
logger.info("--------------------Starting training " f"{pair} --------------------")
|
||||
|
||||
features_filtered, labels_filtered = dk.filter_features(
|
||||
unfiltered_dataframe,
|
||||
unfiltered_df,
|
||||
dk.training_features_list,
|
||||
dk.label_list,
|
||||
training_filter=True,
|
||||
@@ -99,7 +115,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
|
||||
|
||||
model = self.fit_rl(data_dictionary, dk)
|
||||
model = self.fit(data_dictionary, dk)
|
||||
|
||||
logger.info(f"--------------------done training {pair}--------------------")
|
||||
|
||||
@@ -124,7 +140,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
best_model_save_path=str(dk.data_path))
|
||||
|
||||
@abstractmethod
|
||||
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
|
||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||
"""
|
||||
Agent customizations and abstract Reinforcement Learning customizations
|
||||
go in here. Abstract method, so this function must be overridden by
|
||||
@@ -142,6 +158,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
# FIXME: mypy typing doesnt like that strategy may be "None" (it never will be)
|
||||
# FIXME: get_rate and trade_udration shouldn't work with backtesting,
|
||||
# we need to use candle dates and prices to compute that.
|
||||
pytest.set_trace()
|
||||
current_value = self.strategy.dp._exchange.get_rate(
|
||||
pair, refresh=False, side="exit", is_short=trade.is_short)
|
||||
openrate = trade.open_rate
|
||||
@@ -162,7 +179,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
return market_side, current_profit, int(trade_duration)
|
||||
|
||||
def predict(
|
||||
self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False
|
||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||
"""
|
||||
Filter the prediction features data and predict with it.
|
||||
@@ -173,9 +190,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
data (NaNs) or felt uncertain about data (PCA and DI index)
|
||||
"""
|
||||
|
||||
dk.find_features(unfiltered_dataframe)
|
||||
dk.find_features(unfiltered_df)
|
||||
filtered_dataframe, _ = dk.filter_features(
|
||||
unfiltered_dataframe, dk.training_features_list, training_filter=False
|
||||
unfiltered_df, dk.training_features_list, training_filter=False
|
||||
)
|
||||
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
|
||||
dk.data_dictionary["prediction_features"] = filtered_dataframe
|
||||
@@ -305,8 +322,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
# But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor
|
||||
# all the other existing fit() functions to include dk argument. For now we instantiate and
|
||||
# leave it.
|
||||
def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any:
|
||||
return
|
||||
# def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any:
|
||||
# return
|
||||
|
||||
|
||||
def make_env(MyRLEnv: BaseEnvironment, env_id: str, rank: int,
|
||||
|
Reference in New Issue
Block a user