Update hyperopt.py

remove duplicates from list of asked points
This commit is contained in:
Italo 2022-04-08 11:44:42 +01:00
parent e85c7ca8ff
commit 1559692e47

View File

@ -422,16 +422,23 @@ class Hyperopt:
5. Repeat until at least `n_points` points in the `asked_non_tried` list 5. Repeat until at least `n_points` points in the `asked_non_tried` list
6. Return a list with length truncated at `n_points` 6. Return a list with length truncated at `n_points`
''' '''
def unique_list(a_list):
seen = []
for x in a_list:
key = repr(x)
if key not in seen:
seen.append(eval(key))
return seen
i = 0 i = 0
asked_non_tried: List[List[Any]] = [] asked_non_tried: List[List[Any]] = []
is_random: List[bool] = [] is_random: List[bool] = []
while i < 5 and len(asked_non_tried) < n_points: while i < 5 and len(asked_non_tried) < n_points:
if i < 3: if i < 3:
self.opt.cache_ = {} self.opt.cache_ = {}
asked = self.opt.ask(n_points=n_points * 5) asked = unique_list(self.opt.ask(n_points=n_points * 5))
is_random = [False for _ in range(len(asked))] is_random = [False for _ in range(len(asked))]
else: else:
asked = self.opt.space.rvs(n_samples=n_points * 5) asked = unique_list(self.opt.space.rvs(n_samples=n_points * 5))
is_random = [True for _ in range(len(asked))] is_random = [True for _ in range(len(asked))]
asked_non_tried += [x for x in asked asked_non_tried += [x for x in asked
if x not in self.opt.Xi if x not in self.opt.Xi