Prototype MAML
This commit is contained in:
@@ -6,29 +6,32 @@ import torch.nn as nn
|
||||
|
||||
|
||||
def additive_func(A, B):
|
||||
assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size())
|
||||
C = min(A.size(1), B.size(1))
|
||||
if A.size(1) == B.size(1):
|
||||
return A + B
|
||||
elif A.size(1) < B.size(1):
|
||||
out = B.clone()
|
||||
out[:,:C] += A
|
||||
return out
|
||||
else:
|
||||
out = A.clone()
|
||||
out[:,:C] += B
|
||||
return out
|
||||
assert A.dim() == B.dim() and A.size(0) == B.size(0), "{:} vs {:}".format(
|
||||
A.size(), B.size()
|
||||
)
|
||||
C = min(A.size(1), B.size(1))
|
||||
if A.size(1) == B.size(1):
|
||||
return A + B
|
||||
elif A.size(1) < B.size(1):
|
||||
out = B.clone()
|
||||
out[:, :C] += A
|
||||
return out
|
||||
else:
|
||||
out = A.clone()
|
||||
out[:, :C] += B
|
||||
return out
|
||||
|
||||
|
||||
def change_key(key, value):
|
||||
def func(m):
|
||||
if hasattr(m, key):
|
||||
setattr(m, key, value)
|
||||
return func
|
||||
def func(m):
|
||||
if hasattr(m, key):
|
||||
setattr(m, key, value)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def parse_channel_info(xstring):
|
||||
blocks = xstring.split(' ')
|
||||
blocks = [x.split('-') for x in blocks]
|
||||
blocks = [[int(_) for _ in x] for x in blocks]
|
||||
return blocks
|
||||
blocks = xstring.split(" ")
|
||||
blocks = [x.split("-") for x in blocks]
|
||||
blocks = [[int(_) for _ in x] for x in blocks]
|
||||
return blocks
|
||||
|
Reference in New Issue
Block a user