Update GeMOSA v4
This commit is contained in:
@@ -127,6 +127,10 @@ class SyntheticDEnv(data.Dataset):
|
||||
targets = torch.from_numpy(targets)
|
||||
else:
|
||||
targets = torch.Tensor(targets)
|
||||
if dataset.dtype == torch.float64:
|
||||
dataset = dataset.float()
|
||||
if targets.dtype == torch.float64:
|
||||
targets = targets.float()
|
||||
return torch.Tensor([timestamp]), (dataset, targets)
|
||||
|
||||
def __len__(self):
|
||||
|
Reference in New Issue
Block a user