Upgrade spaces and add more tests

This commit is contained in:
D-X-Y
2021-03-18 15:04:14 +08:00
parent 85ee0ad4eb
commit 38409e602f
12 changed files with 386 additions and 84 deletions

View File

@@ -6,3 +6,5 @@
from .basic_space import Categorical
from .basic_space import Continuous
from .basic_op import has_categorical
from .basic_op import has_continuous

16
lib/spaces/basic_op.py Normal file
View File

@@ -0,0 +1,16 @@
from spaces.basic_space import Space
from spaces.basic_space import _EPS
def has_categorical(space_or_value, x):
if isinstance(space_or_value, Space):
return space_or_value.has(x)
else:
return space_or_value == x
def has_continuous(space_or_value, x):
if isinstance(space_or_value, Space):
return space_or_value.has(x)
else:
return abs(space_or_value - x) <= _EPS

View File

@@ -4,28 +4,65 @@
import abc
import math
import copy
import random
import numpy as np
from typing import Optional
_EPS = 1e-9
class Space(metaclass=abc.ABCMeta):
"""Basic search space describing the set of possible candidate values for hyperparameter.
All search space must inherit from this basic class.
"""
@abc.abstractmethod
def random(self, recursion=True):
raise NotImplementedError
@abc.abstractproperty
def determined(self):
raise NotImplementedError
@abc.abstractmethod
def __repr__(self):
raise NotImplementedError
@abc.abstractmethod
def has(self, x):
"""Check whether x is in this search space."""
assert not isinstance(
x, Space
), "The input value itself can not be a search space."
def copy(self):
return copy.deepcopy(self)
class Categorical(Space):
"""A space contains the categorical values.
It can be a nested space, which means that the candidate in this space can also be a search space.
"""
def __init__(self, *data, default: Optional[int] = None):
self._candidates = [*data]
self._default = default
assert self._default is None or 0 <= self._default < len(self._candidates), "default >= {:}".format(
len(self._candidates)
)
assert self._default is None or 0 <= self._default < len(
self._candidates
), "default >= {:}".format(len(self._candidates))
assert len(self) > 0, "Please provide at least one candidate"
@property
def determined(self):
if len(self) == 1:
return (
not isinstance(self._candidates[0], Space)
or self._candidates[0].determined
)
else:
return False
def __getitem__(self, index):
return self._candidates[index]
@@ -38,6 +75,15 @@ class Categorical(Space):
name=self.__class__.__name__, cs=self._candidates, default=self._default
)
def has(self, x):
super().has(x)
for candidate in self._candidates:
if isinstance(candidate, Space) and candidate.has(x):
return True
elif candidate == x:
return True
return False
def random(self, recursion=True):
sample = random.choice(self._candidates)
if recursion and isinstance(sample, Space):
@@ -46,12 +92,35 @@ class Categorical(Space):
return sample
np_float_types = (np.float16, np.float32, np.float64)
np_int_types = (
np.uint8,
np.int8,
np.uint16,
np.int16,
np.uint32,
np.int32,
np.uint64,
np.int64,
)
class Continuous(Space):
def __init__(self, lower: float, upper: float, default: Optional[float] = None, log: bool = False):
"""A space contains the continuous values."""
def __init__(
self,
lower: float,
upper: float,
default: Optional[float] = None,
log: bool = False,
eps: float = _EPS,
):
self._lower = lower
self._upper = upper
self._default = default
self._log_scale = log
self._eps = eps
@property
def lower(self):
@@ -65,6 +134,10 @@ class Continuous(Space):
def default(self):
return self._default
@property
def determined(self):
return abs(self.lower - self.upper) <= self._eps
def __repr__(self):
return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
name=self.__class__.__name__,
@@ -74,6 +147,23 @@ class Continuous(Space):
log=self._log_scale,
)
def convert(self, x):
if isinstance(x, np_float_types) and x.size == 1:
return float(x), True
elif isinstance(x, np_int_types) and x.size == 1:
return float(x), True
elif isinstance(x, int):
return float(x), True
elif isinstance(x, float):
return float(x), True
else:
return None, False
def has(self, x):
super().has(x)
converted_x, success = self.convert(x)
return success and self.lower <= converted_x <= self.upper
def random(self, recursion=True):
del recursion
if self._log_scale: