Refine lib -> xautodl
This commit is contained in:
29
tests/test_loader.py
Normal file
29
tests/test_loader.py
Normal file
@@ -0,0 +1,29 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
# pytest tests/test_loader.py -s #
|
||||
#####################################################
|
||||
import unittest
|
||||
import tempfile
|
||||
import torch
|
||||
|
||||
from xautodl.datasets import get_datasets
|
||||
|
||||
|
||||
def test_simple():
|
||||
xdir = tempfile.mkdtemp()
|
||||
train_data, valid_data, xshape, class_num = get_datasets("cifar10", xdir, -1)
|
||||
print(train_data)
|
||||
print(valid_data)
|
||||
|
||||
xloader = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=256, shuffle=True, num_workers=4, pin_memory=True
|
||||
)
|
||||
print(xloader)
|
||||
print(next(iter(xloader)))
|
||||
|
||||
for i, data in enumerate(xloader):
|
||||
print(i)
|
||||
|
||||
|
||||
test_simple()
|
23
tests/test_tas.py
Normal file
23
tests/test_tas.py
Normal file
@@ -0,0 +1,23 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xautodl.models.shape_searchs.SoftSelect import ChannelWiseInter
|
||||
|
||||
|
||||
class TestTASFunc(unittest.TestCase):
|
||||
"""Test the TAS function."""
|
||||
|
||||
def test_channel_interplation(self):
|
||||
tensors = torch.rand((16, 128, 7, 7))
|
||||
|
||||
for oc in range(200, 210):
|
||||
out_v1 = ChannelWiseInter(tensors, oc, "v1")
|
||||
out_v2 = ChannelWiseInter(tensors, oc, "v2")
|
||||
assert (out_v1 == out_v2).any().item() == 1
|
||||
for oc in range(48, 160):
|
||||
out_v1 = ChannelWiseInter(tensors, oc, "v1")
|
||||
out_v2 = ChannelWiseInter(tensors, oc, "v2")
|
||||
assert (out_v1 == out_v2).any().item() == 1
|
Reference in New Issue
Block a user