Refine lib -> xautodl

This commit is contained in:
D-X-Y
2021-05-19 08:10:42 +00:00
parent bd407ac4dc
commit 1c6c3e7166
12 changed files with 83 additions and 53 deletions

29
tests/test_loader.py Normal file
View File

@@ -0,0 +1,29 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# pytest tests/test_loader.py -s #
#####################################################
import unittest
import tempfile
import torch
from xautodl.datasets import get_datasets
def test_simple():
xdir = tempfile.mkdtemp()
train_data, valid_data, xshape, class_num = get_datasets("cifar10", xdir, -1)
print(train_data)
print(valid_data)
xloader = torch.utils.data.DataLoader(
train_data, batch_size=256, shuffle=True, num_workers=4, pin_memory=True
)
print(xloader)
print(next(iter(xloader)))
for i, data in enumerate(xloader):
print(i)
test_simple()

23
tests/test_tas.py Normal file
View File

@@ -0,0 +1,23 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import torch.nn as nn
from xautodl.models.shape_searchs.SoftSelect import ChannelWiseInter
class TestTASFunc(unittest.TestCase):
"""Test the TAS function."""
def test_channel_interplation(self):
tensors = torch.rand((16, 128, 7, 7))
for oc in range(200, 210):
out_v1 = ChannelWiseInter(tensors, oc, "v1")
out_v2 = ChannelWiseInter(tensors, oc, "v2")
assert (out_v1 == out_v2).any().item() == 1
for oc in range(48, 160):
out_v1 = ChannelWiseInter(tensors, oc, "v1")
out_v2 = ChannelWiseInter(tensors, oc, "v2")
assert (out_v1 == out_v2).any().item() == 1