Complete Super Linear
This commit is contained in:
@@ -12,5 +12,6 @@ from .basic_space import VirtualNode
|
||||
from .basic_op import has_categorical
|
||||
from .basic_op import has_continuous
|
||||
from .basic_op import is_determined
|
||||
from .basic_op import get_determined_value
|
||||
from .basic_op import get_min
|
||||
from .basic_op import get_max
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from spaces.basic_space import Space
|
||||
from spaces.basic_space import VirtualNode
|
||||
from spaces.basic_space import Integer
|
||||
from spaces.basic_space import Continuous
|
||||
from spaces.basic_space import Categorical
|
||||
@@ -26,6 +27,20 @@ def is_determined(space_or_value):
|
||||
return True
|
||||
|
||||
|
||||
def get_determined_value(space_or_value):
|
||||
if not is_determined(space_or_value):
|
||||
raise ValueError("This input is not determined: {:}".format(space_or_value))
|
||||
if isinstance(space_or_value, Space):
|
||||
if isinstance(space_or_value, Continuous):
|
||||
return space_or_value.lower
|
||||
elif isinstance(space_or_value, Categorical):
|
||||
return get_determined_value(space_or_value[0])
|
||||
else: # VirtualNode
|
||||
return space_or_value.value
|
||||
else:
|
||||
return space_or_value
|
||||
|
||||
|
||||
def get_max(space_or_value):
|
||||
if isinstance(space_or_value, Integer):
|
||||
return max(space_or_value.candidates)
|
||||
|
@@ -23,7 +23,7 @@ class Space(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractproperty
|
||||
def xrepr(self, indent=0) -> Text:
|
||||
def xrepr(self, prefix="") -> Text:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> Text:
|
||||
@@ -67,17 +67,27 @@ class VirtualNode(Space):
|
||||
self._value = value
|
||||
self._attributes = OrderedDict()
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
def append(self, key, value):
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(
|
||||
"Only accept string as a key instead of {:}".format(type(key))
|
||||
)
|
||||
if not isinstance(value, Space):
|
||||
raise ValueError("Invalid type of value: {:}".format(type(value)))
|
||||
# if value.determined:
|
||||
# raise ValueError("Can not attach a determined value: {:}".format(value))
|
||||
self._attributes[key] = value
|
||||
|
||||
def xrepr(self, indent=0) -> Text:
|
||||
strs = [self.__class__.__name__ + "("]
|
||||
def xrepr(self, prefix=" ") -> Text:
|
||||
strs = [self.__class__.__name__ + "(value={:}".format(self._value)]
|
||||
for key, value in self._attributes.items():
|
||||
strs.append(value.xrepr(indent + 2) + ",")
|
||||
strs.append(value.xrepr(prefix + " " + key + " = "))
|
||||
strs.append(")")
|
||||
return "\n".join(strs)
|
||||
return prefix + "".join(strs) if len(strs) == 2 else ",\n".join(strs)
|
||||
|
||||
def abstract(self) -> Space:
|
||||
node = VirtualNode(id(self))
|
||||
@@ -87,7 +97,10 @@ class VirtualNode(Space):
|
||||
return node
|
||||
|
||||
def random(self, recursion=True):
|
||||
raise NotImplementedError
|
||||
node = VirtualNode(None, self._value)
|
||||
for key, value in self._attributes.items():
|
||||
node.append(key, value.random(recursion))
|
||||
return node
|
||||
|
||||
def has(self, x) -> bool:
|
||||
for key, value in self._attributes.items():
|
||||
@@ -101,6 +114,7 @@ class VirtualNode(Space):
|
||||
def __getitem__(self, key):
|
||||
return self._attributes[key]
|
||||
|
||||
@property
|
||||
def determined(self) -> bool:
|
||||
for key, value in self._attributes.items():
|
||||
if not value.determined(x):
|
||||
@@ -165,20 +179,22 @@ class Categorical(Space):
|
||||
data.append(candidate.abstract())
|
||||
else:
|
||||
data.append(VirtualNode(id(candidate), candidate))
|
||||
return Categorical(*data, self._default)
|
||||
return Categorical(*data, default=self._default)
|
||||
|
||||
def random(self, recursion=True):
|
||||
sample = random.choice(self._candidates)
|
||||
if recursion and isinstance(sample, Space):
|
||||
return sample.random(recursion)
|
||||
sample = sample.random(recursion)
|
||||
if isinstance(sample, VirtualNode):
|
||||
return sample.copy()
|
||||
else:
|
||||
return sample
|
||||
return VirtualNode(None, sample)
|
||||
|
||||
def xrepr(self, indent=0):
|
||||
def xrepr(self, prefix=""):
|
||||
xrepr = "{name:}(candidates={cs:}, default_index={default:})".format(
|
||||
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
||||
)
|
||||
return " " * indent + xrepr
|
||||
return prefix + xrepr
|
||||
|
||||
def has(self, x):
|
||||
super().has(x)
|
||||
@@ -219,14 +235,14 @@ class Integer(Categorical):
|
||||
default = data.index(default)
|
||||
super(Integer, self).__init__(*data, default=default)
|
||||
|
||||
def xrepr(self, indent=0):
|
||||
def xrepr(self, prefix=""):
|
||||
xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format(
|
||||
name=self.__class__.__name__,
|
||||
lower=self._raw_lower,
|
||||
upper=self._raw_upper,
|
||||
default=self._raw_default,
|
||||
)
|
||||
return " " * indent + xrepr
|
||||
return prefix + xrepr
|
||||
|
||||
|
||||
np_float_types = (np.float16, np.float32, np.float64)
|
||||
@@ -286,11 +302,12 @@ class Continuous(Space):
|
||||
del recursion
|
||||
if self._log_scale:
|
||||
sample = random.uniform(math.log(self._lower), math.log(self._upper))
|
||||
return math.exp(sample)
|
||||
sample = math.exp(sample)
|
||||
else:
|
||||
return random.uniform(self._lower, self._upper)
|
||||
sample = random.uniform(self._lower, self._upper)
|
||||
return VirtualNode(None, sample)
|
||||
|
||||
def xrepr(self, indent=0):
|
||||
def xrepr(self, prefix=""):
|
||||
xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
||||
name=self.__class__.__name__,
|
||||
lower=self._lower,
|
||||
@@ -298,7 +315,7 @@ class Continuous(Space):
|
||||
default=self._default,
|
||||
log=self._log_scale,
|
||||
)
|
||||
return " " * indent + xrepr
|
||||
return prefix + xrepr
|
||||
|
||||
def convert(self, x):
|
||||
if isinstance(x, np_float_types) and x.size == 1:
|
||||
|
Reference in New Issue
Block a user