initial commit

This commit is contained in:
Yinon Polak
2023-03-05 16:59:24 +02:00
parent 108a578772
commit 751b205618
5 changed files with 254 additions and 1 deletions

View File

@@ -0,0 +1,69 @@
import logging
from time import time
from typing import Any, Dict
import torch
from pandas import DataFrame
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel
logger = logging.getLogger(__name__)
class BasePytorchModel(IFreqaiModel):
"""
Base class for TensorFlow type models.
User *must* inherit from this class and set fit() and predict().
"""
def __init__(self, **kwargs):
super().__init__(config=kwargs['config'])
self.dd.model_type = 'pytorch'
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
def train(
self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
) -> Any:
"""
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
for storing, saving, loading, and analyzing the data.
:param unfiltered_df: Full dataframe for the current training period
:param metadata: pair metadata from strategy.
:return:
:model: Trained model which can be used to inference (self.predict)
"""
logger.info(f"-------------------- Starting training {pair} --------------------")
start_time = time()
features_filtered, labels_filtered = dk.filter_features(
unfiltered_df,
dk.training_features_list,
dk.label_list,
training_filter=True,
)
# split data into train/test data.
data_dictionary = dk.make_train_test_datasets(features_filtered, labels_filtered)
if not self.freqai_info.get("fit_live_predictions", 0) or not self.live:
dk.fit_labels()
# normalize all data based on train_dataset only
data_dictionary = dk.normalize_data(data_dictionary)
# optional additional data cleaning/analysis
self.data_cleaning_train(dk)
logger.info(
f"Training model on {len(dk.data_dictionary['train_features'].columns)} features"
)
logger.info(f"Training model on {len(data_dictionary['train_features'])} data points")
model = self.fit(data_dictionary, dk)
end_time = time()
logger.info(f"-------------------- Done training {pair} "
f"({end_time - start_time:.2f} secs) --------------------")
return model

View File

@@ -0,0 +1,51 @@
import logging
from pathlib import Path
from typing import Dict
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
class PytorchModelTrainer:
def __init__(self, model: nn.Module, optimizer, init_model: Dict):
self.model = model
self.optimizer = optimizer
if init_model:
self.load_from_checkpoint(init_model)
def fit(self, tensor_dictionary, max_iters, batch_size):
for iter in range(max_iters):
# todo add validation evaluation here
xb, yb = self.get_batch(tensor_dictionary, 'train', batch_size)
logits, loss = self.model(xb, yb)
self.optimizer.zero_grad(set_to_none=True)
loss.backward()
self.optimizer.step()
def save(self, path):
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}, path)
def load_from_file(self, path: Path):
checkpoint = torch.load(path)
return self.load_from_checkpoint(checkpoint)
def load_from_checkpoint(self, checkpoint: Dict):
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return self
@staticmethod
def get_batch(tensor_dictionary: Dict, split: str, batch_size: int):
ix = torch.randint(len(tensor_dictionary[f'{split}_labels']), (batch_size,))
x = tensor_dictionary[f'{split}_features'][ix]
y = tensor_dictionary[f'{split}_labels'][ix]
return x, y