LFNA ok on the valid data

This commit is contained in:
D-X-Y
2021-05-23 19:14:12 +00:00
parent 63a0361152
commit b1064e5a60
3 changed files with 27 additions and 55 deletions

View File

@@ -66,11 +66,6 @@ class SyntheticDEnv(data.Dataset):
self._cov_functors = cov_functors
self._oracle_map = None
self._seq_length = None
@property
def seq_length(self):
return self._seq_length
@property
def min_timestamp(self):
@@ -84,14 +79,12 @@ class SyntheticDEnv(data.Dataset):
def timestamp_interval(self):
return self._timestamp_generator.interval
def random_timestamp(self):
return (
random.random() * (self.max_timestamp - self.min_timestamp)
+ self.min_timestamp
)
def reset_max_seq_length(self, seq_length):
self._seq_length = seq_length
def random_timestamp(self, min_timestamp=None, max_timestamp=None):
if min_timestamp is None:
min_timestamp = self.min_timestamp
if max_timestamp is None:
max_timestamp = self.max_timestamp
return random.random() * (max_timestamp - min_timestamp) + min_timestamp
def get_timestamp(self, index):
if index is None:
@@ -119,19 +112,7 @@ class SyntheticDEnv(data.Dataset):
def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index, timestamp = self._timestamp_generator[index]
if self._seq_length is None:
return self.__call__(timestamp)
else:
noise = (
random.random() * self.timestamp_interval * self._timestamp_noise_scale
)
timestamps = [
timestamp + i * self.timestamp_interval + noise
for i in range(self._seq_length)
]
# xdata = [self.__call__(timestamp) for timestamp in timestamps]
# return zip_sequence(xdata)
return self.seq_call(timestamps)
return self.__call__(timestamp)
def seq_call(self, timestamps):
with torch.no_grad():