Reformulate the synthetic codes

This commit is contained in:
D-X-Y
2021-04-22 23:08:43 +08:00
parent 78ca90459c
commit 731458f890
11 changed files with 568 additions and 362 deletions

View File

@@ -4,5 +4,7 @@
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
from .SearchDatasetWrap import SearchDataset
from .synthetic_adaptive_environment import QuadraticFunc, CubicFunc, QuarticFunc
from .synthetic_adaptive_environment import SynAdaptiveEnv
from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
from .math_base_funcs import DynamicQuadraticFunc
from .synthetic_utils import SinGenerator, ConstantGenerator
from .synthetic_env import SyntheticDEnv

View File

@@ -176,93 +176,31 @@ class QuarticFunc(FitFunc):
)
class SynAdaptiveEnv(data.Dataset):
"""The synethtic dataset for adaptive environment.
class DynamicQuadraticFunc(FitFunc):
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c."""
- x in [0, 1]
- y = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
- where
- the amplitude scale is a quadratic function of x
- the period-phase-shift is another quadratic function of x
def __init__(self, list_of_points=None):
super(DynamicQuadraticFunc, self).__init__(3, list_of_points)
self._timestamp = None
"""
def __init__(
self,
num: int = 100,
num_sin_phase: int = 7,
min_amplitude: float = 1,
max_amplitude: float = 4,
phase_shift: float = 0,
mode: Optional[str] = None,
):
self._amplitude_scale = QuadraticFunc(
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
def __getitem__(self, x):
self.check_valid()
return (
self._params[0][self._timestamp] * x * x
+ self._params[1][self._timestamp] * x
+ self._params[2][self._timestamp]
)
self._num_sin_phase = num_sin_phase
self._interval = 1.0 / (float(num) - 1)
self._total_num = num
def _getitem(self, x, weights):
raise NotImplementedError
fitting_data = []
temp_max_scalar = 2 ** (num_sin_phase - 1)
for i in range(num_sin_phase):
value = (2 ** i) / temp_max_scalar
next_value = (2 ** (i + 1)) / temp_max_scalar
for _phase in (0, 0.25, 0.5, 0.75):
inter_value = value + (next_value - value) * _phase
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
self._period_phase_shift = QuarticFunc(fitting_data)
# Training Set 60%
num_of_train = int(self._total_num * 0.6)
# Validation Set 20%
num_of_valid = int(self._total_num * 0.2)
# Test Set 20%
num_of_set = self._total_num - num_of_train - num_of_valid
all_indexes = list(range(self._total_num))
if mode is None:
self._indexes = all_indexes
elif mode.lower() in ("train", "training"):
self._indexes = all_indexes[:num_of_train]
elif mode.lower() in ("valid", "validation"):
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
elif mode.lower() in ("test", "testing"):
self._indexes = all_indexes[num_of_train + num_of_valid :]
else:
raise ValueError("Unkonwn mode of {:}".format(mode))
def __iter__(self):
self._iter_num = 0
return self
def __next__(self):
if self._iter_num >= len(self):
raise StopIteration
self._iter_num += 1
return self.__getitem__(self._iter_num - 1)
def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index = self._indexes[index]
position = self._interval * index
value = self._amplitude_scale[position] * math.sin(
self._period_phase_shift[position]
)
return index, position, value
def __len__(self):
return len(self._indexes)
def set_timestamp(self, timestamp):
self._timestamp = timestamp
def __repr__(self):
return (
"{name}({cur_num:}/{total} elements,\n"
"amplitude={amplitude},\n"
"period_phase_shift={period_phase_shift})".format(
name=self.__class__.__name__,
cur_num=self._total_num,
total=len(self),
amplitude=self._amplitude_scale,
period_phase_shift=self._period_phase_shift,
)
return "{name}(y = {a} * x^2 + {b} * x + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
)

View File

@@ -0,0 +1,81 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import math
import abc
import numpy as np
from typing import List, Optional
import torch
import torch.utils.data as data
from .synthetic_utils import UnifiedSplit
class SyntheticDEnv(UnifiedSplit, data.Dataset):
"""The synethtic dynamic environment."""
def __init__(
self,
mean_generators: List[data.Dataset],
cov_generators: List[List[data.Dataset]],
num_per_task: int = 5000,
mode: Optional[str] = None,
):
self._ndim = len(mean_generators)
assert self._ndim == len(
cov_generators
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_generators))
for cov_generator in cov_generators:
assert self._ndim == len(
cov_generator
), "length does not match {:} vs. {:}".format(
self._ndim, len(cov_generator)
)
self._num_per_task = num_per_task
self._total_num = len(mean_generators[0])
for mean_generator in mean_generators:
assert self._total_num == len(mean_generator)
for cov_generator in cov_generators:
for cov_g in cov_generator:
assert self._total_num == len(cov_g)
self._mean_generators = mean_generators
self._cov_generators = cov_generators
UnifiedSplit.__init__(self, self._total_num, mode)
def __iter__(self):
self._iter_num = 0
return self
def __next__(self):
if self._iter_num >= len(self):
raise StopIteration
self._iter_num += 1
return self.__getitem__(self._iter_num - 1)
def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index = self._indexes[index]
mean_list = [generator[index][-1] for generator in self._mean_generators]
cov_matrix = [
[cov_gen[index][-1] for cov_gen in cov_generator]
for cov_generator in self._cov_generators
]
dataset = np.random.multivariate_normal(
mean_list, cov_matrix, size=self._num_per_task
)
return index, torch.Tensor(dataset)
def __len__(self):
return len(self._indexes)
def __repr__(self):
return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task})".format(
name=self.__class__.__name__,
cur_num=len(self),
total=self._total_num,
ndim=self._ndim,
num_per_task=self._num_per_task,
)

View File

@@ -0,0 +1,157 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
import abc
import numpy as np
from typing import Optional
import torch
import torch.utils.data as data
from .math_base_funcs import QuadraticFunc, QuarticFunc
class UnifiedSplit:
"""A class to unify the split strategy."""
def __init__(self, total_num, mode):
# Training Set 60%
num_of_train = int(total_num * 0.6)
# Validation Set 20%
num_of_valid = int(total_num * 0.2)
# Test Set 20%
num_of_set = total_num - num_of_train - num_of_valid
all_indexes = list(range(total_num))
if mode is None:
self._indexes = all_indexes
elif mode.lower() in ("train", "training"):
self._indexes = all_indexes[:num_of_train]
elif mode.lower() in ("valid", "validation"):
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
elif mode.lower() in ("test", "testing"):
self._indexes = all_indexes[num_of_train + num_of_valid :]
else:
raise ValueError("Unkonwn mode of {:}".format(mode))
self._mode = mode
@property
def mode(self):
return self._mode
class SinGenerator(UnifiedSplit, data.Dataset):
"""The synethtic generator for the dynamically changing environment.
- x in [0, 1]
- y = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
- where
- the amplitude scale is a quadratic function of x
- the period-phase-shift is another quadratic function of x
"""
def __init__(
self,
num: int = 100,
num_sin_phase: int = 7,
min_amplitude: float = 1,
max_amplitude: float = 4,
phase_shift: float = 0,
mode: Optional[str] = None,
):
self._amplitude_scale = QuadraticFunc(
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
)
self._num_sin_phase = num_sin_phase
self._interval = 1.0 / (float(num) - 1)
self._total_num = num
fitting_data = []
temp_max_scalar = 2 ** (num_sin_phase - 1)
for i in range(num_sin_phase):
value = (2 ** i) / temp_max_scalar
next_value = (2 ** (i + 1)) / temp_max_scalar
for _phase in (0, 0.25, 0.5, 0.75):
inter_value = value + (next_value - value) * _phase
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
self._period_phase_shift = QuarticFunc(fitting_data)
UnifiedSplit.__init__(self, self._total_num, mode)
self._transform = lambda x: x
def __iter__(self):
self._iter_num = 0
return self
def __next__(self):
if self._iter_num >= len(self):
raise StopIteration
self._iter_num += 1
return self.__getitem__(self._iter_num - 1)
def set_transform(self, transform):
self._transform = transform
def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index = self._indexes[index]
position = self._interval * index
value = self._amplitude_scale[position] * math.sin(
self._period_phase_shift[position]
)
return index, position, self._transform(value)
def __len__(self):
return len(self._indexes)
def __repr__(self):
return (
"{name}({cur_num:}/{total} elements,\n"
"amplitude={amplitude},\n"
"period_phase_shift={period_phase_shift})".format(
name=self.__class__.__name__,
cur_num=len(self),
total=self._total_num,
amplitude=self._amplitude_scale,
period_phase_shift=self._period_phase_shift,
)
)
class ConstantGenerator(UnifiedSplit, data.Dataset):
"""The constant generator."""
def __init__(
self,
num: int = 100,
constant: float = 0.1,
mode: Optional[str] = None,
):
self._total_num = num
self._constant = constant
UnifiedSplit.__init__(self, self._total_num, mode)
def __iter__(self):
self._iter_num = 0
return self
def __next__(self):
if self._iter_num >= len(self):
raise StopIteration
self._iter_num += 1
return self.__getitem__(self._iter_num - 1)
def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index = self._indexes[index]
return index, index, self._constant
def __len__(self):
return len(self._indexes)
def __repr__(self):
return "{name}({cur_num:}/{total} elements)".format(
name=self.__class__.__name__,
cur_num=len(self),
total=self._total_num,
)