Upgrade spaces and add more tests
This commit is contained in:
@@ -6,3 +6,5 @@
|
||||
|
||||
from .basic_space import Categorical
|
||||
from .basic_space import Continuous
|
||||
from .basic_op import has_categorical
|
||||
from .basic_op import has_continuous
|
||||
|
16
lib/spaces/basic_op.py
Normal file
16
lib/spaces/basic_op.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from spaces.basic_space import Space
|
||||
from spaces.basic_space import _EPS
|
||||
|
||||
|
||||
def has_categorical(space_or_value, x):
|
||||
if isinstance(space_or_value, Space):
|
||||
return space_or_value.has(x)
|
||||
else:
|
||||
return space_or_value == x
|
||||
|
||||
|
||||
def has_continuous(space_or_value, x):
|
||||
if isinstance(space_or_value, Space):
|
||||
return space_or_value.has(x)
|
||||
else:
|
||||
return abs(space_or_value - x) <= _EPS
|
@@ -4,28 +4,65 @@
|
||||
|
||||
import abc
|
||||
import math
|
||||
import copy
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from typing import Optional
|
||||
|
||||
_EPS = 1e-9
|
||||
|
||||
|
||||
class Space(metaclass=abc.ABCMeta):
|
||||
"""Basic search space describing the set of possible candidate values for hyperparameter.
|
||||
All search space must inherit from this basic class.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def random(self, recursion=True):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractproperty
|
||||
def determined(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def __repr__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def has(self, x):
|
||||
"""Check whether x is in this search space."""
|
||||
assert not isinstance(
|
||||
x, Space
|
||||
), "The input value itself can not be a search space."
|
||||
|
||||
def copy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
class Categorical(Space):
|
||||
"""A space contains the categorical values.
|
||||
It can be a nested space, which means that the candidate in this space can also be a search space.
|
||||
"""
|
||||
|
||||
def __init__(self, *data, default: Optional[int] = None):
|
||||
self._candidates = [*data]
|
||||
self._default = default
|
||||
assert self._default is None or 0 <= self._default < len(self._candidates), "default >= {:}".format(
|
||||
len(self._candidates)
|
||||
)
|
||||
assert self._default is None or 0 <= self._default < len(
|
||||
self._candidates
|
||||
), "default >= {:}".format(len(self._candidates))
|
||||
assert len(self) > 0, "Please provide at least one candidate"
|
||||
|
||||
@property
|
||||
def determined(self):
|
||||
if len(self) == 1:
|
||||
return (
|
||||
not isinstance(self._candidates[0], Space)
|
||||
or self._candidates[0].determined
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._candidates[index]
|
||||
@@ -38,6 +75,15 @@ class Categorical(Space):
|
||||
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
||||
)
|
||||
|
||||
def has(self, x):
|
||||
super().has(x)
|
||||
for candidate in self._candidates:
|
||||
if isinstance(candidate, Space) and candidate.has(x):
|
||||
return True
|
||||
elif candidate == x:
|
||||
return True
|
||||
return False
|
||||
|
||||
def random(self, recursion=True):
|
||||
sample = random.choice(self._candidates)
|
||||
if recursion and isinstance(sample, Space):
|
||||
@@ -46,12 +92,35 @@ class Categorical(Space):
|
||||
return sample
|
||||
|
||||
|
||||
np_float_types = (np.float16, np.float32, np.float64)
|
||||
np_int_types = (
|
||||
np.uint8,
|
||||
np.int8,
|
||||
np.uint16,
|
||||
np.int16,
|
||||
np.uint32,
|
||||
np.int32,
|
||||
np.uint64,
|
||||
np.int64,
|
||||
)
|
||||
|
||||
|
||||
class Continuous(Space):
|
||||
def __init__(self, lower: float, upper: float, default: Optional[float] = None, log: bool = False):
|
||||
"""A space contains the continuous values."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower: float,
|
||||
upper: float,
|
||||
default: Optional[float] = None,
|
||||
log: bool = False,
|
||||
eps: float = _EPS,
|
||||
):
|
||||
self._lower = lower
|
||||
self._upper = upper
|
||||
self._default = default
|
||||
self._log_scale = log
|
||||
self._eps = eps
|
||||
|
||||
@property
|
||||
def lower(self):
|
||||
@@ -65,6 +134,10 @@ class Continuous(Space):
|
||||
def default(self):
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def determined(self):
|
||||
return abs(self.lower - self.upper) <= self._eps
|
||||
|
||||
def __repr__(self):
|
||||
return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
||||
name=self.__class__.__name__,
|
||||
@@ -74,6 +147,23 @@ class Continuous(Space):
|
||||
log=self._log_scale,
|
||||
)
|
||||
|
||||
def convert(self, x):
|
||||
if isinstance(x, np_float_types) and x.size == 1:
|
||||
return float(x), True
|
||||
elif isinstance(x, np_int_types) and x.size == 1:
|
||||
return float(x), True
|
||||
elif isinstance(x, int):
|
||||
return float(x), True
|
||||
elif isinstance(x, float):
|
||||
return float(x), True
|
||||
else:
|
||||
return None, False
|
||||
|
||||
def has(self, x):
|
||||
super().has(x)
|
||||
converted_x, success = self.convert(x)
|
||||
return success and self.lower <= converted_x <= self.upper
|
||||
|
||||
def random(self, recursion=True):
|
||||
del recursion
|
||||
if self._log_scale:
|
||||
|
Reference in New Issue
Block a user