Add ground work for TensorFlow models, add protections from common mistakes

This commit is contained in:
robcaulk
2022-07-12 18:09:17 +02:00
parent fea63fba12
commit ef409dd345
4 changed files with 44 additions and 21 deletions

View File

@@ -24,8 +24,9 @@ class FreqaiModelResolver(IResolver):
object_type = IFreqaiModel
object_type_str = "FreqaiModel"
user_subdir = USERPATH_FREQAIMODELS
initial_search_path = Path(__file__).parent.parent.joinpath(
"freqai/prediction_models").resolve()
initial_search_path = (
Path(__file__).parent.parent.joinpath("freqai/prediction_models").resolve()
)
@staticmethod
def load_freqaimodel(config: Dict) -> IFreqaiModel:
@@ -33,6 +34,7 @@ class FreqaiModelResolver(IResolver):
Load the custom class from config parameter
:param config: configuration dictionary
"""
disallowed_models = ["BaseRegressionModel", "BaseTensorFlowModel"]
freqaimodel_name = config.get("freqaimodel")
if not freqaimodel_name:
@@ -40,6 +42,11 @@ class FreqaiModelResolver(IResolver):
"No freqaimodel set. Please use `--freqaimodel` to "
"specify the FreqaiModel class to use.\n"
)
if freqaimodel_name in disallowed_models:
raise OperationalException(
f"{freqaimodel_name} is a baseclass and cannot be used directly. User must choose "
"an existing child class or inherit from this baseclass.\n"
)
freqaimodel = FreqaiModelResolver.load_object(
freqaimodel_name,
config,