give user ability to analyze live trade dataframe inside custom prediction model. Add documentation to explain new functionality

This commit is contained in:
robcaulk
2022-08-02 20:14:02 +02:00
parent 895ebbfd18
commit 95d3009a95
4 changed files with 147 additions and 13 deletions

View File

@@ -2,6 +2,7 @@ import copy
import datetime
import logging
import shutil
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Tuple
@@ -39,7 +40,7 @@ class FreqaiDataKitchen:
Robert Caulk @robcaulk
Theoretical brainstorming:
Elin Törnquist @thorntwig
Elin Törnquist @th0rntwig
Code review, software architecture brainstorming:
@xmatthias
@@ -84,6 +85,12 @@ class FreqaiDataKitchen:
config["freqai"]["backtest_period_days"],
)
db_url = self.config.get('db_url', None)
self.database_path = '' if db_url == 'sqlite://' else str(db_url).split('///')[1]
self.trade_database_df: DataFrame = pd.DataFrame()
self.data['extra_returns_per_train'] = self.freqai_config.get('extra_returns_per_train', {})
def set_paths(
self,
pair: str,
@@ -101,7 +108,7 @@ class FreqaiDataKitchen:
self.data_path = Path(
self.full_path
/ str("sub-train" + "-" + pair.split("/")[0] + "_" + str(trained_timestamp))
/ f"sub-train-{pair.split('/')[0]}_{trained_timestamp}"
)
return
@@ -328,7 +335,7 @@ class FreqaiDataKitchen:
"""
for label in self.label_list:
if df[label].dtype == str:
if df[label].dtype == object:
continue
df[label] = (
(df[label] + 1)
@@ -493,7 +500,6 @@ class FreqaiDataKitchen:
tc = self.freqai_config.get("model_training_parameters", {}).get("thread_count", -1)
pairwise = pairwise_distances(self.data_dictionary["train_features"], n_jobs=tc)
avg_mean_dist = pairwise.mean(axis=1).mean()
logger.info(f"avg_mean_dist {avg_mean_dist:.2f}")
return avg_mean_dist
@@ -599,10 +605,11 @@ class FreqaiDataKitchen:
from the training data set.
"""
tc = self.freqai_config.get("model_training_parameters", {}).get("thread_count", -1)
distance = pairwise_distances(
self.data_dictionary["train_features"],
self.data_dictionary["prediction_features"],
n_jobs=-1,
n_jobs=tc,
)
self.DI_values = distance.min(axis=0) / self.data["avg_mean_dist"]
@@ -946,6 +953,19 @@ class FreqaiDataKitchen:
]
return dataframe[to_keep]
def get_current_trade_database(self) -> None:
if self.database_path == '':
logger.warning('No trade databse found. Skipping analysis.')
return
data = sqlite3.connect(self.database_path)
query = data.execute("SELECT * From trades")
cols = [column[0] for column in query.description]
df = pd.DataFrame.from_records(data=query.fetchall(), columns=cols)
self.trade_database_df = df.dropna(subset='close_date')
data.close()
def np_encoder(self, object):
if isinstance(object, np.generic):
return object.item()