ghf
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import collections
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
import shutil
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
@@ -94,14 +95,16 @@ class FreqaiDataDrawer:
|
||||
self.save_lock = threading.Lock()
|
||||
self.pair_dict_lock = threading.Lock()
|
||||
self.metric_tracker_lock = threading.Lock()
|
||||
self.limit_ram_use = self.freqai_info.get('limit_ram_usage', False)
|
||||
self.old_DBSCAN_eps: Dict[str, float] = {}
|
||||
|
||||
self.empty_pair_dict: pair_info = {
|
||||
"model_filename": "", "trained_timestamp": 0,
|
||||
"data_path": "", "extras": {}}
|
||||
if 'Reinforcement' in self.config['freqaimodel']:
|
||||
self.model_type = 'stable_baselines'
|
||||
logger.warning('User passed a ReinforcementLearner model, FreqAI will '
|
||||
'now use stable_baselines3 to save models.')
|
||||
if 'rl_config' in self.freqai_info:
|
||||
self.model_type = self.freqai_info['model_save_type']
|
||||
logger.warning(f'User passed a ReinforcementLearner model, FreqAI will '
|
||||
'now use {self.model_type} to save models.')
|
||||
else:
|
||||
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
|
||||
|
||||
@@ -488,6 +491,8 @@ class FreqaiDataDrawer:
|
||||
model.save(save_path / f"{dk.model_filename}_model.h5")
|
||||
elif 'stable_baselines' in self.model_type:
|
||||
model.save(save_path / f"{dk.model_filename}_model.zip")
|
||||
elif 'sb3_contrib' in self.model_type:
|
||||
model.save(save_path / f"{dk.model_filename}_model.zip")
|
||||
|
||||
if dk.svm_model is not None:
|
||||
dump(dk.svm_model, save_path / f"{dk.model_filename}_svm_model.joblib")
|
||||
@@ -565,7 +570,7 @@ class FreqaiDataDrawer:
|
||||
dk.label_list = dk.data["label_list"]
|
||||
|
||||
# try to access model in memory instead of loading object from disk to save time
|
||||
if dk.live and coin in self.model_dictionary:
|
||||
if dk.live and coin in self.model_dictionary and not self.limit_ram_use:
|
||||
model = self.model_dictionary[coin]
|
||||
elif self.model_type == 'joblib':
|
||||
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
|
||||
@@ -576,10 +581,16 @@ class FreqaiDataDrawer:
|
||||
mod = __import__('stable_baselines3', fromlist=[
|
||||
self.freqai_info['rl_config']['model_type']])
|
||||
MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])
|
||||
model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
|
||||
model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model", device="cpu")
|
||||
|
||||
elif self.model_type == 'sb3_contrib':
|
||||
mod = __import__('sb3_contrib', fromlist=[
|
||||
self.freqai_info['rl_config']['model_type']])
|
||||
MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])
|
||||
model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model", device="cpu")
|
||||
|
||||
if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file():
|
||||
dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib")
|
||||
dk.svm_model = load(path=dk.data_path / f"{dk.model_filename}_svm_model.joblib")
|
||||
|
||||
if not model:
|
||||
raise OperationalException(
|
||||
|
Reference in New Issue
Block a user