Add simple spaces
This commit is contained in:
@@ -5,3 +5,4 @@
|
||||
#####################################################
|
||||
|
||||
from .basic_space import Categorical
|
||||
from .basic_space import Continuous
|
||||
|
@@ -3,12 +3,15 @@
|
||||
#####################################################
|
||||
|
||||
import abc
|
||||
import math
|
||||
import random
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Space(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def random(self):
|
||||
def random(self, recursion=True):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -17,8 +20,12 @@ class Space(metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class Categorical(Space):
|
||||
def __init__(self, *data):
|
||||
def __init__(self, *data, default: Optional[int] = None):
|
||||
self._candidates = [*data]
|
||||
self._default = default
|
||||
assert self._default is None or 0 <= self._default < len(self._candidates), "default >= {:}".format(
|
||||
len(self._candidates)
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._candidates[index]
|
||||
@@ -27,7 +34,50 @@ class Categorical(Space):
|
||||
return len(self._candidates)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name:}(candidates={cs:})".format(name=self.__class__.__name__, cs=self._candidates)
|
||||
return "{name:}(candidates={cs:}, default_index={default:})".format(
|
||||
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
||||
)
|
||||
|
||||
def random(self):
|
||||
return random.choice(self._candidates)
|
||||
def random(self, recursion=True):
|
||||
sample = random.choice(self._candidates)
|
||||
if recursion and isinstance(sample, Space):
|
||||
return sample.random(recursion)
|
||||
else:
|
||||
return sample
|
||||
|
||||
|
||||
class Continuous(Space):
|
||||
def __init__(self, lower: float, upper: float, default: Optional[float] = None, log: bool = False):
|
||||
self._lower = lower
|
||||
self._upper = upper
|
||||
self._default = default
|
||||
self._log_scale = log
|
||||
|
||||
@property
|
||||
def lower(self):
|
||||
return self._lower
|
||||
|
||||
@property
|
||||
def upper(self):
|
||||
return self._upper
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._default
|
||||
|
||||
def __repr__(self):
|
||||
return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
||||
name=self.__class__.__name__,
|
||||
lower=self._lower,
|
||||
upper=self._upper,
|
||||
default=self._default,
|
||||
log=self._log_scale,
|
||||
)
|
||||
|
||||
def random(self, recursion=True):
|
||||
del recursion
|
||||
if self._log_scale:
|
||||
sample = random.uniform(math.log(self._lower), math.log(self._upper))
|
||||
return math.exp(sample)
|
||||
else:
|
||||
return random.uniform(self._lower, self._upper)
|
||||
|
Reference in New Issue
Block a user