Update scripts
This commit is contained in:
@@ -59,7 +59,7 @@ def to_layer(config, embed_dim, depth):
|
||||
def extend_transformer_settings(alg2configs, name):
|
||||
config = copy.deepcopy(alg2configs[name])
|
||||
for i in range(6):
|
||||
for j in [24, 32, 48, 64]:
|
||||
for j in [6, 12, 24, 32, 48, 64]:
|
||||
for k in [0, 0.1]:
|
||||
alg2configs[name + "-{:}x{:}-d{:}".format(i, j, k)] = to_layer(
|
||||
to_pos_drop(config, k), j, i
|
||||
@@ -104,7 +104,7 @@ def retrieve_configs():
|
||||
idx, len(alg2configs), alg, path
|
||||
)
|
||||
)
|
||||
alg2configs = extend_transformer_settings(alg2configs, "TSF-A")
|
||||
alg2configs = extend_transformer_settings(alg2configs, "TSF")
|
||||
return alg2configs
|
||||
|
||||
|
||||
@@ -156,7 +156,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--alg",
|
||||
type=str,
|
||||
choices=list(alg2paths.keys()),
|
||||
choices=list(alg2configs.keys()),
|
||||
required=True,
|
||||
help="The algorithm name.",
|
||||
)
|
||||
|
Reference in New Issue
Block a user