Improve DBSCAN epsilon identification (#7269)
* Improve DBSCAN epsilon identification
This commit is contained in:
parent
60ba921f56
commit
5ce1c69803
@ -601,6 +601,8 @@ class FreqaiDataKitchen:
|
|||||||
is an outlier.
|
is an outlier.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from math import cos, sin
|
||||||
|
|
||||||
if predict:
|
if predict:
|
||||||
train_ft_df = self.data_dictionary['train_features']
|
train_ft_df = self.data_dictionary['train_features']
|
||||||
pred_ft_df = self.data_dictionary['prediction_features']
|
pred_ft_df = self.data_dictionary['prediction_features']
|
||||||
@ -619,23 +621,47 @@ class FreqaiDataKitchen:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
def normalise_distances(distances):
|
||||||
|
normalised_distances = (distances - distances.min()) / \
|
||||||
|
(distances.max() - distances.min())
|
||||||
|
return normalised_distances
|
||||||
|
|
||||||
|
def rotate_point(origin, point, angle):
|
||||||
|
# rotate a point counterclockwise by a given angle (in radians)
|
||||||
|
# around a given origin
|
||||||
|
x = origin[0] + cos(angle) * (point[0] - origin[0]) - \
|
||||||
|
sin(angle) * (point[1] - origin[1])
|
||||||
|
y = origin[1] + sin(angle) * (point[0] - origin[0]) + \
|
||||||
|
cos(angle) * (point[1] - origin[1])
|
||||||
|
return (x, y)
|
||||||
|
|
||||||
MinPts = len(self.data_dictionary['train_features'].columns) * 2
|
MinPts = len(self.data_dictionary['train_features'].columns) * 2
|
||||||
# measure pairwise distances to train_features.shape[1]*2 nearest neighbours
|
# measure pairwise distances to train_features.shape[1]*2 nearest neighbours
|
||||||
neighbors = NearestNeighbors(
|
neighbors = NearestNeighbors(
|
||||||
n_neighbors=MinPts, n_jobs=self.thread_count)
|
n_neighbors=MinPts, n_jobs=self.thread_count)
|
||||||
neighbors_fit = neighbors.fit(self.data_dictionary['train_features'])
|
neighbors_fit = neighbors.fit(self.data_dictionary['train_features'])
|
||||||
distances, _ = neighbors_fit.kneighbors(self.data_dictionary['train_features'])
|
distances, _ = neighbors_fit.kneighbors(self.data_dictionary['train_features'])
|
||||||
distances = np.sort(distances, axis=0)
|
distances = np.sort(distances, axis=0).mean(axis=1)
|
||||||
index_ten_pct = int(len(distances[:, 1]) * 0.1)
|
|
||||||
distances = distances[index_ten_pct:, 1]
|
normalised_distances = normalise_distances(distances)
|
||||||
epsilon = distances[-1]
|
x_range = np.linspace(0, 1, len(distances))
|
||||||
|
line = np.linspace(normalised_distances[0],
|
||||||
|
normalised_distances[-1], len(normalised_distances))
|
||||||
|
deflection = np.abs(normalised_distances - line)
|
||||||
|
max_deflection_loc = np.where(deflection == deflection.max())[0][0]
|
||||||
|
origin = x_range[max_deflection_loc], line[max_deflection_loc]
|
||||||
|
point = x_range[max_deflection_loc], normalised_distances[max_deflection_loc]
|
||||||
|
rot_angle = np.pi / 4
|
||||||
|
elbow_loc = rotate_point(origin, point, rot_angle)
|
||||||
|
|
||||||
|
epsilon = elbow_loc[1] * (distances[-1] - distances[0]) + distances[0]
|
||||||
|
|
||||||
clustering = DBSCAN(eps=epsilon, min_samples=MinPts,
|
clustering = DBSCAN(eps=epsilon, min_samples=MinPts,
|
||||||
n_jobs=int(self.thread_count)).fit(
|
n_jobs=int(self.thread_count)).fit(
|
||||||
self.data_dictionary['train_features']
|
self.data_dictionary['train_features']
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'DBSCAN found eps of {epsilon}.')
|
logger.info(f'DBSCAN found eps of {epsilon:.2f}.')
|
||||||
|
|
||||||
self.data['DBSCAN_eps'] = epsilon
|
self.data['DBSCAN_eps'] = epsilon
|
||||||
self.data['DBSCAN_min_samples'] = MinPts
|
self.data['DBSCAN_min_samples'] = MinPts
|
||||||
@ -698,7 +724,7 @@ class FreqaiDataKitchen:
|
|||||||
|
|
||||||
if (len(do_predict) - do_predict.sum()) > 0:
|
if (len(do_predict) - do_predict.sum()) > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"DI tossed {len(do_predict) - do_predict.sum():.2f} predictions for "
|
f"DI tossed {len(do_predict) - do_predict.sum()} predictions for "
|
||||||
"being too far from training data"
|
"being too far from training data"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user