Add visualize codes for Q

This commit is contained in:
D-X-Y
2021-04-11 21:45:20 +08:00
parent e777f38233
commit 0e2dd13762
16 changed files with 570 additions and 125 deletions

View File

@@ -21,7 +21,11 @@ def get_conv2D_Wmats(tensor: np.ndarray) -> List[np.ndarray]:
"""
mats = []
N, M, imax, jmax = tensor.shape
assert N + M >= imax + jmax, "invalid tensor shape detected: {}x{} (NxM), {}x{} (i,j)".format(N, M, imax, jmax)
assert (
N + M >= imax + jmax
), "invalid tensor shape detected: {}x{} (NxM), {}x{} (i,j)".format(
N, M, imax, jmax
)
for i in range(imax):
for j in range(jmax):
w = tensor[:, :, i, j]
@@ -58,7 +62,17 @@ def glorot_norm_fix(w, n, m, rf_size):
return w
def analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms, softranks, normalize, glorot_fix):
def analyze_weights(
weights,
min_size,
max_size,
alphas,
lognorms,
spectralnorms,
softranks,
normalize,
glorot_fix,
):
results = OrderedDict()
count = len(weights)
if count == 0:
@@ -94,12 +108,16 @@ def analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms
lambda0 = None
if M < min_size:
summary = "Weight matrix {}/{} ({},{}): Skipping: too small (<{})".format(i + 1, count, M, N, min_size)
summary = "Weight matrix {}/{} ({},{}): Skipping: too small (<{})".format(
i + 1, count, M, N, min_size
)
cur_res["summary"] = summary
continue
elif max_size > 0 and M > max_size:
summary = "Weight matrix {}/{} ({},{}): Skipping: too big (testing) (>{})".format(
i + 1, count, M, N, max_size
summary = (
"Weight matrix {}/{} ({},{}): Skipping: too big (testing) (>{})".format(
i + 1, count, M, N, max_size
)
)
cur_res["summary"] = summary
continue
@@ -153,7 +171,9 @@ def analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms
cur_res["lognormX"] = lognormX
summary.append(
"Weight matrix {}/{} ({},{}): LogNorm: {} ; LogNormX: {}".format(i + 1, count, M, N, lognorm, lognormX)
"Weight matrix {}/{} ({},{}): LogNorm: {} ; LogNormX: {}".format(
i + 1, count, M, N, lognorm, lognormX
)
)
if softranks:
@@ -163,8 +183,10 @@ def analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms
cur_res["softrank"] = softrank
cur_res["softranklog"] = softranklog
cur_res["softranklogratio"] = softranklogratio
summary += "{}. Softrank: {}. Softrank log: {}. Softrank log ratio: {}".format(
summary, softrank, softranklog, softranklogratio
summary += (
"{}. Softrank: {}. Softrank log: {}. Softrank log ratio: {}".format(
summary, softrank, softranklog, softranklogratio
)
)
cur_res["summary"] = "\n".join(summary)
return results
@@ -209,7 +231,17 @@ def compute_details(results):
metrics_stats.append("{}_compound_avg".format(metric))
columns = (
["layer_id", "layer_type", "N", "M", "layer_count", "slice", "slice_count", "level", "comment"]
[
"layer_id",
"layer_type",
"N",
"M",
"layer_count",
"slice",
"slice_count",
"level",
"comment",
]
+ [*metrics]
+ metrics_stats
)
@@ -351,7 +383,15 @@ def analyze(
else:
weights = get_conv2D_Wmats(module.weight.cpu().detach().numpy())
results = analyze_weights(
weights, min_size, max_size, alphas, lognorms, spectralnorms, softranks, normalize, glorot_fix
weights,
min_size,
max_size,
alphas,
lognorms,
spectralnorms,
softranks,
normalize,
glorot_fix,
)
results["id"] = index
results["type"] = type(module)