This commit is contained in:
richardjozsa
2022-11-25 01:00:53 +01:00
parent cd5016d0c7
commit 3939106b5a
4 changed files with 81 additions and 8 deletions

View File

@@ -1,6 +1,7 @@
import collections
import logging
import re
import json
import shutil
import threading
from datetime import datetime, timezone
@@ -94,14 +95,16 @@ class FreqaiDataDrawer:
self.save_lock = threading.Lock()
self.pair_dict_lock = threading.Lock()
self.metric_tracker_lock = threading.Lock()
self.limit_ram_use = self.freqai_info.get('limit_ram_usage', False)
self.old_DBSCAN_eps: Dict[str, float] = {}
self.empty_pair_dict: pair_info = {
"model_filename": "", "trained_timestamp": 0,
"data_path": "", "extras": {}}
if 'Reinforcement' in self.config['freqaimodel']:
self.model_type = 'stable_baselines'
logger.warning('User passed a ReinforcementLearner model, FreqAI will '
'now use stable_baselines3 to save models.')
if 'rl_config' in self.freqai_info:
self.model_type = self.freqai_info['model_save_type']
logger.warning(f'User passed a ReinforcementLearner model, FreqAI will '
'now use {self.model_type} to save models.')
else:
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
@@ -488,6 +491,8 @@ class FreqaiDataDrawer:
model.save(save_path / f"{dk.model_filename}_model.h5")
elif 'stable_baselines' in self.model_type:
model.save(save_path / f"{dk.model_filename}_model.zip")
elif 'sb3_contrib' in self.model_type:
model.save(save_path / f"{dk.model_filename}_model.zip")
if dk.svm_model is not None:
dump(dk.svm_model, save_path / f"{dk.model_filename}_svm_model.joblib")
@@ -565,7 +570,7 @@ class FreqaiDataDrawer:
dk.label_list = dk.data["label_list"]
# try to access model in memory instead of loading object from disk to save time
if dk.live and coin in self.model_dictionary:
if dk.live and coin in self.model_dictionary and not self.limit_ram_use:
model = self.model_dictionary[coin]
elif self.model_type == 'joblib':
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
@@ -576,10 +581,16 @@ class FreqaiDataDrawer:
mod = __import__('stable_baselines3', fromlist=[
self.freqai_info['rl_config']['model_type']])
MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])
model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model", device="cpu")
elif self.model_type == 'sb3_contrib':
mod = __import__('sb3_contrib', fromlist=[
self.freqai_info['rl_config']['model_type']])
MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])
model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model", device="cpu")
if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file():
dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib")
dk.svm_model = load(path=dk.data_path / f"{dk.model_filename}_svm_model.joblib")
if not model:
raise OperationalException(