Re-org GeMOSA codes
This commit is contained in:
@@ -1,6 +1,3 @@
|
||||
import math
|
||||
import random
|
||||
from typing import List, Optional, Dict
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
@@ -43,6 +40,18 @@ class SyntheticDEnv(data.Dataset):
|
||||
self._oracle_map = oracle_map
|
||||
self._num_per_task = num_per_task
|
||||
self._noise = noise
|
||||
self._meta_info = dict()
|
||||
|
||||
def set_regression(self):
|
||||
self._meta_info["task"] = "regression"
|
||||
|
||||
def set_classification(self, num_classes):
|
||||
self._meta_info["task"] = "classification"
|
||||
self._meta_info["num_classes"] = int(num_classes)
|
||||
|
||||
@property
|
||||
def meta_info(self):
|
||||
return self._meta_info
|
||||
|
||||
@property
|
||||
def min_timestamp(self):
|
||||
@@ -60,13 +69,6 @@ class SyntheticDEnv(data.Dataset):
|
||||
def mode(self):
|
||||
return self._time_generator.mode
|
||||
|
||||
def random_timestamp(self, min_timestamp=None, max_timestamp=None):
|
||||
if min_timestamp is None:
|
||||
min_timestamp = self.min_timestamp
|
||||
if max_timestamp is None:
|
||||
max_timestamp = self.max_timestamp
|
||||
return random.random() * (max_timestamp - min_timestamp) + min_timestamp
|
||||
|
||||
def get_timestamp(self, index):
|
||||
if index is None:
|
||||
timestamps = []
|
||||
|
Reference in New Issue
Block a user