Refine Transformer
This commit is contained in:
@@ -8,6 +8,9 @@
|
||||
import os, sys, time, torch
|
||||
import pickle
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
root_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||
|
||||
from xautodl.trade_models.quant_transformer import QuantTransformer
|
||||
|
||||
@@ -17,7 +20,7 @@ def test_create():
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
quant_model = QuantTransformer(GPU=0)
|
||||
temp_dir = lib_dir / ".." / "tests" / ".pytest_cache"
|
||||
temp_dir = root_dir / "tests" / ".pytest_cache"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_file = temp_dir / "quant-model.pkl"
|
||||
with temp_file.open("wb") as f:
|
||||
@@ -30,7 +33,7 @@ def test_create():
|
||||
|
||||
|
||||
def test_load():
|
||||
temp_file = lib_dir / ".." / "tests" / ".pytest_cache" / "quant-model.pkl"
|
||||
temp_file = root_dir / "tests" / ".pytest_cache" / "quant-model.pkl"
|
||||
with temp_file.open("rb") as f:
|
||||
model = pickle.load(f)
|
||||
print(model.model)
|
||||
|
Reference in New Issue
Block a user