appease mypy

This commit is contained in:
robcaulk
2022-05-06 16:20:52 +02:00
parent a4f5811a5b
commit 178c2014b0
2 changed files with 25 additions and 14 deletions

View File

@@ -8,6 +8,7 @@ from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
import numpy.typing as npt
import pandas as pd
from joblib import dump, load
from pandas import DataFrame
@@ -35,14 +36,14 @@ class FreqaiDataKitchen:
self.data_dictionary: Dict[Any, Any] = {}
self.config = config
self.freqai_config = config["freqai"]
self.predictions = np.array([])
self.do_predict = np.array([])
self.target_mean = np.array([])
self.target_std = np.array([])
self.full_predictions = np.array([])
self.full_do_predict = np.array([])
self.full_target_mean = np.array([])
self.full_target_std = np.array([])
self.predictions: npt.ArrayLike = np.array([])
self.do_predict: npt.ArrayLike = np.array([])
self.target_mean: npt.ArrayLike = np.array([])
self.target_std: npt.ArrayLike = np.array([])
self.full_predictions: npt.ArrayLike = np.array([])
self.full_do_predict: npt.ArrayLike = np.array([])
self.full_target_mean: npt.ArrayLike = np.array([])
self.full_target_std: npt.ArrayLike = np.array([])
self.model_path = Path()
self.model_filename = ""
@@ -123,6 +124,7 @@ class FreqaiDataKitchen:
:labels: cleaned labels ready to be split.
"""
weights: npt.ArrayLike
if self.config["freqai"]["feature_parameters"]["weight_factor"] > 0:
weights = self.set_weights_higher_recent(len(filtered_dataframe))
else:
@@ -519,12 +521,13 @@ class FreqaiDataKitchen:
self.do_predict += do_predict
self.do_predict -= 1
def set_weights_higher_recent(self, num_weights: int) -> int:
def set_weights_higher_recent(self, num_weights: int) -> npt.ArrayLike:
"""
Set weights so that recent data is more heavily weighted during
training than older data.
"""
weights = np.zeros(num_weights)
weights = np.zeros_like(num_weights)
for i in range(1, len(weights)):
weights[len(weights) - i] = np.exp(
-i / (self.config["freqai"]["feature_parameters"]["weight_factor"] * num_weights)