From ad34af99135bad2d92942ff688d6b3766e653f80 Mon Sep 17 00:00:00 2001
From: D-X-Y <280835372@qq.com>
Date: Thu, 9 Jan 2020 22:26:23 +1100
Subject: [PATCH] update code styles

---
 BASELINE.md                                   | 66 +++++++++----------
 CONTRIBUTING.md                               |  2 +-
 NAS-Bench-102.md                              | 10 +--
 exps/NAS-Bench-102/visualize.py               |  7 +-
 exps/algos/BOHB.py                            |  1 -
 exps/algos/DARTS-V1.py                        | 20 ++++--
 exps/algos/DARTS-V2.py                        |  7 +-
 exps/algos/GDAS.py                            |  4 +-
 exps/vis/test.py                              | 37 ++++++++++-
 lib/datasets/SearchDatasetWrap.py             | 32 ++++++---
 lib/datasets/landmark_utils/point_meta.py     |  2 +-
 lib/models/cell_infers/cells.py               |  1 -
 lib/models/cell_searchs/genotypes.py          |  2 +-
 lib/models/cell_searchs/search_cells.py       |  2 +-
 .../shape_infers/InferCifarResNet_width.py    |  2 +-
 lib/nas_102_api/api.py                        |  3 +
 lib/procedures/starts.py                      |  2 +-
 lib/utils/__init__.py                         |  2 +
 lib/{xvision => utils}/affine_utils.py        |  7 --
 lib/utils/flop_benchmark.py                   |  2 +-
 lib/utils/gpu_manager.py                      |  2 +-
 lib/utils/nas_utils.py                        | 52 +++++++++++++++
 lib/xvision/__init__.py                       |  1 -
 scripts-search/algos/DARTS-V2.sh              |  1 +
 scripts/prepare.sh                            |  2 +-
 scripts/tas-infer-train.sh                    |  4 +-
 26 files changed, 192 insertions(+), 81 deletions(-)
 rename lib/{xvision => utils}/affine_utils.py (95%)
 create mode 100644 lib/utils/nas_utils.py
 delete mode 100644 lib/xvision/__init__.py

diff --git a/BASELINE.md b/BASELINE.md
index d7d923a..7fd5162 100644
--- a/BASELINE.md
+++ b/BASELINE.md
@@ -40,39 +40,39 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_
 
 ## Performance on ImageNet
 
