fix
- clear cache before calling `ask` - avoid errors in case asked_non_tried has less than n_points elements
This commit is contained in:
parent
f8a674f24d
commit
37a43019d6
@ -426,6 +426,7 @@ class Hyperopt:
|
|||||||
asked_non_tried: List[List[Any]] = []
|
asked_non_tried: List[List[Any]] = []
|
||||||
while i < 100 and len(asked_non_tried) < n_points:
|
while i < 100 and len(asked_non_tried) < n_points:
|
||||||
if i < 3:
|
if i < 3:
|
||||||
|
self.opt.cache_ = {}
|
||||||
asked = self.opt.ask(n_points=n_points)
|
asked = self.opt.ask(n_points=n_points)
|
||||||
else:
|
else:
|
||||||
asked = self.opt.space.rvs(n_samples=n_points * 5)
|
asked = self.opt.space.rvs(n_samples=n_points * 5)
|
||||||
@ -434,7 +435,7 @@ class Hyperopt:
|
|||||||
and x not in asked_non_tried]
|
and x not in asked_non_tried]
|
||||||
i += 1
|
i += 1
|
||||||
if asked_non_tried:
|
if asked_non_tried:
|
||||||
return asked_non_tried[:n_points]
|
return asked_non_tried[:min(len(asked_non_tried), n_points)]
|
||||||
else:
|
else:
|
||||||
return self.opt.ask(n_points=n_points)
|
return self.opt.ask(n_points=n_points)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user