Updates
This commit is contained in:
@@ -19,6 +19,7 @@ from .super_utils import TensorContainer
|
||||
from .super_utils import ShapeContainer
|
||||
|
||||
BEST_DIR_KEY = "best_model_dir"
|
||||
BEST_NAME_KEY = "best_model_name"
|
||||
BEST_SCORE_KEY = "best_model_score"
|
||||
|
||||
|
||||
@@ -94,6 +95,9 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
self._meta_info[BEST_DIR_KEY] = str(xdir)
|
||||
Path(xdir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def set_best_name(self, xname):
|
||||
self._meta_info[BEST_NAME_KEY] = str(xname)
|
||||
|
||||
def save_best(self, score):
|
||||
if BEST_DIR_KEY not in self._meta_info:
|
||||
tempdir = tempfile.mkdtemp("-xlayers")
|
||||
@@ -102,10 +106,11 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
self._meta_info[BEST_SCORE_KEY] = None
|
||||
best_score = self._meta_info[BEST_SCORE_KEY]
|
||||
if best_score is None or best_score <= score:
|
||||
best_save_path = os.path.join(
|
||||
self._meta_info[BEST_DIR_KEY],
|
||||
"best-{:}.pth".format(self.__class__.__name__),
|
||||
best_save_name = self._meta_info.get(
|
||||
BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__)
|
||||
)
|
||||
|
||||
best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name)
|
||||
self._meta_info[BEST_SCORE_KEY] = score
|
||||
torch.save(self.state_dict(), best_save_path)
|
||||
return True, self._meta_info[BEST_SCORE_KEY]
|
||||
|
Reference in New Issue
Block a user