Refine lib -> xautodl

This commit is contained in:
D-X-Y
2021-05-19 07:19:20 +00:00
parent bda202ce87
commit 5b9a028e60
19 changed files with 46 additions and 46 deletions

View File

@@ -6,7 +6,7 @@ import torch
import torch.nn as nn
from copy import deepcopy
from models.cell_operations import OPS
from xautodl.models.cell_operations import OPS
# Cell for NAS-Bench-201

View File

@@ -4,6 +4,7 @@
import torch
import torch.nn as nn
from copy import deepcopy
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR

View File

@@ -9,11 +9,11 @@ import torch
__all__ = ["get_model"]
from xlayers.super_core import SuperSequential
from xlayers.super_core import SuperLinear
from xlayers.super_core import SuperDropout
from xlayers.super_core import super_name2norm
from xlayers.super_core import super_name2activation
from xautodl.xlayers.super_core import SuperSequential
from xautodl.xlayers.super_core import SuperLinear
from xautodl.xlayers.super_core import SuperDropout
from xautodl.xlayers.super_core import super_name2norm
from xautodl.xlayers.super_core import super_name2activation
def get_model(config: Dict[Text, Any], **kwargs):

View File

@@ -7,8 +7,7 @@ import os, sys, time, torch
from typing import Optional, Text, Callable
# modules in AutoDL
from log_utils import AverageMeter
from log_utils import time_string
from xautodl.log_utils import AverageMeter, time_string
from .eval_funcs import obtain_accuracy

View File

@@ -4,8 +4,7 @@
import os, sys, time, torch
# modules in AutoDL
from log_utils import AverageMeter
from log_utils import time_string
from xautodl.log_utils import AverageMeter, time_string
from .eval_funcs import obtain_accuracy

View File

@@ -15,6 +15,6 @@ def obtain_accuracy(output, target, topk=(1,)):
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

View File

@@ -4,7 +4,7 @@
import os, time, copy, torch, pathlib
# modules in AutoDL
import xautodl.datasets
from xautodl import datasets
from xautodl.config_utils import load_config
from xautodl.procedures import prepare_seed, get_optim_scheduler
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time

View File

@@ -8,7 +8,6 @@ import pprint
import logging
from copy import deepcopy
from log_utils import pickle_load
import qlib
from qlib.utils import init_instance_by_config
from qlib.workflow import R

View File

@@ -2,8 +2,9 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
from log_utils import AverageMeter, time_string
from models import change_key
from xautodl.log_utils import AverageMeter, time_string
from xautodl.models import change_key
from .eval_funcs import obtain_accuracy

View File

@@ -4,8 +4,8 @@
import os, sys, time, torch
# modules in AutoDL
from log_utils import AverageMeter, time_string
from models import change_key
from xautodl.log_utils import AverageMeter, time_string
from xautodl.models import change_key
from .eval_funcs import obtain_accuracy

View File

@@ -5,7 +5,7 @@ import os, sys, time, torch
import torch.nn.functional as F
# modules in AutoDL
from log_utils import AverageMeter, time_string
from xautodl.log_utils import AverageMeter, time_string
from .eval_funcs import obtain_accuracy

View File

@@ -16,7 +16,7 @@ def prepare_seed(rand_seed):
def prepare_logger(xargs):
args = copy.deepcopy(xargs)
from log_utils import Logger
from xautodl.log_utils import Logger
logger = Logger(args.save_dir, args.rand_seed)
logger.log("Main Function with logger : {:}".format(logger))