Complete xlayers.rearrange

This commit is contained in:
D-X-Y
2021-06-08 23:47:52 -07:00
parent f9bbf974de
commit 744ce97bc5
10 changed files with 218 additions and 96 deletions

View File

@@ -21,6 +21,8 @@ from .super_utils import ShapeContainer
BEST_DIR_KEY = "best_model_dir"
BEST_NAME_KEY = "best_model_name"
BEST_SCORE_KEY = "best_model_score"
ENABLE_CANDIDATE = 0
DISABLE_CANDIDATE = 1
class SuperModule(abc.ABC, nn.Module):
@@ -32,6 +34,7 @@ class SuperModule(abc.ABC, nn.Module):
self._abstract_child = None
self._verbose = False
self._meta_info = {}
self._candidate_mode = DISABLE_CANDIDATE
def set_super_run_type(self, super_run_type):
def _reset_super_run(m):
@@ -65,6 +68,20 @@ class SuperModule(abc.ABC, nn.Module):
)
self._abstract_child = abstract_child
def enable_candidate(self):
def _enable_candidate(m):
if isinstance(m, SuperModule):
m._candidate_mode = ENABLE_CANDIDATE
self.apply(_enable_candidate)
def disable_candidate(self):
def _disable_candidate(m):
if isinstance(m, SuperModule):
m._candidate_mode = DISABLE_CANDIDATE
self.apply(_disable_candidate)
def get_w_container(self):
container = TensorContainer()
for name, param in self.named_parameters():
@@ -191,9 +208,11 @@ class SuperModule(abc.ABC, nn.Module):
if self.super_run_type == SuperRunMode.FullModel:
outputs = self.forward_raw(*inputs)
elif self.super_run_type == SuperRunMode.Candidate:
if self._candidate_mode == DISABLE_CANDIDATE:
raise ValueError("candidate mode is disabled")
outputs = self.forward_candidate(*inputs)
else:
raise ModeError(
raise ValueError(
"Unknown Super Model Run Mode: {:}".format(self.super_run_type)
)
if self.verbose: