fix model loading from disk bug, improve doc, clarify installation/docker instructions, add a torch tag to the freqairl docker image. Fix seriously outdated prediction_model docstrings
This commit is contained in:
parent
a655524221
commit
48d3c8e62e
@ -7,6 +7,7 @@ TAG=$(echo "${BRANCH_NAME}" | sed -e "s/\//_/g")
|
||||
TAG_PLOT=${TAG}_plot
|
||||
TAG_FREQAI=${TAG}_freqai
|
||||
TAG_FREQAI_RL=${TAG_FREQAI}rl
|
||||
TAG_FREQAI_RL=${TAG_FREQAI}torch
|
||||
TAG_PI="${TAG}_pi"
|
||||
|
||||
PI_PLATFORM="linux/arm/v7"
|
||||
@ -64,6 +65,7 @@ docker build --cache-from freqtrade:${TAG_FREQAI} --build-arg sourceimage=${CACH
|
||||
docker tag freqtrade:$TAG_PLOT ${CACHE_IMAGE}:$TAG_PLOT
|
||||
docker tag freqtrade:$TAG_FREQAI ${CACHE_IMAGE}:$TAG_FREQAI
|
||||
docker tag freqtrade:$TAG_FREQAI_RL ${CACHE_IMAGE}:$TAG_FREQAI_RL
|
||||
docker tag freqtrade:$TAG_FREQAI_RL ${CACHE_IMAGE}:$TAG_FREQAI_TORCH
|
||||
|
||||
# Run backtest
|
||||
docker run --rm -v $(pwd)/config_examples/config_bittrex.example.json:/freqtrade/config.json:ro -v $(pwd)/tests:/tests freqtrade:${TAG} backtesting --datadir /tests/testdata --strategy-path /tests/strategy/strats/ --strategy StrategyTestV3
|
||||
|
@ -237,7 +237,7 @@ df['&s-up_or_down'] = np.where( df["close"].shift(-100) > df["close"], 'up', 'do
|
||||
df['&s-up_or_down'] = np.where( df["close"].shift(-100) == df["close"], 'same', df['&s-up_or_down'])
|
||||
```
|
||||
|
||||
## PyTorch Models
|
||||
## PyTorch Module
|
||||
|
||||
### Quick start
|
||||
|
||||
@ -247,14 +247,16 @@ The easiest way to quickly run a pytorch model is with the following command (fo
|
||||
freqtrade trade --config config_examples/config_freqai.example.json --strategy FreqaiExampleStrategy --freqaimodel PyTorchMLPRegressor --strategy-path freqtrade/templates
|
||||
```
|
||||
|
||||
!!! note "Installation/docker"
|
||||
The PyTorch module requires large packages such as `torch`, which should be explicitly requested during `./setup.sh -i` by answering "y" to the question "Do you also want dependencies for freqai-rl or PyTorch (~700mb additional space required) [y/N]?".
|
||||
Users who prefer docker should ensure they use the docker image appended with `_freqaitorch`.
|
||||
|
||||
### Structure
|
||||
|
||||
#### Model
|
||||
You can use any pytorch model. Here is an example of logistic regression model implementation using pytorch (should be used with nn.BCELoss criterion) for classification tasks.
|
||||
You can construct your own Neural Network architecture in PyTorch by simply defining your `nn.Module` class inside your custom [`IFreqaiModel` file](#using-different-prediction-models) and then using that class in your `def train()` function. Here is an example of logistic regression model implementation using PyTorch (should be used with nn.BCELoss criterion) for classification tasks.
|
||||
|
||||
```python
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
class LogisticRegression(nn.Module):
|
||||
def __init__(self, input_size: int):
|
||||
@ -268,11 +270,59 @@ class LogisticRegression(nn.Module):
|
||||
out = self.linear(x)
|
||||
out = self.activation(out)
|
||||
return out
|
||||
|
||||
class MyCoolPyTorchClassifier(BasePyTorchClassifier):
|
||||
"""
|
||||
This is a custom IFreqaiModel showing how a user might setup their own
|
||||
custom Neural Network architecture for their training.
|
||||
"""
|
||||
|
||||
@property
|
||||
def data_convertor(self) -> PyTorchDataConvertor:
|
||||
return DefaultPyTorchDataConvertor(target_tensor_type=torch.float)
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
config = self.freqai_info.get("model_training_parameters", {})
|
||||
self.learning_rate: float = config.get("learning_rate", 3e-4)
|
||||
self.model_kwargs: Dict[str, Any] = config.get("model_kwargs", {})
|
||||
self.trainer_kwargs: Dict[str, Any] = config.get("trainer_kwargs", {})
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
class_names = self.get_class_names()
|
||||
self.convert_label_column_to_int(data_dictionary, dk, class_names)
|
||||
n_features = data_dictionary["train_features"].shape[-1]
|
||||
model = LogisticRegression(
|
||||
input_dim=n_features
|
||||
)
|
||||
model.to(self.device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
init_model = self.get_init_model(dk.pair)
|
||||
trainer = PyTorchModelTrainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
model_meta_data={"class_names": class_names},
|
||||
device=self.device,
|
||||
init_model=init_model,
|
||||
data_convertor=self.data_convertor,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
trainer.fit(data_dictionary, self.splits)
|
||||
return trainer
|
||||
|
||||
```
|
||||
|
||||
|
||||
#### Trainer
|
||||
The `PyTorchModelTrainer` performs the idiomatic pytorch train loop:
|
||||
The `PyTorchModelTrainer` performs the idiomatic PyTorch train loop:
|
||||
Define our model, loss function, and optimizer, and then move them to the appropriate device (GPU or CPU). Inside the loop, we iterate through the batches in the dataloader, move the data to the device, compute the prediction and loss, backpropagate, and update the model parameters using the optimizer.
|
||||
|
||||
In addition, the trainer is responsible for the following:
|
||||
|
@ -539,7 +539,9 @@ class FreqaiDataDrawer:
|
||||
model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
|
||||
elif self.model_type == 'pytorch':
|
||||
import torch
|
||||
model = torch.load(dk.data_path / f"{dk.model_filename}_model.zip")
|
||||
zip = torch.load(dk.data_path / f"{dk.model_filename}_model.zip")
|
||||
model = zip["pytrainer"]
|
||||
model = model.load_from_checkpoint(zip)
|
||||
|
||||
if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file():
|
||||
dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib")
|
||||
|
@ -14,16 +14,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class CatboostClassifier(BaseClassifierModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
train_data = Pool(
|
||||
|
@ -15,16 +15,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class CatboostClassifierMultiTarget(BaseClassifierModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
cbc = CatBoostClassifier(
|
||||
|
@ -14,16 +14,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class CatboostRegressor(BaseRegressionModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
train_data = Pool(
|
||||
|
@ -15,16 +15,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class CatboostRegressorMultiTarget(BaseRegressionModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
cbr = CatBoostRegressor(
|
||||
|
@ -12,16 +12,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class LightGBMClassifier(BaseClassifierModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0:
|
||||
|
@ -13,16 +13,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class LightGBMClassifierMultiTarget(BaseClassifierModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
lgb = LGBMClassifier(**self.model_training_parameters)
|
||||
|
@ -12,18 +12,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class LightGBMRegressor(BaseRegressionModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
Most regressors use the same function names and arguments e.g. user
|
||||
can drop in LGBMRegressor in place of CatBoostRegressor and all data
|
||||
management will be properly handled by Freqai.
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0:
|
||||
|
@ -13,16 +13,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class LightGBMRegressorMultiTarget(BaseRegressionModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
lgb = LGBMRegressor(**self.model_training_parameters)
|
||||
|
@ -57,8 +57,9 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
:raises ValueError: If self.class_names is not defined in the parent class.
|
||||
"""
|
||||
|
||||
|
@ -55,8 +55,9 @@ class PyTorchMLPRegressor(BasePyTorchRegressor):
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
n_features = data_dictionary["train_features"].shape[-1]
|
||||
|
@ -18,16 +18,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class XGBoostClassifier(BaseClassifierModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
X = data_dictionary["train_features"].to_numpy()
|
||||
|
@ -18,16 +18,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class XGBoostRFClassifier(BaseClassifierModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
X = data_dictionary["train_features"].to_numpy()
|
||||
|
@ -12,16 +12,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class XGBoostRFRegressor(BaseRegressionModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
X = data_dictionary["train_features"]
|
||||
|
@ -12,16 +12,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class XGBoostRegressor(BaseRegressionModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
X = data_dictionary["train_features"]
|
||||
|
@ -13,16 +13,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class XGBoostRegressorMultiTarget(BaseRegressionModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
User created prediction model. The class inherits IFreqaiModel, which
|
||||
means it has full access to all Frequency AI functionality. Typically,
|
||||
users would use this to override the common `fit()`, `train()`, or
|
||||
`predict()` methods to add their custom data handling tools or change
|
||||
various aspects of the training that cannot be configured via the
|
||||
top level config.json file.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
:param data_dictionary: the dictionary holding all data for train, test,
|
||||
labels, weights
|
||||
:param dk: The datakitchen object for the current coin/model
|
||||
"""
|
||||
|
||||
xgb = XGBRegressor(**self.model_training_parameters)
|
||||
|
@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
@ -169,6 +169,12 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
|
||||
n_batches = math.ceil(n_obs // batch_size)
|
||||
epochs = math.ceil(n_iters // n_batches)
|
||||
if epochs <= 10:
|
||||
logger.warning("User set `max_iters` in such a way that the trainer will only perform "
|
||||
f" {epochs} epochs. Please consider increasing this value accordingly")
|
||||
if epochs <= 1:
|
||||
logger.warning("Epochs set to 1. Please review your `max_iters` value")
|
||||
epochs = 1
|
||||
return epochs
|
||||
|
||||
def save(self, path: Path):
|
||||
@ -182,6 +188,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
"model_meta_data": self.model_meta_data,
|
||||
"pytrainer": self
|
||||
}, path)
|
||||
|
||||
def load(self, path: Path):
|
||||
@ -195,7 +202,6 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
you can access this dict from any class that inherits IFreqaiModel by calling
|
||||
get_init_model method.
|
||||
"""
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
self.model_meta_data = checkpoint["model_meta_data"]
|
||||
|
@ -4,7 +4,7 @@ from typing import Dict, List
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PyTorchTrainerInterface(ABC):
|
||||
|
2
setup.sh
2
setup.sh
@ -85,7 +85,7 @@ function updateenv() {
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]
|
||||
then
|
||||
REQUIREMENTS_FREQAI="-r requirements-freqai.txt --use-pep517"
|
||||
read -p "Do you also want dependencies for freqai-rl (~700mb additional space required) [y/N]? "
|
||||
read -p "Do you also want dependencies for freqai-rl or PyTorch (~700mb additional space required) [y/N]? "
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]
|
||||
then
|
||||
REQUIREMENTS_FREQAI="-r requirements-freqai-rl.txt"
|
||||
|
Loading…
Reference in New Issue
Block a user