fix model loading from disk bug, improve doc, clarify installation/docker instructions, add a torch tag to the freqairl docker image. Fix seriously outdated prediction_model docstrings

This commit is contained in:
robcaulk 2023-04-08 12:09:53 +02:00
parent a655524221
commit 48d3c8e62e
21 changed files with 195 additions and 83 deletions

View File

@ -7,6 +7,7 @@ TAG=$(echo "${BRANCH_NAME}" | sed -e "s/\//_/g")
TAG_PLOT=${TAG}_plot TAG_PLOT=${TAG}_plot
TAG_FREQAI=${TAG}_freqai TAG_FREQAI=${TAG}_freqai
TAG_FREQAI_RL=${TAG_FREQAI}rl TAG_FREQAI_RL=${TAG_FREQAI}rl
TAG_FREQAI_RL=${TAG_FREQAI}torch
TAG_PI="${TAG}_pi" TAG_PI="${TAG}_pi"
PI_PLATFORM="linux/arm/v7" PI_PLATFORM="linux/arm/v7"
@ -64,6 +65,7 @@ docker build --cache-from freqtrade:${TAG_FREQAI} --build-arg sourceimage=${CACH
docker tag freqtrade:$TAG_PLOT ${CACHE_IMAGE}:$TAG_PLOT docker tag freqtrade:$TAG_PLOT ${CACHE_IMAGE}:$TAG_PLOT
docker tag freqtrade:$TAG_FREQAI ${CACHE_IMAGE}:$TAG_FREQAI docker tag freqtrade:$TAG_FREQAI ${CACHE_IMAGE}:$TAG_FREQAI
docker tag freqtrade:$TAG_FREQAI_RL ${CACHE_IMAGE}:$TAG_FREQAI_RL docker tag freqtrade:$TAG_FREQAI_RL ${CACHE_IMAGE}:$TAG_FREQAI_RL
docker tag freqtrade:$TAG_FREQAI_RL ${CACHE_IMAGE}:$TAG_FREQAI_TORCH
# Run backtest # Run backtest
docker run --rm -v $(pwd)/config_examples/config_bittrex.example.json:/freqtrade/config.json:ro -v $(pwd)/tests:/tests freqtrade:${TAG} backtesting --datadir /tests/testdata --strategy-path /tests/strategy/strats/ --strategy StrategyTestV3 docker run --rm -v $(pwd)/config_examples/config_bittrex.example.json:/freqtrade/config.json:ro -v $(pwd)/tests:/tests freqtrade:${TAG} backtesting --datadir /tests/testdata --strategy-path /tests/strategy/strats/ --strategy StrategyTestV3

View File

@ -237,7 +237,7 @@ df['&s-up_or_down'] = np.where( df["close"].shift(-100) > df["close"], 'up', 'do
df['&s-up_or_down'] = np.where( df["close"].shift(-100) == df["close"], 'same', df['&s-up_or_down']) df['&s-up_or_down'] = np.where( df["close"].shift(-100) == df["close"], 'same', df['&s-up_or_down'])
``` ```
## PyTorch Models ## PyTorch Module
### Quick start ### Quick start
@ -247,14 +247,16 @@ The easiest way to quickly run a pytorch model is with the following command (fo
freqtrade trade --config config_examples/config_freqai.example.json --strategy FreqaiExampleStrategy --freqaimodel PyTorchMLPRegressor --strategy-path freqtrade/templates freqtrade trade --config config_examples/config_freqai.example.json --strategy FreqaiExampleStrategy --freqaimodel PyTorchMLPRegressor --strategy-path freqtrade/templates
``` ```
!!! note "Installation/docker"
The PyTorch module requires large packages such as `torch`, which should be explicitly requested during `./setup.sh -i` by answering "y" to the question "Do you also want dependencies for freqai-rl or PyTorch (~700mb additional space required) [y/N]?".
Users who prefer docker should ensure they use the docker image appended with `_freqaitorch`.
### Structure ### Structure
#### Model #### Model
You can use any pytorch model. Here is an example of logistic regression model implementation using pytorch (should be used with nn.BCELoss criterion) for classification tasks. You can construct your own Neural Network architecture in PyTorch by simply defining your `nn.Module` class inside your custom [`IFreqaiModel` file](#using-different-prediction-models) and then using that class in your `def train()` function. Here is an example of logistic regression model implementation using PyTorch (should be used with nn.BCELoss criterion) for classification tasks.
```python ```python
import torch.nn as nn
import torch
class LogisticRegression(nn.Module): class LogisticRegression(nn.Module):
def __init__(self, input_size: int): def __init__(self, input_size: int):
@ -268,11 +270,59 @@ class LogisticRegression(nn.Module):
out = self.linear(x) out = self.linear(x)
out = self.activation(out) out = self.activation(out)
return out return out
class MyCoolPyTorchClassifier(BasePyTorchClassifier):
"""
This is a custom IFreqaiModel showing how a user might setup their own
custom Neural Network architecture for their training.
"""
@property
def data_convertor(self) -> PyTorchDataConvertor:
return DefaultPyTorchDataConvertor(target_tensor_type=torch.float)
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
config = self.freqai_info.get("model_training_parameters", {})
self.learning_rate: float = config.get("learning_rate", 3e-4)
self.model_kwargs: Dict[str, Any] = config.get("model_kwargs", {})
self.trainer_kwargs: Dict[str, Any] = config.get("trainer_kwargs", {})
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
"""
User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary holding all data for train, test,
labels, weights
:param dk: The datakitchen object for the current coin/model
"""
class_names = self.get_class_names()
self.convert_label_column_to_int(data_dictionary, dk, class_names)
n_features = data_dictionary["train_features"].shape[-1]
model = LogisticRegression(
input_dim=n_features
)
model.to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
criterion = torch.nn.CrossEntropyLoss()
init_model = self.get_init_model(dk.pair)
trainer = PyTorchModelTrainer(
model=model,
optimizer=optimizer,
criterion=criterion,
model_meta_data={"class_names": class_names},
device=self.device,
init_model=init_model,
data_convertor=self.data_convertor,
**self.trainer_kwargs,
)
trainer.fit(data_dictionary, self.splits)
return trainer
``` ```
#### Trainer #### Trainer
The `PyTorchModelTrainer` performs the idiomatic pytorch train loop: The `PyTorchModelTrainer` performs the idiomatic PyTorch train loop:
Define our model, loss function, and optimizer, and then move them to the appropriate device (GPU or CPU). Inside the loop, we iterate through the batches in the dataloader, move the data to the device, compute the prediction and loss, backpropagate, and update the model parameters using the optimizer. Define our model, loss function, and optimizer, and then move them to the appropriate device (GPU or CPU). Inside the loop, we iterate through the batches in the dataloader, move the data to the device, compute the prediction and loss, backpropagate, and update the model parameters using the optimizer.
In addition, the trainer is responsible for the following: In addition, the trainer is responsible for the following:

View File

@ -539,7 +539,9 @@ class FreqaiDataDrawer:
model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model") model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
elif self.model_type == 'pytorch': elif self.model_type == 'pytorch':
import torch import torch
model = torch.load(dk.data_path / f"{dk.model_filename}_model.zip") zip = torch.load(dk.data_path / f"{dk.model_filename}_model.zip")
model = zip["pytrainer"]
model = model.load_from_checkpoint(zip)
if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file(): if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file():
dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib") dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib")

View File

@ -14,16 +14,20 @@ logger = logging.getLogger(__name__)
class CatboostClassifier(BaseClassifierModel): class CatboostClassifier(BaseClassifierModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
train_data = Pool( train_data = Pool(

View File

@ -15,16 +15,20 @@ logger = logging.getLogger(__name__)
class CatboostClassifierMultiTarget(BaseClassifierModel): class CatboostClassifierMultiTarget(BaseClassifierModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
cbc = CatBoostClassifier( cbc = CatBoostClassifier(

View File

@ -14,16 +14,20 @@ logger = logging.getLogger(__name__)
class CatboostRegressor(BaseRegressionModel): class CatboostRegressor(BaseRegressionModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
train_data = Pool( train_data = Pool(

View File

@ -15,16 +15,20 @@ logger = logging.getLogger(__name__)
class CatboostRegressorMultiTarget(BaseRegressionModel): class CatboostRegressorMultiTarget(BaseRegressionModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
cbr = CatBoostRegressor( cbr = CatBoostRegressor(

View File

@ -12,16 +12,20 @@ logger = logging.getLogger(__name__)
class LightGBMClassifier(BaseClassifierModel): class LightGBMClassifier(BaseClassifierModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0: if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0:

View File

@ -13,16 +13,20 @@ logger = logging.getLogger(__name__)
class LightGBMClassifierMultiTarget(BaseClassifierModel): class LightGBMClassifierMultiTarget(BaseClassifierModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
lgb = LGBMClassifier(**self.model_training_parameters) lgb = LGBMClassifier(**self.model_training_parameters)

View File

@ -12,18 +12,20 @@ logger = logging.getLogger(__name__)
class LightGBMRegressor(BaseRegressionModel): class LightGBMRegressor(BaseRegressionModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
Most regressors use the same function names and arguments e.g. user User sets up the training and test data to fit their desired model here
can drop in LGBMRegressor in place of CatBoostRegressor and all data :param data_dictionary: the dictionary holding all data for train, test,
management will be properly handled by Freqai. labels, weights
:param data_dictionary: the dictionary constructed by DataHandler to hold :param dk: The datakitchen object for the current coin/model
all the training and test data/labels.
""" """
if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0: if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0:

View File

@ -13,16 +13,20 @@ logger = logging.getLogger(__name__)
class LightGBMRegressorMultiTarget(BaseRegressionModel): class LightGBMRegressorMultiTarget(BaseRegressionModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
lgb = LGBMRegressor(**self.model_training_parameters) lgb = LGBMRegressor(**self.model_training_parameters)

View File

@ -57,8 +57,9 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
:raises ValueError: If self.class_names is not defined in the parent class. :raises ValueError: If self.class_names is not defined in the parent class.
""" """

View File

@ -55,8 +55,9 @@ class PyTorchMLPRegressor(BasePyTorchRegressor):
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
n_features = data_dictionary["train_features"].shape[-1] n_features = data_dictionary["train_features"].shape[-1]

View File

@ -18,16 +18,20 @@ logger = logging.getLogger(__name__)
class XGBoostClassifier(BaseClassifierModel): class XGBoostClassifier(BaseClassifierModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
X = data_dictionary["train_features"].to_numpy() X = data_dictionary["train_features"].to_numpy()

View File

@ -18,16 +18,20 @@ logger = logging.getLogger(__name__)
class XGBoostRFClassifier(BaseClassifierModel): class XGBoostRFClassifier(BaseClassifierModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
X = data_dictionary["train_features"].to_numpy() X = data_dictionary["train_features"].to_numpy()

View File

@ -12,16 +12,20 @@ logger = logging.getLogger(__name__)
class XGBoostRFRegressor(BaseRegressionModel): class XGBoostRFRegressor(BaseRegressionModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
X = data_dictionary["train_features"] X = data_dictionary["train_features"]

View File

@ -12,16 +12,20 @@ logger = logging.getLogger(__name__)
class XGBoostRegressor(BaseRegressionModel): class XGBoostRegressor(BaseRegressionModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
X = data_dictionary["train_features"] X = data_dictionary["train_features"]

View File

@ -13,16 +13,20 @@ logger = logging.getLogger(__name__)
class XGBoostRegressorMultiTarget(BaseRegressionModel): class XGBoostRegressorMultiTarget(BaseRegressionModel):
""" """
User created prediction model. The class needs to override three necessary User created prediction model. The class inherits IFreqaiModel, which
functions, predict(), train(), fit(). The class inherits ModelHandler which means it has full access to all Frequency AI functionality. Typically,
has its own DataHandler where data is held, saved, loaded, and managed. users would use this to override the common `fit()`, `train()`, or
`predict()` methods to add their custom data handling tools or change
various aspects of the training that cannot be configured via the
top level config.json file.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary holding all data for train, test,
all the training and test data/labels. labels, weights
:param dk: The datakitchen object for the current coin/model
""" """
xgb = XGBRegressor(**self.model_training_parameters) xgb = XGBRegressor(**self.model_training_parameters)

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional
import pandas as pd import pandas as pd
import torch import torch
import torch.nn as nn from torch import nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader, TensorDataset
@ -169,6 +169,12 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
n_batches = math.ceil(n_obs // batch_size) n_batches = math.ceil(n_obs // batch_size)
epochs = math.ceil(n_iters // n_batches) epochs = math.ceil(n_iters // n_batches)
if epochs <= 10:
logger.warning("User set `max_iters` in such a way that the trainer will only perform "
f" {epochs} epochs. Please consider increasing this value accordingly")
if epochs <= 1:
logger.warning("Epochs set to 1. Please review your `max_iters` value")
epochs = 1
return epochs return epochs
def save(self, path: Path): def save(self, path: Path):
@ -182,6 +188,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
"model_state_dict": self.model.state_dict(), "model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(),
"model_meta_data": self.model_meta_data, "model_meta_data": self.model_meta_data,
"pytrainer": self
}, path) }, path)
def load(self, path: Path): def load(self, path: Path):
@ -195,7 +202,6 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
you can access this dict from any class that inherits IFreqaiModel by calling you can access this dict from any class that inherits IFreqaiModel by calling
get_init_model method. get_init_model method.
""" """
self.model.load_state_dict(checkpoint["model_state_dict"]) self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.model_meta_data = checkpoint["model_meta_data"] self.model_meta_data = checkpoint["model_meta_data"]

View File

@ -4,7 +4,7 @@ from typing import Dict, List
import pandas as pd import pandas as pd
import torch import torch
import torch.nn as nn from torch import nn
class PyTorchTrainerInterface(ABC): class PyTorchTrainerInterface(ABC):

View File

@ -85,7 +85,7 @@ function updateenv() {
if [[ $REPLY =~ ^[Yy]$ ]] if [[ $REPLY =~ ^[Yy]$ ]]
then then
REQUIREMENTS_FREQAI="-r requirements-freqai.txt --use-pep517" REQUIREMENTS_FREQAI="-r requirements-freqai.txt --use-pep517"
read -p "Do you also want dependencies for freqai-rl (~700mb additional space required) [y/N]? " read -p "Do you also want dependencies for freqai-rl or PyTorch (~700mb additional space required) [y/N]? "
if [[ $REPLY =~ ^[Yy]$ ]] if [[ $REPLY =~ ^[Yy]$ ]]
then then
REQUIREMENTS_FREQAI="-r requirements-freqai-rl.txt" REQUIREMENTS_FREQAI="-r requirements-freqai-rl.txt"