Finalize example vis codes

This commit is contained in:
D-X-Y
2021-04-27 20:09:37 +08:00
parent 77cab08d60
commit 5eb18e8adb
8 changed files with 98 additions and 61 deletions

View File

@@ -5,7 +5,8 @@ from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
from .SearchDatasetWrap import SearchDataset
from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
from .math_adv_funcs import DynamicQuadraticFunc, ConstantFunc
from .math_dynamic_funcs import DynamicQuadraticFunc
from .math_adv_funcs import ConstantFunc
from .math_adv_funcs import ComposedSinFunc
from .synthetic_utils import TimeStamp

View File

@@ -14,41 +14,6 @@ from .math_base_funcs import QuadraticFunc
from .math_base_funcs import QuarticFunc
class DynamicQuadraticFunc(FitFunc):
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
The a, b, and c is a function of timestamp.
"""
def __init__(self, list_of_points=None):
super(DynamicQuadraticFunc, self).__init__(3, list_of_points)
self._timestamp = None
def __call__(self, x, timestamp=None):
self.check_valid()
if timestamp is None:
timestamp = self._timestamp
a = self._params[0](timestamp)
b = self._params[1](timestamp)
c = self._params[2](timestamp)
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
return a * x * x + b * x + c
def _getitem(self, x, weights):
raise NotImplementedError
def set_timestamp(self, timestamp):
self._timestamp = timestamp
def __repr__(self):
return "{name}({a} * x^2 + {b} * x + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
)
class ConstantFunc(FitFunc):
"""The constant function: f(x) = c."""

View File

@@ -13,20 +13,20 @@ import torch.utils.data as data
class FitFunc(abc.ABC):
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, freedom: int, list_of_points=None, _params=None):
def __init__(self, freedom: int, list_of_points=None, params=None):
self._params = dict()
for i in range(freedom):
self._params[i] = None
self._freedom = freedom
if list_of_points is not None and _params is not None:
raise ValueError("list_of_points and _params can not be set simultaneously")
if list_of_points is not None and params is not None:
raise ValueError("list_of_points and params can not be set simultaneously")
if list_of_points is not None:
self.fit(list_of_points=list_of_points)
if _params is not None:
self.set(_params)
if params is not None:
self.set(params)
def set(self, _params):
self._params = copy.deepcopy(_params)
def set(self, params):
self._params = copy.deepcopy(params)
def check_valid(self):
for key, value in self._params.items():

View File

@@ -0,0 +1,66 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
import abc
import copy
import numpy as np
from typing import Optional
import torch
import torch.utils.data as data
from .math_base_funcs import FitFunc
class DynamicFunc(FitFunc):
"""The dynamic quadratic function, where each param is a function."""
def __init__(self, freedom: int, params=None):
super(DynamicFunc, self).__init__(freedom, None, params)
self._timestamp = None
def __call__(self, x, timestamp=None):
raise NotImplementedError
def _getitem(self, x, weights):
raise NotImplementedError
def set_timestamp(self, timestamp):
self._timestamp = timestamp
def noise_call(self, x, timestamp=None, std=0.1):
clean_y = self.__call__(x, timestamp)
if isinstance(clean_y, np.ndarray):
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
else:
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
return noise_y
class DynamicQuadraticFunc(DynamicFunc):
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
The a, b, and c is a function of timestamp.
"""
def __init__(self, params=None):
super(DynamicQuadraticFunc, self).__init__(3, params)
def __call__(self, x, timestamp=None):
self.check_valid()
if timestamp is None:
timestamp = self._timestamp
a = self._params[0](timestamp)
b = self._params[1](timestamp)
c = self._params[2](timestamp)
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
return a * x * x + b * x + c
def __repr__(self):
return "{name}({a} * x^2 + {b} * x + {c}, timestamp={timestamp})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
timestamp=self._timestamp,
)

View File

@@ -41,6 +41,11 @@ class SyntheticDEnv(data.Dataset):
self._mean_functors = mean_functors
self._cov_functors = cov_functors
self._oracle_map = None
def set_oracle_map(self, functor):
self._oracle_map = functor
def __iter__(self):
self._iter_num = 0
return self
@@ -63,7 +68,11 @@ class SyntheticDEnv(data.Dataset):
dataset = np.random.multivariate_normal(
mean_list, cov_matrix, size=self._num_per_task
)
return timestamp, torch.Tensor(dataset)
if self._oracle_map is None:
return timestamp, torch.Tensor(dataset)
else:
targets = self._oracle_map.noise_call(dataset, timestamp)
return timestamp, (torch.Tensor(dataset), torch.Tensor(targets))
def __len__(self):
return len(self._timestamp_generator)

View File

@@ -1,8 +1,9 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import copy
from .math_adv_funcs import DynamicQuadraticFunc
from .math_dynamic_funcs import DynamicQuadraticFunc
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
from .synthetic_env import SyntheticDEnv
@@ -11,7 +12,6 @@ def create_example_v1(
timestamp_config=None,
num_per_task=5000,
):
# timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
mean_generator = ComposedSinFunc()
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
@@ -32,4 +32,6 @@ def create_example_v1(
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
)
function.set(function_param)
dynamic_env.set_oracle_map(copy.deepcopy(function))
return dynamic_env, function