Update hyperopt.py
remove duplicates from list of asked points
This commit is contained in:
parent
e85c7ca8ff
commit
1559692e47
@ -422,16 +422,23 @@ class Hyperopt:
|
|||||||
5. Repeat until at least `n_points` points in the `asked_non_tried` list
|
5. Repeat until at least `n_points` points in the `asked_non_tried` list
|
||||||
6. Return a list with length truncated at `n_points`
|
6. Return a list with length truncated at `n_points`
|
||||||
'''
|
'''
|
||||||
|
def unique_list(a_list):
|
||||||
|
seen = []
|
||||||
|
for x in a_list:
|
||||||
|
key = repr(x)
|
||||||
|
if key not in seen:
|
||||||
|
seen.append(eval(key))
|
||||||
|
return seen
|
||||||
i = 0
|
i = 0
|
||||||
asked_non_tried: List[List[Any]] = []
|
asked_non_tried: List[List[Any]] = []
|
||||||
is_random: List[bool] = []
|
is_random: List[bool] = []
|
||||||
while i < 5 and len(asked_non_tried) < n_points:
|
while i < 5 and len(asked_non_tried) < n_points:
|
||||||
if i < 3:
|
if i < 3:
|
||||||
self.opt.cache_ = {}
|
self.opt.cache_ = {}
|
||||||
asked = self.opt.ask(n_points=n_points * 5)
|
asked = unique_list(self.opt.ask(n_points=n_points * 5))
|
||||||
is_random = [False for _ in range(len(asked))]
|
is_random = [False for _ in range(len(asked))]
|
||||||
else:
|
else:
|
||||||
asked = self.opt.space.rvs(n_samples=n_points * 5)
|
asked = unique_list(self.opt.space.rvs(n_samples=n_points * 5))
|
||||||
is_random = [True for _ in range(len(asked))]
|
is_random = [True for _ in range(len(asked))]
|
||||||
asked_non_tried += [x for x in asked
|
asked_non_tried += [x for x in asked
|
||||||
if x not in self.opt.Xi
|
if x not in self.opt.Xi
|
||||||
|
Loading…
Reference in New Issue
Block a user