Complete xlayers.rearrange
This commit is contained in:
@@ -21,6 +21,8 @@ from .super_utils import ShapeContainer
|
||||
BEST_DIR_KEY = "best_model_dir"
|
||||
BEST_NAME_KEY = "best_model_name"
|
||||
BEST_SCORE_KEY = "best_model_score"
|
||||
ENABLE_CANDIDATE = 0
|
||||
DISABLE_CANDIDATE = 1
|
||||
|
||||
|
||||
class SuperModule(abc.ABC, nn.Module):
|
||||
@@ -32,6 +34,7 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
self._abstract_child = None
|
||||
self._verbose = False
|
||||
self._meta_info = {}
|
||||
self._candidate_mode = DISABLE_CANDIDATE
|
||||
|
||||
def set_super_run_type(self, super_run_type):
|
||||
def _reset_super_run(m):
|
||||
@@ -65,6 +68,20 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
)
|
||||
self._abstract_child = abstract_child
|
||||
|
||||
def enable_candidate(self):
|
||||
def _enable_candidate(m):
|
||||
if isinstance(m, SuperModule):
|
||||
m._candidate_mode = ENABLE_CANDIDATE
|
||||
|
||||
self.apply(_enable_candidate)
|
||||
|
||||
def disable_candidate(self):
|
||||
def _disable_candidate(m):
|
||||
if isinstance(m, SuperModule):
|
||||
m._candidate_mode = DISABLE_CANDIDATE
|
||||
|
||||
self.apply(_disable_candidate)
|
||||
|
||||
def get_w_container(self):
|
||||
container = TensorContainer()
|
||||
for name, param in self.named_parameters():
|
||||
@@ -191,9 +208,11 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
if self.super_run_type == SuperRunMode.FullModel:
|
||||
outputs = self.forward_raw(*inputs)
|
||||
elif self.super_run_type == SuperRunMode.Candidate:
|
||||
if self._candidate_mode == DISABLE_CANDIDATE:
|
||||
raise ValueError("candidate mode is disabled")
|
||||
outputs = self.forward_candidate(*inputs)
|
||||
else:
|
||||
raise ModeError(
|
||||
raise ValueError(
|
||||
"Unknown Super Model Run Mode: {:}".format(self.super_run_type)
|
||||
)
|
||||
if self.verbose:
|
||||
|
Reference in New Issue
Block a user