-|      Model     | FLOPs (GB) | Params (M) | Top-1 Error | Top-5 Error |  Optimizer |
-|:--------------:|:----------:|:----------:|:-----------:|:-----------:|:----------:|
-| ResNet-18      | 1.814      |  11.69     |   30.24     |   10.92     | Official   |
-| ResNet-18      | 1.814      |  11.69     |   29.97     |   10.43     | Step-120   |
-| ResNet-18      | 1.814      |  11.69     |   29.35     |   10.13     | Cosine-120 |
-| ResNet-18      | 1.814      |  11.69     |   29.45     |   10.25     | Cosine-120 B1024 |
-| ResNet-18      | 1.814      |  11.69     |   29.44     |   10.12     |Cosine-S-120|
-| ResNet-18 (DS) | 2.053      |  11.71     |   28.53     |   9.69      |Cosine-S-120|
-| ResNet-34      | 3.663      |  21.80     |   25.65     |   8.06      |Cosine-120  |
-| ResNet-34 (DS) | 3.903      |  21.82     |   25.05     |   7.67      |Cosine-S-120|
-| ResNet-50      | 4.089      |  25.56     |   23.85     |   7.13      | Official   |
-| ResNet-50      | 4.089      |  25.56     |   22.54     |   6.45      |Cosine-120  |
-| ResNet-50      | 4.089      |  25.56     |   22.71     |   6.38      |Cosine-120 B1024 |
-| ResNet-50      | 4.089      |  25.56     |   22.34     |   6.22      |Cosine-S-120|
-| ResNet-50 (DS) | 4.328      |  25.58     |   22.67     |   6.39      | Step-120   |
-| ResNet-50 (DS) | 4.328      |  25.58     |   21.94     |   6.23      | Cosine-120 |
-| ResNet-50 (DS) | 4.328      |  25.58     |   21.71     |   5.99      |Cosine-S-120|
-| ResNet-101     | 7.801      |  44.55     |   20.93     |   5.57      |Cosine-120  |
-| ResNet-101     | 7.801      |  44.55     |   20.92     |   5.58      |Cosine-120 B1024 |
-| ResNet-101 (DS)| 8.041      |  44.57     |   20.36     |   5.22      |Cosine-S-120|
-| ResNet-152     | 11.514     |  60.19     |   20.10     |   5.17      |Cosine-120 B1024 |
-| ResNet-152 (DS)| 11.753     |  60.21     |   19.83     |   5.02      |Cosine-S-120|
-| ResNet-200     | 15.007     |  64.67     |   20.06     |   4.98      |Cosine-S-120|
-| Next50-32x4d (DS)| 4.2      |  25.0      |   22.2      |     -       | Official   |
-| Next50-32x4d (DS)| 4.470    |  25.05     |   21.16     |   5.65      |Cosine-S-120|
-| MobileNet-V2   | 0.300      |  3.40      |   28.0      |     -       | Official   |
-| MobileNet-V2   | 0.300      |  3.50      |   27.92     |   9.50      | MobileFast |
-| MobileNet-V2   | 0.300      |  3.50      |   27.56     |   9.26      | MobileFast-Smooth |
-| ShuffleNet-V2 1.0| 0.146    |  2.28      |   30.6      |   11.1      | Official   |
-| ShuffleNet-V2 1.0| 0.145    |  2.28      |             |             |Cosine-S-120|
-| ShuffleNet-V2 1.5| 0.299    |            |   27.4      |     -       | Official   |
-| ShuffleNet-V2 1.5|          |            |             |             |Cosine-S-120|
-| ShuffleNet-V2 2.0|          |            |             |             |Cosine-S-120|
+|        Model      | FLOPs (GB) | Params (M) | Top-1 Error | Top-5 Error |  Optimizer |
+|:-----------------:|:----------:|:----------:|:-----------:|:-----------:|:----------:|
+| ResNet-18         | 1.814      |  11.69     |   30.24     |   10.92     | Official   |
+| ResNet-18         | 1.814      |  11.69     |   29.97     |   10.43     | Step-120   |
+| ResNet-18         | 1.814      |  11.69     |   29.35     |   10.13     | Cosine-120 |
+| ResNet-18         | 1.814      |  11.69     |   29.45     |   10.25     | Cosine-120 B1024 |
+| ResNet-18         | 1.814      |  11.69     |   29.44     |   10.12     | Cosine-S-120 |
+| ResNet-18 (DS)    | 2.053      |  11.71     |   28.53     |   9.69      | Cosine-S-120 |
+| ResNet-34         | 3.663      |  21.80     |   25.65     |   8.06      | Cosine-120   |
+| ResNet-34 (DS)    | 3.903      |  21.82     |   25.05     |   7.67      | Cosine-S-120 |
+| ResNet-50         | 4.089      |  25.56     |   23.85     |   7.13      | Official     |
+| ResNet-50         | 4.089      |  25.56     |   22.54     |   6.45      | Cosine-120   |
+| ResNet-50         | 4.089      |  25.56     |   22.71     |   6.38      | Cosine-120 B1024 |
+| ResNet-50         | 4.089      |  25.56     |   22.34     |   6.22      | Cosine-S-120 |
+| ResNet-50 (DS)    | 4.328      |  25.58     |   22.67     |   6.39      | Step-120     |
+| ResNet-50 (DS)    | 4.328      |  25.58     |   21.94     |   6.23      | Cosine-120   |
+| ResNet-50 (DS)    | 4.328      |  25.58     |   21.71     |   5.99      | Cosine-S-120 |
+| ResNet-101        | 7.801      |  44.55     |   20.93     |   5.57      | Cosine-120   |
+| ResNet-101        | 7.801      |  44.55     |   20.92     |   5.58      | Cosine-120 B1024 |
+| ResNet-101 (DS)   | 8.041      |  44.57     |   20.36     |   5.22      | Cosine-S-120 |
+| ResNet-152        | 11.514     |  60.19     |   20.10     |   5.17      | Cosine-120 B1024 |
+| ResNet-152 (DS)   | 11.753     |  60.21     |   19.83     |   5.02      | Cosine-S-120 |
+| ResNet-200        | 15.007     |  64.67     |   20.06     |   4.98      | Cosine-S-120 |
+| Next50-32x4d (DS) | 4.2        |  25.0      |   22.2      |     -       | Official     |
+| Next50-32x4d (DS) | 4.470      |  25.05     |   21.16     |   5.65      | Cosine-S-120 |
+| MobileNet-V2      | 0.300      |  3.40      |   28.0      |     -       | Official     |
+| MobileNet-V2      | 0.300      |  3.50      |   27.92     |   9.50      | MobileFast   |
+| MobileNet-V2      | 0.300      |  3.50      |   27.56     |   9.26      | MobileFast-Smooth |
+| ShuffleNet-V2 1.0 | 0.146      |  2.28      |   30.6      |   11.1      | Official     |
+| ShuffleNet-V2 1.0 | 0.145      |  2.28      |             |             | Cosine-S-120 |
+| ShuffleNet-V2 1.5 | 0.299      |            |   27.4      |     -       | Official     |
+| ShuffleNet-V2 1.5 |            |            |             |             | Cosine-S-120 |
+| ShuffleNet-V2 2.0 |            |            |             |             | Cosine-S-120 |
 
 `DS` indicates deep-stem for the first convolutional layer.
 ```
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index c057f42..45211f0 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -4,7 +4,7 @@
 
 The following is a set of guidelines for contributing to NAS-Projects.
 
-#### Table Of Contents
+## Table Of Contents
 
 [How Can I Contribute?](#how-can-i-contribute)
   * [Reporting Bugs](#reporting-bugs)
diff --git a/NAS-Bench-102.md b/NAS-Bench-102.md
index f14ce1e..530fba6 100644
--- a/NAS-Bench-102.md
+++ b/NAS-Bench-102.md
@@ -6,9 +6,9 @@ Each edge here is associated with an operation selected from a predefined operat
 For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-102 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total.
 
 In this Markdown file, we provide:
-- [How to Use NAS-Bench-102](#how-to-use-nas-bench-102)
-- [Instruction to re-generate NAS-Bench-102](#instruction-to-re-generate-nas-bench-102)
-- [10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-102)
+-	[How to Use NAS-Bench-102](#how-to-use-nas-bench-102)
+-	[Instruction to re-generate NAS-Bench-102](#instruction-to-re-generate-nas-bench-102)
+-	[10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-102)
 
 Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
 
@@ -140,6 +140,8 @@ This command will train 390 architectures (id from 0 to 389) using the following
 | CIFAR-100       | train         | valid / test |
 | ImageNet-16-120 | train         | valid / test |
 
+Note that the above `train`, `valid`, and `test` indicate the proposed splits in our NAS-Bench-102, and they might be different with the original splits.
+
 3. calculate the latency, merge the results of all architectures, and simplify the results.
 (see commands in `output/NAS-BENCH-102-4/meta-node-4.cal-script.txt` which is automatically generated by step-1).
 ```
