Move to xautodl
This commit is contained in:
17
xautodl/spaces/__init__.py
Normal file
17
xautodl/spaces/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
|
||||
#####################################################
|
||||
# Define complex searc space for AutoDL #
|
||||
#####################################################
|
||||
|
||||
from .basic_space import Categorical
|
||||
from .basic_space import Continuous
|
||||
from .basic_space import Integer
|
||||
from .basic_space import Space
|
||||
from .basic_space import VirtualNode
|
||||
from .basic_op import has_categorical
|
||||
from .basic_op import has_continuous
|
||||
from .basic_op import is_determined
|
||||
from .basic_op import get_determined_value
|
||||
from .basic_op import get_min
|
||||
from .basic_op import get_max
|
71
xautodl/spaces/basic_op.py
Normal file
71
xautodl/spaces/basic_op.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from spaces.basic_space import Space
|
||||
from spaces.basic_space import VirtualNode
|
||||
from spaces.basic_space import Integer
|
||||
from spaces.basic_space import Continuous
|
||||
from spaces.basic_space import Categorical
|
||||
from spaces.basic_space import _EPS
|
||||
|
||||
|
||||
def has_categorical(space_or_value, x):
|
||||
if isinstance(space_or_value, Space):
|
||||
return space_or_value.has(x)
|
||||
else:
|
||||
return space_or_value == x
|
||||
|
||||
|
||||
def has_continuous(space_or_value, x):
|
||||
if isinstance(space_or_value, Space):
|
||||
return space_or_value.has(x)
|
||||
else:
|
||||
return abs(space_or_value - x) <= _EPS
|
||||
|
||||
|
||||
def is_determined(space_or_value):
|
||||
if isinstance(space_or_value, Space):
|
||||
return space_or_value.determined
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def get_determined_value(space_or_value):
|
||||
if not is_determined(space_or_value):
|
||||
raise ValueError("This input is not determined: {:}".format(space_or_value))
|
||||
if isinstance(space_or_value, Space):
|
||||
if isinstance(space_or_value, Continuous):
|
||||
return space_or_value.lower
|
||||
elif isinstance(space_or_value, Categorical):
|
||||
return get_determined_value(space_or_value[0])
|
||||
else: # VirtualNode
|
||||
return space_or_value.value
|
||||
else:
|
||||
return space_or_value
|
||||
|
||||
|
||||
def get_max(space_or_value):
|
||||
if isinstance(space_or_value, Integer):
|
||||
return max(space_or_value.candidates)
|
||||
elif isinstance(space_or_value, Continuous):
|
||||
return space_or_value.upper
|
||||
elif isinstance(space_or_value, Categorical):
|
||||
values = []
|
||||
for index in range(len(space_or_value)):
|
||||
max_value = get_max(space_or_value[index])
|
||||
values.append(max_value)
|
||||
return max(values)
|
||||
else:
|
||||
return space_or_value
|
||||
|
||||
|
||||
def get_min(space_or_value):
|
||||
if isinstance(space_or_value, Integer):
|
||||
return min(space_or_value.candidates)
|
||||
elif isinstance(space_or_value, Continuous):
|
||||
return space_or_value.lower
|
||||
elif isinstance(space_or_value, Categorical):
|
||||
values = []
|
||||
for index in range(len(space_or_value)):
|
||||
min_value = get_min(space_or_value[index])
|
||||
values.append(min_value)
|
||||
return min(values)
|
||||
else:
|
||||
return space_or_value
|
434
xautodl/spaces/basic_space.py
Normal file
434
xautodl/spaces/basic_space.py
Normal file
@@ -0,0 +1,434 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
|
||||
import abc
|
||||
import math
|
||||
import copy
|
||||
import random
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
from typing import Optional, Text
|
||||
|
||||
|
||||
__all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"]
|
||||
|
||||
_EPS = 1e-9
|
||||
|
||||
|
||||
class Space(metaclass=abc.ABCMeta):
|
||||
"""Basic search space describing the set of possible candidate values for hyperparameter.
|
||||
All search space must inherit from this basic class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# used to avoid duplicate sample
|
||||
self._last_sample = None
|
||||
self._last_abstract = None
|
||||
|
||||
@abc.abstractproperty
|
||||
def xrepr(self, depth=0) -> Text:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> Text:
|
||||
return self.xrepr()
|
||||
|
||||
@abc.abstractproperty
|
||||
def abstract(self, reuse_last=False) -> "Space":
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def random(self, recursion=True, reuse_last=False):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def clean_last_sample(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def clean_last_abstract(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def clean_last(self):
|
||||
self.clean_last_sample()
|
||||
self.clean_last_abstract()
|
||||
|
||||
@abc.abstractproperty
|
||||
def determined(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def has(self, x) -> bool:
|
||||
"""Check whether x is in this search space."""
|
||||
assert not isinstance(
|
||||
x, Space
|
||||
), "The input value itself can not be a search space."
|
||||
|
||||
@abc.abstractmethod
|
||||
def __eq__(self, other):
|
||||
raise NotImplementedError
|
||||
|
||||
def copy(self) -> "Space":
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
class VirtualNode(Space):
|
||||
"""For a nested search space, we represent it as a tree structure.
|
||||
|
||||
For example,
|
||||
"""
|
||||
|
||||
def __init__(self, id=None, value=None):
|
||||
super(VirtualNode, self).__init__()
|
||||
self._id = id
|
||||
self._value = value
|
||||
self._attributes = OrderedDict()
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
def append(self, key, value):
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(
|
||||
"Only accept string as a key instead of {:}".format(type(key))
|
||||
)
|
||||
if not isinstance(value, Space):
|
||||
raise ValueError("Invalid type of value: {:}".format(type(value)))
|
||||
# if value.determined:
|
||||
# raise ValueError("Can not attach a determined value: {:}".format(value))
|
||||
self._attributes[key] = value
|
||||
|
||||
def xrepr(self, depth=0) -> Text:
|
||||
strs = [self.__class__.__name__ + "(value={:}".format(self._value)]
|
||||
for key, value in self._attributes.items():
|
||||
strs.append(key + " = " + value.xrepr(depth + 1))
|
||||
strs.append(")")
|
||||
if len(strs) == 2:
|
||||
return "".join(strs)
|
||||
else:
|
||||
space = " "
|
||||
xstrs = (
|
||||
[strs[0]]
|
||||
+ [space * (depth + 1) + x for x in strs[1:-1]]
|
||||
+ [space * depth + strs[-1]]
|
||||
)
|
||||
return ",\n".join(xstrs)
|
||||
|
||||
def abstract(self, reuse_last=False) -> Space:
|
||||
if reuse_last and self._last_abstract is not None:
|
||||
return self._last_abstract
|
||||
node = VirtualNode(id(self))
|
||||
for key, value in self._attributes.items():
|
||||
if not value.determined:
|
||||
node.append(value.abstract(reuse_last))
|
||||
self._last_abstract = node
|
||||
return self._last_abstract
|
||||
|
||||
def random(self, recursion=True, reuse_last=False):
|
||||
if reuse_last and self._last_sample is not None:
|
||||
return self._last_sample
|
||||
node = VirtualNode(None, self._value)
|
||||
for key, value in self._attributes.items():
|
||||
node.append(key, value.random(recursion, reuse_last))
|
||||
self._last_sample = node # record the last sample
|
||||
return node
|
||||
|
||||
def clean_last_sample(self):
|
||||
self._last_sample = None
|
||||
for key, value in self._attributes.items():
|
||||
value.clean_last_sample()
|
||||
|
||||
def clean_last_abstract(self):
|
||||
self._last_abstract = None
|
||||
for key, value in self._attributes.items():
|
||||
value.clean_last_abstract()
|
||||
|
||||
def has(self, x) -> bool:
|
||||
for key, value in self._attributes.items():
|
||||
if value.has(x):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._attributes
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._attributes[key]
|
||||
|
||||
@property
|
||||
def determined(self) -> bool:
|
||||
for key, value in self._attributes.items():
|
||||
if not value.determined:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, VirtualNode):
|
||||
return False
|
||||
for key, value in self._attributes.items():
|
||||
if not key in other:
|
||||
return False
|
||||
if value != other[key]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class Categorical(Space):
|
||||
"""A space contains the categorical values.
|
||||
It can be a nested space, which means that the candidate in this space can also be a search space.
|
||||
"""
|
||||
|
||||
def __init__(self, *data, default: Optional[int] = None):
|
||||
super(Categorical, self).__init__()
|
||||
self._candidates = [*data]
|
||||
self._default = default
|
||||
assert self._default is None or 0 <= self._default < len(
|
||||
self._candidates
|
||||
), "default >= {:}".format(len(self._candidates))
|
||||
assert len(self) > 0, "Please provide at least one candidate"
|
||||
|
||||
@property
|
||||
def candidates(self):
|
||||
return self._candidates
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def determined(self):
|
||||
if len(self) == 1:
|
||||
return (
|
||||
not isinstance(self._candidates[0], Space)
|
||||
or self._candidates[0].determined
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._candidates[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._candidates)
|
||||
|
||||
def clean_last_sample(self):
|
||||
self._last_sample = None
|
||||
for candidate in self._candidates:
|
||||
if isinstance(candidate, Space):
|
||||
candidate.clean_last_sample()
|
||||
|
||||
def clean_last_abstract(self):
|
||||
self._last_abstract = None
|
||||
for candidate in self._candidates:
|
||||
if isinstance(candidate, Space):
|
||||
candidate.clean_last_abstract()
|
||||
|
||||
def abstract(self, reuse_last=False) -> Space:
|
||||
if reuse_last and self._last_abstract is not None:
|
||||
return self._last_abstract
|
||||
if self.determined:
|
||||
result = VirtualNode(id(self), self)
|
||||
else:
|
||||
# [TO-IMPROVE]
|
||||
data = []
|
||||
for candidate in self.candidates:
|
||||
if isinstance(candidate, Space):
|
||||
data.append(candidate.abstract())
|
||||
else:
|
||||
data.append(VirtualNode(id(candidate), candidate))
|
||||
result = Categorical(*data, default=self._default)
|
||||
self._last_abstract = result
|
||||
return self._last_abstract
|
||||
|
||||
def random(self, recursion=True, reuse_last=False):
|
||||
if reuse_last and self._last_sample is not None:
|
||||
return self._last_sample
|
||||
sample = random.choice(self._candidates)
|
||||
if recursion and isinstance(sample, Space):
|
||||
sample = sample.random(recursion, reuse_last)
|
||||
if isinstance(sample, VirtualNode):
|
||||
sample = sample.copy()
|
||||
else:
|
||||
sample = VirtualNode(None, sample)
|
||||
self._last_sample = sample
|
||||
return self._last_sample
|
||||
|
||||
def xrepr(self, depth=0):
|
||||
del depth
|
||||
xrepr = "{name:}(candidates={cs:}, default_index={default:})".format(
|
||||
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
||||
)
|
||||
return xrepr
|
||||
|
||||
def has(self, x):
|
||||
super().has(x)
|
||||
for candidate in self._candidates:
|
||||
if isinstance(candidate, Space) and candidate.has(x):
|
||||
return True
|
||||
elif candidate == x:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Categorical):
|
||||
return False
|
||||
if len(self) != len(other):
|
||||
return False
|
||||
if self.default != other.default:
|
||||
return False
|
||||
for index in range(len(self)):
|
||||
if self.__getitem__(index) != other[index]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class Integer(Categorical):
|
||||
"""A space contains the integer values."""
|
||||
|
||||
def __init__(self, lower: int, upper: int, default: Optional[int] = None):
|
||||
if not isinstance(lower, int) or not isinstance(upper, int):
|
||||
raise ValueError(
|
||||
"The lower [{:}] and uppwer [{:}] must be int.".format(lower, upper)
|
||||
)
|
||||
data = list(range(lower, upper + 1))
|
||||
self._raw_lower = lower
|
||||
self._raw_upper = upper
|
||||
self._raw_default = default
|
||||
if default is not None and (default < lower or default > upper):
|
||||
raise ValueError("The default value [{:}] is out of range.".format(default))
|
||||
default = data.index(default)
|
||||
super(Integer, self).__init__(*data, default=default)
|
||||
|
||||
def xrepr(self, depth=0):
|
||||
del depth
|
||||
xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format(
|
||||
name=self.__class__.__name__,
|
||||
lower=self._raw_lower,
|
||||
upper=self._raw_upper,
|
||||
default=self._raw_default,
|
||||
)
|
||||
return xrepr
|
||||
|
||||
|
||||
np_float_types = (np.float16, np.float32, np.float64)
|
||||
np_int_types = (
|
||||
np.uint8,
|
||||
np.int8,
|
||||
np.uint16,
|
||||
np.int16,
|
||||
np.uint32,
|
||||
np.int32,
|
||||
np.uint64,
|
||||
np.int64,
|
||||
)
|
||||
|
||||
|
||||
class Continuous(Space):
|
||||
"""A space contains the continuous values."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower: float,
|
||||
upper: float,
|
||||
default: Optional[float] = None,
|
||||
log: bool = False,
|
||||
eps: float = _EPS,
|
||||
):
|
||||
super(Continuous, self).__init__()
|
||||
self._lower = lower
|
||||
self._upper = upper
|
||||
self._default = default
|
||||
self._log_scale = log
|
||||
self._eps = eps
|
||||
|
||||
@property
|
||||
def lower(self):
|
||||
return self._lower
|
||||
|
||||
@property
|
||||
def upper(self):
|
||||
return self._upper
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def use_log(self):
|
||||
return self._log_scale
|
||||
|
||||
@property
|
||||
def eps(self):
|
||||
return self._eps
|
||||
|
||||
def abstract(self, reuse_last=False) -> Space:
|
||||
if reuse_last and self._last_abstract is not None:
|
||||
return self._last_abstract
|
||||
self._last_abstract = self.copy()
|
||||
return self._last_abstract
|
||||
|
||||
def random(self, recursion=True, reuse_last=False):
|
||||
del recursion
|
||||
if reuse_last and self._last_sample is not None:
|
||||
return self._last_sample
|
||||
if self._log_scale:
|
||||
sample = random.uniform(math.log(self._lower), math.log(self._upper))
|
||||
sample = math.exp(sample)
|
||||
else:
|
||||
sample = random.uniform(self._lower, self._upper)
|
||||
self._last_sample = VirtualNode(None, sample)
|
||||
return self._last_sample
|
||||
|
||||
def xrepr(self, depth=0):
|
||||
del depth
|
||||
xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
||||
name=self.__class__.__name__,
|
||||
lower=self._lower,
|
||||
upper=self._upper,
|
||||
default=self._default,
|
||||
log=self._log_scale,
|
||||
)
|
||||
return xrepr
|
||||
|
||||
def convert(self, x):
|
||||
if isinstance(x, np_float_types) and x.size == 1:
|
||||
return float(x), True
|
||||
elif isinstance(x, np_int_types) and x.size == 1:
|
||||
return float(x), True
|
||||
elif isinstance(x, int):
|
||||
return float(x), True
|
||||
elif isinstance(x, float):
|
||||
return float(x), True
|
||||
else:
|
||||
return None, False
|
||||
|
||||
def has(self, x):
|
||||
super().has(x)
|
||||
converted_x, success = self.convert(x)
|
||||
return success and self.lower <= converted_x <= self.upper
|
||||
|
||||
@property
|
||||
def determined(self):
|
||||
return abs(self.lower - self.upper) <= self._eps
|
||||
|
||||
def clean_last_sample(self):
|
||||
self._last_sample = None
|
||||
|
||||
def clean_last_abstract(self):
|
||||
self._last_abstract = None
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Continuous):
|
||||
return False
|
||||
if self is other:
|
||||
return True
|
||||
else:
|
||||
return (
|
||||
self.lower == other.lower
|
||||
and self.upper == other.upper
|
||||
and self.default == other.default
|
||||
and self.use_log == other.use_log
|
||||
and self.eps == other.eps
|
||||
)
|
Reference in New Issue
Block a user