Update codes
This commit is contained in:
@@ -19,7 +19,7 @@ class SyntheticDEnv(data.Dataset):
|
||||
mean_functors: List[data.Dataset],
|
||||
cov_functors: List[List[data.Dataset]],
|
||||
num_per_task: int = 5000,
|
||||
time_stamp_config: Optional[Dict] = None,
|
||||
timestamp_config: Optional[Dict] = None,
|
||||
mode: Optional[str] = None,
|
||||
):
|
||||
self._ndim = len(mean_functors)
|
||||
@@ -31,12 +31,12 @@ class SyntheticDEnv(data.Dataset):
|
||||
cov_functor
|
||||
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor))
|
||||
self._num_per_task = num_per_task
|
||||
if time_stamp_config is None:
|
||||
time_stamp_config = dict(mode=mode)
|
||||
if timestamp_config is None:
|
||||
timestamp_config = dict(mode=mode)
|
||||
else:
|
||||
time_stamp_config["mode"] = mode
|
||||
timestamp_config["mode"] = mode
|
||||
|
||||
self._timestamp_generator = TimeStamp(**time_stamp_config)
|
||||
self._timestamp_generator = TimeStamp(**timestamp_config)
|
||||
|
||||
self._mean_functors = mean_functors
|
||||
self._cov_functors = cov_functors
|
||||
|
@@ -2,21 +2,23 @@
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#####################################################
|
||||
|
||||
from .math_base_funcs import DynamicQuadraticFunc
|
||||
from .math_adv_funcs import DynamicQuadraticFunc
|
||||
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
|
||||
from .synthetic_env import SyntheticDEnv
|
||||
|
||||
|
||||
def create_example_v1(timestamps=50, num_per_task=5000):
|
||||
def create_example_v1(
|
||||
timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
|
||||
num_per_task=5000,
|
||||
):
|
||||
mean_generator = ComposedSinFunc()
|
||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
|
||||
std_generator.set_transform(lambda x: x + 1)
|
||||
|
||||
dynamic_env = SyntheticDEnv(
|
||||
[mean_generator],
|
||||
[[std_generator]],
|
||||
num_per_task=num_per_task,
|
||||
time_stamp_config=dict(num=timestamps),
|
||||
timestamp_config=timestamp_config,
|
||||
)
|
||||
|
||||
function = DynamicQuadraticFunc()
|
||||
|
Reference in New Issue
Block a user