@@ -167,7 +169,7 @@ If researchers can provide better results with different hyper-parameters, we ar
 
 **Note that** you need to prepare the training and test data as described in [Preparation and Download](#preparation-and-download)
 
-- [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1`
+- [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1`, where `cifar10` can be replaced with `cifar100` or `ImageNet16-120`.
 - [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1`
 - [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh     cifar10 -1`
 - [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh     cifar10 -1`
diff --git a/exps/NAS-Bench-102/visualize.py b/exps/NAS-Bench-102/visualize.py
index a41b53c..97be2f4 100644
--- a/exps/NAS-Bench-102/visualize.py
+++ b/exps/NAS-Bench-102/visualize.py
@@ -8,7 +8,6 @@ from tqdm import tqdm
 from collections import OrderedDict
 import numpy as np
 import torch
-import torch.nn as nn
 from pathlib import Path
 from collections import defaultdict
 import matplotlib
@@ -498,6 +497,8 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
 
   def get_accs(xdata):
     epochs, xresults = xdata['epoch'], []
+    metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False)
+    xresults.append( metrics['accuracy'] )
     for iepoch in range(epochs):
       genotype = xdata['genotypes'][iepoch]
       index = api.query_index_by_arch(genotype)
@@ -547,7 +548,6 @@ if __name__ == '__main__':
   #visualize_relative_ranking(vis_save_dir)
 
   api = API(args.api_path)
-  """
   for x_maxs in [50, 250]:
     show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
     show_nas_sharing_w(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
@@ -555,11 +555,12 @@ if __name__ == '__main__':
     show_nas_sharing_w(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
     show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
     show_nas_sharing_w(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
-  just_show(api)
   """
+  just_show(api)
   plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
   plot_results_nas(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
   plot_results_nas(api, 'cifar100'      , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
   plot_results_nas(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-com.pdf', (55,75, 3))
   plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
   plot_results_nas(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-com.pdf', (35,50, 3))
+  """
diff --git a/exps/algos/BOHB.py b/exps/algos/BOHB.py
index 244556e..4fc30b4 100644
--- a/exps/algos/BOHB.py
+++ b/exps/algos/BOHB.py
@@ -10,7 +10,6 @@ from copy import deepcopy
 from pathlib import Path
 import torch
 import torch.nn as nn
-from torch.distributions import Categorical
 lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
 if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
 from config_utils import load_config, dict2config, configure2str
diff --git a/exps/algos/DARTS-V1.py b/exps/algos/DARTS-V1.py
index da3769a..11516e9 100644
--- a/exps/algos/DARTS-V1.py
+++ b/exps/algos/DARTS-V1.py
@@ -121,9 +121,19 @@ def main(xargs):
     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
   elif xargs.dataset == 'cifar100':
-    raise ValueError('not support yet : {:}'.format(xargs.dataset))
-  elif xargs.dataset.startswith('ImageNet16'):
-    raise ValueError('not support yet : {:}'.format(xargs.dataset))
+    cifar100_test_split = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None)
+    search_train_data = train_data
+    search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
+    search_data   = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid)
+    search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
+    valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=xargs.workers, pin_memory=True)
+  elif xargs.dataset == 'ImageNet16-120':
+    imagenet_test_split = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None)
+    search_train_data = train_data
+    search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
+    search_data   = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid)
+    search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
+    valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=xargs.workers, pin_memory=True)
   else:
     raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
@@ -168,7 +178,7 @@ def main(xargs):
     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
   else:
     logger.log("=> do not find the last-info file : {:}".format(last_info))
-    start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
+    start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
 
   # start training
   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
@@ -230,7 +240,7 @@ if __name__ == '__main__':
   parser.add_argument('--data_path',          type=str,   help='Path to dataset')
   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
   # channels and number-of-cells
-  parser.add_argument('--config_path',        type=str,   help='The config paths.')
+  parser.add_argument('--config_path',        type=str,   help='The config path.')
   parser.add_argument('--search_space_name',  type=str,   help='The search space name.')
   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.')
   parser.add_argument('--channel',            type=int,   help='The number of channels.')
diff --git a/exps/algos/DARTS-V2.py b/exps/algos/DARTS-V2.py
index 667de61..6f0a24e 100644
--- a/exps/algos/DARTS-V2.py
+++ b/exps/algos/DARTS-V2.py
@@ -181,8 +181,8 @@ def main(xargs):
     logger.log('Load split file from {:}'.format(split_Fpath))
   else:
     raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
-  config_path = 'configs/nas-benchmark/algos/DARTS.config'
-  config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
+  #config_path = 'configs/nas-benchmark/algos/DARTS.config'
+  config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
   # To split data
   train_data_v2 = deepcopy(train_data)
   train_data_v2.transform = valid_data.transform
@@ -233,7 +233,7 @@ def main(xargs):
     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
   else:
     logger.log("=> do not find the last-info file : {:}".format(last_info))
-    start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
+    start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
 
   # start training
   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
@@ -297,6 +297,7 @@ if __name__ == '__main__':
   parser.add_argument('--data_path',          type=str,   help='Path to dataset')
   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
   # channels and number-of-cells
+  parser.add_argument('--config_path',        type=str,   help='The config path.')
   parser.add_argument('--search_space_name',  type=str,   help='The search space name.')
   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.')
   parser.add_argument('--channel',            type=int,   help='The number of channels.')
diff --git a/exps/algos/GDAS.py b/exps/algos/GDAS.py
index 84432b8..1209d7c 100644
--- a/exps/algos/GDAS.py
+++ b/exps/algos/GDAS.py
@@ -3,7 +3,7 @@
 ###########################################################################
 # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
 ###########################################################################
-import os, sys, time, glob, random, argparse
+import os, sys, time, random, argparse
 import numpy as np
 from copy import deepcopy
 import torch
@@ -11,7 +11,7 @@ import torch.nn as nn
 from pathlib import Path
 lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
 if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
-from config_utils import load_config, dict2config, configure2str
+from config_utils import load_config, dict2config
 from datasets     import get_datasets, SearchDataset
 from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
 from utils        import get_model_infos, obtain_accuracy
diff --git a/exps/vis/test.py b/exps/vis/test.py
index 99cd31a..a8d65aa 100644
--- a/exps/vis/test.py
+++ b/exps/vis/test.py
@@ -1,12 +1,14 @@
 # python ./exps/vis/test.py
 import os, sys, random
 from pathlib import Path
+from copy import deepcopy
 import torch
 import numpy as np
 from collections import OrderedDict
 lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
 if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
 
+from nas_102_api import NASBench102API as API
 
 def test_nas_api():
   from nas_102_api import ArchResults
@@ -72,7 +74,40 @@ def test_auto_grad():
     s_grads = torch.autograd.grad(grads, net.parameters())
     second_order_grads.append( s_grads )
 
+
+def test_one_shot_model(ckpath, use_train):
+  from models import get_cell_based_tiny_net, get_search_spaces
+  from datasets import get_datasets, SearchDataset
+  from config_utils import load_config, dict2config
+  from utils.nas_utils import evaluate_one_shot
+  use_train = int(use_train) > 0
+  #ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
+  #ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
+  print ('ckpath : {:}'.format(ckpath))
+  ckp = torch.load(ckpath)
+  xargs = ckp['args']
+  train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
+  config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
+  if xargs.dataset == 'cifar10':
+    cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
+    xvalid_data = deepcopy(train_data)
+    xvalid_data.transform = valid_data.transform
+    valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True)
+  else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet))
+  search_space = get_search_spaces('cell', xargs.search_space_name)
+  model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
+                              'max_nodes': xargs.max_nodes, 'num_classes': class_num,
+                              'space'    : search_space,
+                              'affine'   : False, 'track_running_stats': True}, None)
+  search_model = get_cell_based_tiny_net(model_config)
+  search_model.load_state_dict( ckp['search_model'] )
+  search_model = search_model.cuda()
+  api = API('/home/dxy/.torch/NAS-Bench-102-v1_0-e61699.pth')
+  archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
+
+
 if __name__ == '__main__':
   #test_nas_api()
   #for i in range(200): plot('{:04d}'.format(i))
-  test_auto_grad()
+  #test_auto_grad()
+  test_one_shot_model(sys.argv[1], sys.argv[2])
diff --git a/lib/datasets/SearchDatasetWrap.py b/lib/datasets/SearchDatasetWrap.py
index d20dd1f..06c5191 100644
--- a/lib/datasets/SearchDatasetWrap.py
+++ b/lib/datasets/SearchDatasetWrap.py
@@ -9,16 +9,25 @@ class SearchDataset(data.Dataset):
 
   def __init__(self, name, data, train_split, valid_split, check=True):
     self.datasetname = name
-    self.data        = data
-    self.train_split = train_split.copy()
-    self.valid_split = valid_split.copy()
-    if check:
-      intersection = set(train_split).intersection(set(valid_split))
-      assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection'
+    if isinstance(data, (list, tuple)): # new type of SearchDataset
+      assert len(data) == 2, 'invalid length: {:}'.format( len(data) )
+      self.train_data  = data[0]
+      self.valid_data  = data[1]
+      self.train_split = train_split.copy()
+      self.valid_split = valid_split.copy()
+      self.mode_str    = 'V2' # new mode 
+    else:
+      self.mode_str    = 'V1' # old mode 
+      self.data        = data
+      self.train_split = train_split.copy()
+      self.valid_split = valid_split.copy()
+      if check:
+        intersection = set(train_split).intersection(set(valid_split))
+        assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection'
     self.length      = len(self.train_split)
 
   def __repr__(self):
-    return ('{name}(name={datasetname}, train={tr_L}, valid={val_L})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split)))
+    return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str))
 
   def __len__(self):
     return self.length
@@ -27,6 +36,11 @@ class SearchDataset(data.Dataset):
     assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
     train_index = self.train_split[index]
     valid_index = random.choice( self.valid_split )
-    train_image, train_label = self.data[train_index]
-    valid_image, valid_label = self.data[valid_index]
+    if self.mode_str == 'V1':
+      train_image, train_label = self.data[train_index]
+      valid_image, valid_label = self.data[valid_index]
+    elif self.mode_str == 'V2':
+      train_image, train_label = self.train_data[train_index]
+      valid_image, valid_label = self.valid_data[valid_index]
+    else: raise ValueError('invalid mode : {:}'.format(self.mode_str))
     return train_image, train_label, valid_image, valid_label
diff --git a/lib/datasets/landmark_utils/point_meta.py b/lib/datasets/landmark_utils/point_meta.py
index 3e24638..2091970 100644
--- a/lib/datasets/landmark_utils/point_meta.py
+++ b/lib/datasets/landmark_utils/point_meta.py
@@ -34,7 +34,7 @@ class PointMeta():
 
   def get_box(self, return_diagonal=False):
     if self.box is None: return None
-    if return_diagonal == False:
+    if not return_diagonal:
       return self.box.clone()
     else:
       W = (self.box[2]-self.box[0]).item()
diff --git a/lib/models/cell_infers/cells.py b/lib/models/cell_infers/cells.py
index ae26a79..d881cba 100644
--- a/lib/models/cell_infers/cells.py
+++ b/lib/models/cell_infers/cells.py
@@ -1,4 +1,3 @@
-import torch
 import torch.nn as nn
 from copy import deepcopy
 from ..cell_operations import OPS
diff --git a/lib/models/cell_searchs/genotypes.py b/lib/models/cell_searchs/genotypes.py
index f18714f..5ccc283 100644
--- a/lib/models/cell_searchs/genotypes.py
+++ b/lib/models/cell_searchs/genotypes.py
@@ -68,7 +68,7 @@ class Structure:
     for i, node_info in enumerate(self.nodes):
       sums = []
       for op, xin in node_info:
-        if op == 'none' or nodes[xin] == False: x = False
+        if op == 'none' or nodes[xin] is False: x = False
         else: x = True
         sums.append( x )
       nodes[i+1] = sum(sums) > 0
diff --git a/lib/models/cell_searchs/search_cells.py b/lib/models/cell_searchs/search_cells.py
index 121322e..a38486d 100644
--- a/lib/models/cell_searchs/search_cells.py
+++ b/lib/models/cell_searchs/search_cells.py
@@ -85,7 +85,7 @@ class SearchCell(nn.Module):
           candidates = self.edges[node_str]
           select_op  = random.choice(candidates)
           sops.append( select_op )
-          if not hasattr(select_op, 'is_zero') or select_op.is_zero == False: has_non_zero=True
+          if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True
         if has_non_zero: break
       inter_nodes = []
       for j, select_op in enumerate(sops):
diff --git a/lib/models/shape_infers/InferCifarResNet_width.py b/lib/models/shape_infers/InferCifarResNet_width.py
index f1539da..06dbfa6 100644
--- a/lib/models/shape_infers/InferCifarResNet_width.py
+++ b/lib/models/shape_infers/InferCifarResNet_width.py
@@ -1,4 +1,4 @@
-import math, torch
+import math
 import torch.nn as nn
 import torch.nn.functional as F
 from ..initialization import initialize_resnet
diff --git a/lib/nas_102_api/api.py b/lib/nas_102_api/api.py
index cf11df4..9b8617a 100644
--- a/lib/nas_102_api/api.py
+++ b/lib/nas_102_api/api.py
@@ -70,6 +70,9 @@ class NASBench102API(object):
   def __repr__(self):
     return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
 
+  def random(self):
+    return random.randint(0, len(self.meta_archs)-1)
+
   def query_index_by_arch(self, arch):
     if isinstance(arch, str):
       if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
diff --git a/lib/procedures/starts.py b/lib/procedures/starts.py
index 3202b3f..b1b19d3 100644
--- a/lib/procedures/starts.py
+++ b/lib/procedures/starts.py
@@ -1,7 +1,7 @@
 ##################################################
 # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 ##################################################
-import os, sys, time, torch, random, PIL, copy, numpy as np
+import os, sys, torch, random, PIL, copy, numpy as np
 from os import path as osp
 from shutil  import copyfile
 
diff --git a/lib/utils/__init__.py b/lib/utils/__init__.py
index ff4419f..04c3bb6 100644
--- a/lib/utils/__init__.py
+++ b/lib/utils/__init__.py
@@ -1,3 +1,5 @@
 from .evaluation_utils import obtain_accuracy
 from .gpu_manager      import GPUManager
 from .flop_benchmark   import get_model_infos
+from .affine_utils     import normalize_points, denormalize_points
+from .affine_utils     import identity2affine, solve2theta, affine2image
diff --git a/lib/xvision/affine_utils.py b/lib/utils/affine_utils.py
similarity index 95%
rename from lib/xvision/affine_utils.py
rename to lib/utils/affine_utils.py
index a4e1a83..b122cb4 100644
--- a/lib/xvision/affine_utils.py
+++ b/lib/utils/affine_utils.py
@@ -1,10 +1,3 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-#
-#
 # functions for affine transformation
 import math, torch
 import numpy as np
diff --git a/lib/utils/flop_benchmark.py b/lib/utils/flop_benchmark.py
index 749751e..133cf2c 100644
--- a/lib/utils/flop_benchmark.py
+++ b/lib/utils/flop_benchmark.py
@@ -1,4 +1,4 @@
-import copy, torch
+import torch
 import torch.nn as nn
 import numpy as np
 
diff --git a/lib/utils/gpu_manager.py b/lib/utils/gpu_manager.py
index 8b039de..520cff4 100644
--- a/lib/utils/gpu_manager.py
+++ b/lib/utils/gpu_manager.py
@@ -27,7 +27,7 @@ class GPUManager():
         find = False
         for gpu in all_gpus:
           if gpu['index'] == CUDA_VISIBLE_DEVICE:
-            assert find==False, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE)
+            assert not find, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE)
             find = True
             selected_gpus.append( gpu.copy() )
             selected_gpus[-1]['index'] = '{}'.format(idx)
diff --git a/lib/utils/nas_utils.py b/lib/utils/nas_utils.py
new file mode 100644
index 0000000..7d3ec49
--- /dev/null
+++ b/lib/utils/nas_utils.py
@@ -0,0 +1,52 @@
+# This file is for experimental usage
+import os, sys, torch, random
+import numpy as np
+from copy import deepcopy
+from tqdm import tqdm
+import torch.nn as nn
+
+from utils  import obtain_accuracy
+from models import CellStructure
+from log_utils import time_string
+
+def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
+  weights = deepcopy(model.state_dict())
+  model.train(cal_mode)
+  with torch.no_grad():
+    logits = nn.functional.log_softmax(model.arch_parameters, dim=-1)
+    archs = CellStructure.gen_all(model.op_names, model.max_nodes, False)
+    probs, accuracies, gt_accs = [], [], []
+    loader_iter = iter(xloader)
+    random.seed(seed)
+    random.shuffle(archs)
+    for idx, arch in enumerate(archs):
+      arch_index = api.query_index_by_arch( arch )
+      metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False)
+      gt_accs.append( metrics['valid-accuracy'] )
+      select_logits = []
+      for i, node_info in enumerate(arch.nodes):
+        for op, xin in node_info:
+          node_str = '{:}<-{:}'.format(i+1, xin)
+          op_index = model.op_names.index(op)
+          select_logits.append( logits[model.edge2index[node_str], op_index] )
+      cur_prob = sum(select_logits).item()
+      probs.append( cur_prob )
+    cor_prob = np.corrcoef(probs, gt_accs)[0,1]
+    print ('correlation for probabilities : {:}'.format(cor_prob))
+      
+    for idx, arch in enumerate(archs):
+      model.set_cal_mode('dynamic', arch)
+      try:
+        inputs, targets = next(loader_iter)
+      except:
+        loader_iter = iter(xloader)
+        inputs, targets = next(loader_iter)
+      _, logits = model(inputs.cuda())
+      _, preds  = torch.max(logits, dim=-1)
+      correct = (preds == targets.cuda() ).float()
+      accuracies.append( correct.mean().item() )
+      if idx != 0 and (idx % 300 == 0 or idx + 1 == len(archs) or idx == 10):
+        cor_accs = np.corrcoef(accuracies, gt_accs[:idx+1])[0,1]
+        print ('{:} {:03d}/{:03d} mode={:5s}, correlation : accs={:.4f}, arch={:}'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs, arch))
+  model.load_state_dict(weights)
+  return archs, probs, accuracies
diff --git a/lib/xvision/__init__.py b/lib/xvision/__init__.py
deleted file mode 100644
index 24a3af4..0000000
--- a/lib/xvision/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .affine_utils import normalize_points, denormalize_points
diff --git a/scripts-search/algos/DARTS-V2.sh b/scripts-search/algos/DARTS-V2.sh
index 2d21149..5b81268 100644
--- a/scripts-search/algos/DARTS-V2.sh
+++ b/scripts-search/algos/DARTS-V2.sh
@@ -33,6 +33,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \
 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
 	--dataset ${dataset} --data_path ${data_path} \
 	--search_space_name ${space} \
+	--config_path configs/nas-benchmark/algos/DARTS.config \
 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
 	--workers 4 --print_freq 200 --rand_seed ${seed}
diff --git a/scripts/prepare.sh b/scripts/prepare.sh
index b968b70..3bf7150 100644
--- a/scripts/prepare.sh
+++ b/scripts/prepare.sh
@@ -1,6 +1,6 @@
 #!/bin/bash
 # bash ./scripts/prepare.sh
-datasets='cifar10 cifar100 imagenet-1k'
+#datasets='cifar10 cifar100 imagenet-1k'
 #ratios='0.5 0.8 0.9'
 ratios='0.5'
 save_dir=./.latent-data/splits
diff --git a/scripts/tas-infer-train.sh b/scripts/tas-infer-train.sh
index 8c282ea..034e3d3 100644
--- a/scripts/tas-infer-train.sh
+++ b/scripts/tas-infer-train.sh
@@ -33,7 +33,7 @@ OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \
 	--procedure    basic \
 	--save_dir     ${xsave_dir} \
 	--cutout_length -1 \
-	--batch_size 256 --rand_seed ${rseed} --workers 6 \
+	--batch_size ${batch} --rand_seed ${rseed} --workers 6 \
 	--eval_frequency 1 --print_freq 100 --print_freq_eval 200
 
 # KD training
@@ -47,5 +47,5 @@ OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \
 	--save_dir     ${xsave_dir} \
 	--KD_alpha 0.9 --KD_temperature 4 \
 	--cutout_length -1 \
-	--batch_size 256 --rand_seed ${rseed} --workers 6 \
+	--batch_size ${batch} --rand_seed ${rseed} --workers 6 \
 	--eval_frequency 1 --print_freq 100 --print_freq_eval 200