Fix bugs
This commit is contained in:
@@ -63,7 +63,7 @@ class SyntheticDEnv(data.Dataset):
|
||||
dataset = np.random.multivariate_normal(
|
||||
mean_list, cov_matrix, size=self._num_per_task
|
||||
)
|
||||
return index, torch.Tensor(dataset)
|
||||
return timestamp, torch.Tensor(dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._timestamp_generator)
|
||||
|
@@ -8,9 +8,10 @@ from .synthetic_env import SyntheticDEnv
|
||||
|
||||
|
||||
def create_example_v1(
|
||||
timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
|
||||
timestamp_config=None,
|
||||
num_per_task=5000,
|
||||
):
|
||||
# timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
|
||||
mean_generator = ComposedSinFunc()
|
||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
|
||||
|
||||
|
@@ -1,8 +1,10 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||
#####################################################
|
||||
# To be finished.
|
||||
#
|
||||
import os, sys, time, torch
|
||||
from typing import import Optional, Text, Callable
|
||||
from typing import Optional, Text, Callable
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter
|
||||
@@ -60,9 +62,10 @@ def procedure(
|
||||
network,
|
||||
criterion,
|
||||
optimizer,
|
||||
eval_metric,
|
||||
mode: Text,
|
||||
print_freq: int = 100,
|
||||
logger_fn: Callable = None
|
||||
logger_fn: Callable = None,
|
||||
):
|
||||
data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode.lower() == "train":
|
||||
@@ -90,7 +93,7 @@ def procedure(
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
metrics =
|
||||
metrics = eval_metric(logits.data, targets.data)
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#####################################################
|
||||
import abc
|
||||
|
||||
|
||||
def obtain_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
@@ -20,7 +21,6 @@ def obtain_accuracy(output, target, topk=(1,)):
|
||||
|
||||
|
||||
class EvaluationMetric(abc.ABC):
|
||||
|
||||
def __init__(self):
|
||||
self._total_metrics = 0
|
||||
|
||||
|
Reference in New Issue
Block a user