LFNA ok on the valid data
This commit is contained in:
@@ -66,11 +66,6 @@ class SyntheticDEnv(data.Dataset):
|
||||
self._cov_functors = cov_functors
|
||||
|
||||
self._oracle_map = None
|
||||
self._seq_length = None
|
||||
|
||||
@property
|
||||
def seq_length(self):
|
||||
return self._seq_length
|
||||
|
||||
@property
|
||||
def min_timestamp(self):
|
||||
@@ -84,14 +79,12 @@ class SyntheticDEnv(data.Dataset):
|
||||
def timestamp_interval(self):
|
||||
return self._timestamp_generator.interval
|
||||
|
||||
def random_timestamp(self):
|
||||
return (
|
||||
random.random() * (self.max_timestamp - self.min_timestamp)
|
||||
+ self.min_timestamp
|
||||
)
|
||||
|
||||
def reset_max_seq_length(self, seq_length):
|
||||
self._seq_length = seq_length
|
||||
def random_timestamp(self, min_timestamp=None, max_timestamp=None):
|
||||
if min_timestamp is None:
|
||||
min_timestamp = self.min_timestamp
|
||||
if max_timestamp is None:
|
||||
max_timestamp = self.max_timestamp
|
||||
return random.random() * (max_timestamp - min_timestamp) + min_timestamp
|
||||
|
||||
def get_timestamp(self, index):
|
||||
if index is None:
|
||||
@@ -119,19 +112,7 @@ class SyntheticDEnv(data.Dataset):
|
||||
def __getitem__(self, index):
|
||||
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
|
||||
index, timestamp = self._timestamp_generator[index]
|
||||
if self._seq_length is None:
|
||||
return self.__call__(timestamp)
|
||||
else:
|
||||
noise = (
|
||||
random.random() * self.timestamp_interval * self._timestamp_noise_scale
|
||||
)
|
||||
timestamps = [
|
||||
timestamp + i * self.timestamp_interval + noise
|
||||
for i in range(self._seq_length)
|
||||
]
|
||||
# xdata = [self.__call__(timestamp) for timestamp in timestamps]
|
||||
# return zip_sequence(xdata)
|
||||
return self.seq_call(timestamps)
|
||||
return self.__call__(timestamp)
|
||||
|
||||
def seq_call(self, timestamps):
|
||||
with torch.no_grad():
|
||||
|
Reference in New Issue
Block a user