Update yaml configs

This commit is contained in:
D-X-Y
2021-06-10 21:53:22 +08:00
parent 1a7440d2af
commit 9bf0fa5f04
21 changed files with 410 additions and 178 deletions

View File

@@ -3,10 +3,8 @@
#####################################################
# pytest tests/test_basic_space.py -s #
#####################################################
import sys, random
import random
import unittest
import pytest
from pathlib import Path
from xautodl.spaces import Categorical
from xautodl.spaces import Continuous

View File

@@ -3,12 +3,6 @@
#####################################################
# pytest ./tests/test_import.py #
#####################################################
import os, sys, time, torch
import pickle
import tempfile
from pathlib import Path
def test_import():
from xautodl import config_utils
from xautodl import datasets
@@ -19,6 +13,9 @@ def test_import():
from xautodl import spaces
from xautodl import trade_models
from xautodl import utils
from xautodl import xlayers
from xautodl import xmisc
from xautodl import xmmodels
print("Check all imports done")

View File

@@ -3,13 +3,11 @@
#####################################################
# pytest ./tests/test_super_att.py -s #
#####################################################
import sys, random
import random
import unittest
from parameterized import parameterized
from pathlib import Path
import torch
from xautodl import spaces
from xautodl.xlayers import super_core

View File

@@ -3,10 +3,9 @@
#####################################################
# pytest ./tests/test_super_container.py -s #
#####################################################
import sys, random
import random
import unittest
import pytest
from pathlib import Path
import torch
from xautodl import spaces

View File

@@ -3,7 +3,6 @@
#####################################################
# pytest ./tests/test_super_rearrange.py -s #
#####################################################
import sys
import unittest
import torch

View File

@@ -3,8 +3,8 @@
#####################################################
# pytest ./tests/test_super_vit.py -s #
#####################################################
import sys
import unittest
from parameterized import parameterized
import torch
from xautodl.xmodels import transformers
@@ -16,25 +16,28 @@ class TestSuperViT(unittest.TestCase):
def test_super_vit(self):
model = transformers.get_transformer("vit-base-16")
tensor = torch.rand((16, 3, 224, 224))
tensor = torch.rand((2, 3, 224, 224))
print("The tensor shape: {:}".format(tensor.shape))
# print(model)
outs = model(tensor)
print("The output tensor shape: {:}".format(outs.shape))
def test_imagenet(self):
name2config = transformers.name2config
print("There are {:} models in total.".format(len(name2config)))
for name, config in name2config.items():
if "cifar" in name:
tensor = torch.rand((16, 3, 32, 32))
else:
tensor = torch.rand((16, 3, 224, 224))
model = transformers.get_transformer(config)
outs = model(tensor)
size = count_parameters(model, "mb", True)
print(
"{:10s} : size={:.2f}MB, out-shape: {:}".format(
name, size, tuple(outs.shape)
)
@parameterized.expand(
[
["vit-cifar10-p4-d4-h4-c32", 32],
["vit-base-16", 224],
["vit-large-16", 224],
["vit-huge-14", 224],
]
)
def test_imagenet(self, name, resolution):
tensor = torch.rand((2, 3, resolution, resolution))
config = transformers.name2config[name]
model = transformers.get_transformer(config)
outs = model(tensor)
size = count_parameters(model, "mb", True)
print(
"{:10s} : size={:.2f}MB, out-shape: {:}".format(
name, size, tuple(outs.shape)
)
)