Update the sync data v1

This commit is contained in:
D-X-Y
2021-05-24 13:06:10 +08:00
parent da2575cc6c
commit 3ee0d348af
17 changed files with 228 additions and 274 deletions

View File

@@ -17,10 +17,10 @@ from .math_base_funcs import QuarticFunc
class ConstantFunc(FitFunc):
"""The constant function: f(x) = c."""
def __init__(self, constant=None):
def __init__(self, constant=None, xstr="x"):
param = dict()
param[0] = constant
super(ConstantFunc, self).__init__(0, None, param)
super(ConstantFunc, self).__init__(0, None, param, xstr)
def __call__(self, x):
self.check_valid()
@@ -37,6 +37,34 @@ class ConstantFunc(FitFunc):
class ComposedSinFunc(FitFunc):
"""The composed sin function that outputs:
f(x) = a * sin( b*x ) + c
"""
def __init__(self, params, xstr="x"):
super(ComposedSinFunc, self).__init__(3, None, params, xstr)
def __call__(self, x):
self.check_valid()
a = self._params[0]
b = self._params[1]
c = self._params[2]
return a * math.sin(b * x) + c
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a} * sin({b} * {x}) + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
class ComposedSinFuncV2(FitFunc):
"""The composed sin function that outputs:
f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
- the amplitude scale is a quadratic function of x
@@ -44,7 +72,7 @@ class ComposedSinFunc(FitFunc):
"""
def __init__(self, **kwargs):
super(ComposedSinFunc, self).__init__(0, None)
super(ComposedSinFuncV2, self).__init__(0, None)
self.fit(**kwargs)
def __call__(self, x):

View File

