Refine Transformer

This commit is contained in:
D-X-Y
2021-07-04 11:59:06 +00:00
parent 9136f33684
commit 11f313288a
10 changed files with 160 additions and 28 deletions

View File

@@ -8,6 +8,9 @@
import os, sys, time, torch
import pickle
import tempfile
from pathlib import Path
root_dir = (Path(__file__).parent / ".." / "..").resolve()
from xautodl.trade_models.quant_transformer import QuantTransformer
@@ -17,7 +20,7 @@ def test_create():
if not torch.cuda.is_available():
return
quant_model = QuantTransformer(GPU=0)
temp_dir = lib_dir / ".." / "tests" / ".pytest_cache"
temp_dir = root_dir / "tests" / ".pytest_cache"
temp_dir.mkdir(parents=True, exist_ok=True)
temp_file = temp_dir / "quant-model.pkl"
with temp_file.open("wb") as f:
@@ -30,7 +33,7 @@ def test_create():
def test_load():
temp_file = lib_dir / ".." / "tests" / ".pytest_cache" / "quant-model.pkl"
temp_file = root_dir / "tests" / ".pytest_cache" / "quant-model.pkl"
with temp_file.open("rb") as f:
model = pickle.load(f)
print(model.model)