Re-org GeMOSA codes

This commit is contained in:
D-X-Y
2021-05-27 11:17:57 +08:00
parent a507f8dd94
commit 8961215416
8 changed files with 82 additions and 162 deletions

View File

@@ -1,6 +1,3 @@
import math
import random
from typing import List, Optional, Dict
import torch
import torch.utils.data as data
@@ -43,6 +40,18 @@ class SyntheticDEnv(data.Dataset):
self._oracle_map = oracle_map
self._num_per_task = num_per_task
self._noise = noise
self._meta_info = dict()
def set_regression(self):
self._meta_info["task"] = "regression"
def set_classification(self, num_classes):
self._meta_info["task"] = "classification"
self._meta_info["num_classes"] = int(num_classes)
@property
def meta_info(self):
return self._meta_info
@property
def min_timestamp(self):
@@ -60,13 +69,6 @@ class SyntheticDEnv(data.Dataset):
def mode(self):
return self._time_generator.mode
def random_timestamp(self, min_timestamp=None, max_timestamp=None):
if min_timestamp is None:
min_timestamp = self.min_timestamp
if max_timestamp is None:
max_timestamp = self.max_timestamp
return random.random() * (max_timestamp - min_timestamp) + min_timestamp
def get_timestamp(self, index):
if index is None:
timestamps = []