@@ -5,15 +5,13 @@ import math
import abc
import copy
import numpy as np
from typing import Optional
import torch
import torch.utils.data as data
class FitFunc(abc.ABC):
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, freedom: int, list_of_points=None, params=None):
def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"):
self._params = dict()
for i in range(freedom):
self._params[i] = None
@@ -24,6 +22,7 @@ class FitFunc(abc.ABC):
self.fit(list_of_points=list_of_points)
if params is not None:
self.set(params)
self._xstr = str(xstr)
def set(self, params):
self._params = copy.deepcopy(params)
@@ -33,6 +32,13 @@ class FitFunc(abc.ABC):
if value is None:
raise ValueError("The {:} is None".format(key))
@property
def xstr(self):
return self._xstr
def reset_xstr(self, xstr):
self._xstr = str(xstr)
@abc.abstractmethod
def __call__(self, x):
raise NotImplementedError
@@ -106,8 +112,8 @@ class FitFunc(abc.ABC):
class LinearFunc(FitFunc):
"""The linear function that outputs f(x) = a * x + b."""
def __init__(self, list_of_points=None, params=None):
super(LinearFunc, self).__init__(2, list_of_points, params)
def __init__(self, list_of_points=None, params=None, xstr="x"):
super(LinearFunc, self).__init__(2, list_of_points, params, xstr)
def __call__(self, x):
self.check_valid()
@@ -117,18 +123,19 @@ class LinearFunc(FitFunc):
return weights[0] * x + weights[1]
def __repr__(self):
return "{name}({a} * x + {b})".format(
return "{name}({a} * {x} + {b})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
x=self.xstr,
)
class QuadraticFunc(FitFunc):
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, list_of_points=None, params=None):
super(QuadraticFunc, self).__init__(3, list_of_points, params)
def __init__(self, list_of_points=None, params=None, xstr="x"):
super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr)
def __call__(self, x):
self.check_valid()
@@ -138,11 +145,12 @@ class QuadraticFunc(FitFunc):
return weights[0] * x * x + weights[1] * x + weights[2]
def __repr__(self):
return "{name}({a} * x^2 + {b} * x + {c})".format(
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
@@ -165,12 +173,13 @@ class CubicFunc(FitFunc):
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
def __repr__(self):
return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format(
return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
x=self.xstr,
)

View File

@@ -6,3 +6,4 @@ from .math_dynamic_funcs import DynamicLinearFunc
from .math_dynamic_funcs import DynamicQuadraticFunc
from .math_adv_funcs import ConstantFunc
from .math_adv_funcs import ComposedSinFunc
from .math_dynamic_generator import GaussianDGenerator

View File

@@ -15,20 +15,19 @@ from .math_base_funcs import FitFunc
class DynamicFunc(FitFunc):
"""The dynamic quadratic function, where each param is a function."""
def __init__(self, freedom: int, params=None):
super(DynamicFunc, self).__init__(freedom, None, params)
self._timestamp = None
def __init__(self, freedom: int, params=None, xstr="x"):
if params is not None:
for param in params:
param.reset_xstr("t") if isinstance(param, FitFunc) else None
super(DynamicFunc, self).__init__(freedom, None, params, xstr)
def __call__(self, x, timestamp=None):
def __call__(self, x, timestamp):
raise NotImplementedError
def _getitem(self, x, weights):
raise NotImplementedError
def set_timestamp(self, timestamp):
self._timestamp = timestamp
def noise_call(self, x, timestamp=None, std=0.1):
def noise_call(self, x, timestamp, std):
clean_y = self.__call__(x, timestamp)
if isinstance(clean_y, np.ndarray):
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
@@ -42,13 +41,10 @@ class DynamicLinearFunc(DynamicFunc):
The a and b is a function of timestamp.
"""
def __init__(self, params=None):
super(DynamicLinearFunc, self).__init__(3, params)
def __init__(self, params=None, xstr="x"):
super(DynamicLinearFunc, self).__init__(3, params, xstr)
def __call__(self, x, timestamp=None):
self.check_valid()
if timestamp is None:
timestamp = self._timestamp
def __call__(self, x, timestamp):
a = self._params[0](timestamp)
b = self._params[1](timestamp)
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
@@ -56,11 +52,11 @@ class DynamicLinearFunc(DynamicFunc):
return a * x + b
def __repr__(self):
return "{name}({a} * x + {b}, timestamp={timestamp})".format(
return "{name}({a} * {x} + {b})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
timestamp=self._timestamp,
x=self.xstr,
)

View File

@@ -0,0 +1,58 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import abc
import numpy as np
def assert_list_tuple(x):
assert isinstance(x, (list, tuple))
return len(x)
class DynamicGenerator(abc.ABC):
"""The dynamic quadratic function, where each param is a function."""
def __init__(self):
self._ndim = None
def __call__(self, time, num):
raise NotImplementedError
class GaussianDGenerator(DynamicGenerator):
def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)):
super(GaussianDGenerator, self).__init__()
self._ndim = assert_list_tuple(mean_functors)
assert self._ndim == len(
cov_functors
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors))
assert_list_tuple(cov_functors)
for cov_functor in cov_functors:
assert self._ndim == assert_list_tuple(
cov_functor
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor))
assert (
isinstance(trunc, (list, tuple)) and len(trunc) == 2 and trunc[0] < trunc[1]
)
self._mean_functors = mean_functors
self._cov_functors = cov_functors
if trunc is not None:
assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1]
self._trunc = trunc
def __call__(self, time, num):
mean_list = [functor(time) for functor in self._mean_functors]
cov_matrix = [
[abs(cov_gen(time)) for cov_gen in cov_functor]
for cov_functor in self._cov_functors
]
values = np.random.multivariate_normal(mean_list, cov_matrix, size=num)
if self._trunc is not None:
np.clip(values, self._trunc[0], self._trunc[1], out=values)
return values
def __repr__(self):
return "{name}({ndim} dims, trunc={trunc})".format(
name=self.__class__.__name__, ndim=self._ndim, trunc=self._trunc
)

View File

@@ -1,13 +1,14 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
#####################################################
import math
from .synthetic_utils import TimeStamp
from .synthetic_env import EnvSampler
from .synthetic_env import SyntheticDEnv
from .math_core import LinearFunc
from .math_core import DynamicLinearFunc
from .math_core import DynamicQuadraticFunc
from .math_core import ConstantFunc, ComposedSinFunc
from .math_core import GaussianDGenerator
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
@@ -17,42 +18,21 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio
if version == "v1":
mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1)
elif version == "v2":
mean_generator = ComposedSinFunc()
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5)
else:
raise ValueError("Unknown version: {:}".format(version))
dynamic_env = SyntheticDEnv(
[mean_generator],
[[std_generator]],
num_per_task=num_per_task,
timestamp_config=dict(
min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode
),
)
if version == "v1":
function = DynamicLinearFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
amplitude_scale=ConstantFunc(3.0),
num_sin_phase=9,
sin_speed_use_power=False,
data_generator = GaussianDGenerator(
[mean_generator], [[std_generator]], (-2, 2)
)
function_param[1] = ConstantFunc(constant=0.9)
elif version == "v2":
function = DynamicQuadraticFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
time_generator = TimeStamp(
min_timestamp=0, max_timestamp=math.pi * 6, num=total_timestamp, mode=mode
)
function_param[1] = ConstantFunc(constant=0.9)
function_param[2] = ComposedSinFunc(
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
oracle_map = DynamicLinearFunc(
params={
0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),
1: ComposedSinFunc(params={0: 1.5, 1: 0.4, 2: 2.2}),
}
)
dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task
)
else:
raise ValueError("Unknown version: {:}".format(version))
function.set(function_param)
# dynamic_env.set_oracle_map(copy.deepcopy(function))
dynamic_env.set_oracle_map(function)
return dynamic_env

View File

@@ -1,15 +1,9 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import math
import random
import numpy as np
from typing import List, Optional, Dict
import torch
import torch.utils.data as data
from .synthetic_utils import TimeStamp
def is_list_tuple(x):
return isinstance(x, (tuple, list))
@@ -38,46 +32,33 @@ class SyntheticDEnv(data.Dataset):
def __init__(
self,
mean_functors: List[data.Dataset],
cov_functors: List[List[data.Dataset]],
data_generator,
oracle_map,
time_generator,
num_per_task: int = 5000,
timestamp_config: Optional[Dict] = None,
mode: Optional[str] = None,
timestamp_noise_scale: float = 0.3,
noise: float = 0.1,
):
self._ndim = len(mean_functors)
assert self._ndim == len(
cov_functors
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors))
for cov_functor in cov_functors:
assert self._ndim == len(
cov_functor
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor))
self._data_generator = data_generator
self._time_generator = time_generator
self._oracle_map = oracle_map
self._num_per_task = num_per_task
if timestamp_config is None:
timestamp_config = dict(mode=mode)
elif "mode" not in timestamp_config:
timestamp_config["mode"] = mode
self._timestamp_generator = TimeStamp(**timestamp_config)
self._timestamp_noise_scale = timestamp_noise_scale
self._mean_functors = mean_functors
self._cov_functors = cov_functors
self._oracle_map = None
self._noise = noise
@property
def min_timestamp(self):
return self._timestamp_generator.min_timestamp
return self._time_generator.min_timestamp
@property
def max_timestamp(self):
return self._timestamp_generator.max_timestamp
return self._time_generator.max_timestamp
@property
def timestamp_interval(self):
return self._timestamp_generator.interval
def time_interval(self):
return self._time_generator.interval
@property
def mode(self):
return self._time_generator.mode
def random_timestamp(self, min_timestamp=None, max_timestamp=None):
if min_timestamp is None:
@@ -89,16 +70,13 @@ class SyntheticDEnv(data.Dataset):
def get_timestamp(self, index):
if index is None:
timestamps = []
for index in range(len(self._timestamp_generator)):
timestamps.append(self._timestamp_generator[index][1])
for index in range(len(self._time_generator)):
timestamps.append(self._time_generator[index][1])
return tuple(timestamps)
else:
index, timestamp = self._timestamp_generator[index]
index, timestamp = self._time_generator[index]
return timestamp
def set_oracle_map(self, functor):
self._oracle_map = functor
def __iter__(self):
self._iter_num = 0
return self
@@ -111,7 +89,7 @@ class SyntheticDEnv(data.Dataset):
def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index, timestamp = self._timestamp_generator[index]
index, timestamp = self._time_generator[index]
return self.__call__(timestamp)
def seq_call(self, timestamps):
@@ -122,52 +100,24 @@ class SyntheticDEnv(data.Dataset):
return zip_sequence(xdata)
def __call__(self, timestamp):
mean_list = [functor(timestamp) for functor in self._mean_functors]
cov_matrix = [
[abs(cov_gen(timestamp)) for cov_gen in cov_functor]
for cov_functor in self._cov_functors
]
dataset = np.random.multivariate_normal(
mean_list, cov_matrix, size=self._num_per_task
dataset = self._data_generator(timestamp, self._num_per_task)
targets = self._oracle_map.noise_call(dataset, timestamp, self._noise)
return torch.Tensor([timestamp]), (
torch.Tensor(dataset),
torch.Tensor(targets),
)
if self._oracle_map is None:
return torch.Tensor([timestamp]), torch.Tensor(dataset)
else:
targets = self._oracle_map.noise_call(dataset, timestamp)
return torch.Tensor([timestamp]), (
torch.Tensor(dataset),
torch.Tensor(targets),
)
def __len__(self):
return len(self._timestamp_generator)
return len(self._time_generator)
def __repr__(self):
return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task}, range=[{xrange_min:.5f}~{xrange_max:.5f}], mode={mode})".format(
name=self.__class__.__name__,
cur_num=len(self),
total=len(self._timestamp_generator),
total=len(self._time_generator),
ndim=self._ndim,
num_per_task=self._num_per_task,
xrange_min=self.min_timestamp,
xrange_max=self.max_timestamp,
mode=self._timestamp_generator.mode,
mode=self.mode,
)
class EnvSampler:
def __init__(self, env, batch, enlarge):
indexes = list(range(len(env)))
self._indexes = indexes * enlarge
self._batch = batch
self._iterations = len(self._indexes) // self._batch
def __iter__(self):
random.shuffle(self._indexes)
for it in range(self._iterations):
indexes = self._indexes[it * self._batch : (it + 1) * self._batch]
yield indexes
def __len__(self):
return self._iterations

View File

@@ -1,72 +0,0 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import copy
from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
from .synthetic_env import SyntheticDEnv
def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"):
if indicator == "v1":
return create_example_v1(timestamp_config, num_per_task)
elif indicator == "v2":
return create_example_v2(timestamp_config, num_per_task)
else:
raise ValueError("Unkonwn indicator: {:}".format(indicator))
def create_example_v1(
timestamp_config=None,
num_per_task=5000,
):
mean_generator = ComposedSinFunc()
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
dynamic_env = SyntheticDEnv(
[mean_generator],
[[std_generator]],
num_per_task=num_per_task,
timestamp_config=timestamp_config,
)
function = DynamicQuadraticFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
)
function_param[1] = ConstantFunc(constant=0.9)
function_param[2] = ComposedSinFunc(
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
)
function.set(function_param)
dynamic_env.set_oracle_map(copy.deepcopy(function))
return dynamic_env, function
def create_example_v2(
timestamp_config=None,
num_per_task=5000,
):
mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1)
dynamic_env = SyntheticDEnv(
[mean_generator],
[[std_generator]],
num_per_task=num_per_task,
timestamp_config=timestamp_config,
)
function = DynamicLinearFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0)
)
function_param[1] = ConstantFunc(constant=0.9)
function.set(function_param)
dynamic_env.set_oracle_map(copy.deepcopy(function))
return dynamic_env, function

View File

@@ -13,11 +13,11 @@ class UnifiedSplit:
"""A class to unify the split strategy."""
def __init__(self, total_num, mode):
# Training Set 60%
num_of_train = int(total_num * 0.6)
# Validation Set 20%
num_of_valid = int(total_num * 0.2)
# Test Set 20%
# Training Set 65%
num_of_train = int(total_num * 0.65)
# Validation Set 05%
num_of_valid = int(total_num * 0.05)
# Test Set 30%
num_of_set = total_num - num_of_train - num_of_valid
all_indexes = list(range(total_num))
if mode is None:
@@ -28,6 +28,8 @@ class UnifiedSplit:
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
elif mode.lower() in ("test", "testing"):
self._indexes = all_indexes[num_of_train + num_of_valid :]
elif mode.lower() in ("trainval", "trainvalidation"):
self._indexes = all_indexes[: num_of_train + num_of_valid]
else:
raise ValueError("Unkonwn mode of {:}".format(mode))
self._all_indexes = all_indexes