Update visualization codes
This commit is contained in:
29
lib/datasets/synthetic_core.py
Normal file
29
lib/datasets/synthetic_core.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import copy
|
||||
from .synthetic_env import SyntheticDEnv
|
||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
||||
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
|
||||
|
||||
|
||||
def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None):
|
||||
mean_generator = ComposedSinFunc()
|
||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5)
|
||||
dynamic_env = SyntheticDEnv(
|
||||
[mean_generator],
|
||||
[[std_generator]],
|
||||
num_per_task=num_per_task,
|
||||
timestamp_config=dict(
|
||||
min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode
|
||||
),
|
||||
)
|
||||
function = DynamicQuadraticFunc()
|
||||
function_param = dict()
|
||||
function_param[0] = ComposedSinFunc(
|
||||
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
|
||||
)
|
||||
function_param[1] = ConstantFunc(constant=0.9)
|
||||
function_param[2] = ComposedSinFunc(
|
||||
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
|
||||
)
|
||||
function.set(function_param)
|
||||
dynamic_env.set_oracle_map(copy.deepcopy(function))
|
||||
return dynamic_env
|
@@ -43,6 +43,14 @@ class SyntheticDEnv(data.Dataset):
|
||||
|
||||
self._oracle_map = None
|
||||
|
||||
@property
|
||||
def min_timestamp(self):
|
||||
return self._timestamp_generator.min_timestamp
|
||||
|
||||
@property
|
||||
def max_timestamp(self):
|
||||
return self._timestamp_generator.max_timestamp
|
||||
|
||||
def set_oracle_map(self, functor):
|
||||
self._oracle_map = functor
|
||||
|
||||
@@ -61,7 +69,7 @@ class SyntheticDEnv(data.Dataset):
|
||||
index, timestamp = self._timestamp_generator[index]
|
||||
mean_list = [functor(timestamp) for functor in self._mean_functors]
|
||||
cov_matrix = [
|
||||
[cov_gen(timestamp) for cov_gen in cov_functor]
|
||||
[abs(cov_gen(timestamp)) for cov_gen in cov_functor]
|
||||
for cov_functor in self._cov_functors
|
||||
]
|
||||
|
||||
|
@@ -53,6 +53,14 @@ class TimeStamp(UnifiedSplit, data.Dataset):
|
||||
self._total_num = num
|
||||
UnifiedSplit.__init__(self, self._total_num, mode)
|
||||
|
||||
@property
|
||||
def min_timestamp(self):
|
||||
return self._min_timestamp
|
||||
|
||||
@property
|
||||
def max_timestamp(self):
|
||||
return self._max_timestamp
|
||||
|
||||
def __iter__(self):
|
||||
self._iter_num = 0
|
||||
return self
|
||||
|
Reference in New Issue
Block a user