Refine lib -> xautodl
This commit is contained in:
@@ -6,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
|
||||
from models.cell_operations import OPS
|
||||
from xautodl.models.cell_operations import OPS
|
||||
|
||||
|
||||
# Cell for NAS-Bench-201
|
||||
|
@@ -4,6 +4,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
|
||||
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
|
||||
|
||||
|
||||
|
@@ -9,11 +9,11 @@ import torch
|
||||
__all__ = ["get_model"]
|
||||
|
||||
|
||||
from xlayers.super_core import SuperSequential
|
||||
from xlayers.super_core import SuperLinear
|
||||
from xlayers.super_core import SuperDropout
|
||||
from xlayers.super_core import super_name2norm
|
||||
from xlayers.super_core import super_name2activation
|
||||
from xautodl.xlayers.super_core import SuperSequential
|
||||
from xautodl.xlayers.super_core import SuperLinear
|
||||
from xautodl.xlayers.super_core import SuperDropout
|
||||
from xautodl.xlayers.super_core import super_name2norm
|
||||
from xautodl.xlayers.super_core import super_name2activation
|
||||
|
||||
|
||||
def get_model(config: Dict[Text, Any], **kwargs):
|
||||
|
@@ -7,8 +7,7 @@ import os, sys, time, torch
|
||||
from typing import Optional, Text, Callable
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter
|
||||
from log_utils import time_string
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
|
@@ -4,8 +4,7 @@
|
||||
import os, sys, time, torch
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter
|
||||
from log_utils import time_string
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
|
@@ -15,6 +15,6 @@ def obtain_accuracy(output, target, topk=(1,)):
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
@@ -4,7 +4,7 @@
|
||||
import os, time, copy, torch, pathlib
|
||||
|
||||
# modules in AutoDL
|
||||
import xautodl.datasets
|
||||
from xautodl import datasets
|
||||
from xautodl.config_utils import load_config
|
||||
from xautodl.procedures import prepare_seed, get_optim_scheduler
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
|
@@ -8,7 +8,6 @@ import pprint
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
from log_utils import pickle_load
|
||||
import qlib
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
|
@@ -2,8 +2,9 @@
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from models import change_key
|
||||
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from xautodl.models import change_key
|
||||
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
@@ -4,8 +4,8 @@
|
||||
import os, sys, time, torch
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter, time_string
|
||||
from models import change_key
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from xautodl.models import change_key
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
|
@@ -5,7 +5,7 @@ import os, sys, time, torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter, time_string
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
|
@@ -16,7 +16,7 @@ def prepare_seed(rand_seed):
|
||||
|
||||
def prepare_logger(xargs):
|
||||
args = copy.deepcopy(xargs)
|
||||
from log_utils import Logger
|
||||
from xautodl.log_utils import Logger
|
||||
|
||||
logger = Logger(args.save_dir, args.rand_seed)
|
||||
logger.log("Main Function with logger : {:}".format(logger))
|
||||
|
Reference in New Issue
Block a user