Add ground work for TensorFlow models, add protections from common mistakes
This commit is contained in:
@@ -24,8 +24,9 @@ class FreqaiModelResolver(IResolver):
|
||||
object_type = IFreqaiModel
|
||||
object_type_str = "FreqaiModel"
|
||||
user_subdir = USERPATH_FREQAIMODELS
|
||||
initial_search_path = Path(__file__).parent.parent.joinpath(
|
||||
"freqai/prediction_models").resolve()
|
||||
initial_search_path = (
|
||||
Path(__file__).parent.parent.joinpath("freqai/prediction_models").resolve()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_freqaimodel(config: Dict) -> IFreqaiModel:
|
||||
@@ -33,6 +34,7 @@ class FreqaiModelResolver(IResolver):
|
||||
Load the custom class from config parameter
|
||||
:param config: configuration dictionary
|
||||
"""
|
||||
disallowed_models = ["BaseRegressionModel", "BaseTensorFlowModel"]
|
||||
|
||||
freqaimodel_name = config.get("freqaimodel")
|
||||
if not freqaimodel_name:
|
||||
@@ -40,6 +42,11 @@ class FreqaiModelResolver(IResolver):
|
||||
"No freqaimodel set. Please use `--freqaimodel` to "
|
||||
"specify the FreqaiModel class to use.\n"
|
||||
)
|
||||
if freqaimodel_name in disallowed_models:
|
||||
raise OperationalException(
|
||||
f"{freqaimodel_name} is a baseclass and cannot be used directly. User must choose "
|
||||
"an existing child class or inherit from this baseclass.\n"
|
||||
)
|
||||
freqaimodel = FreqaiModelResolver.load_object(
|
||||
freqaimodel_name,
|
||||
config,
|
||||
|
Reference in New Issue
Block a user