diff --git a/UNet/data.py b/UNet/data.py new file mode 100644 index 0000000..09952b4 --- /dev/null +++ b/UNet/data.py @@ -0,0 +1,31 @@ +import os + +from torch.utils.data import Dataset +from utils import * +from torchvision import transforms +transform = transforms.Compose([ + transforms.ToTensor() +]) + + +#use VOC2007 Dataset +class MyDataset(Dataset): + def __init__(self, path): + self.path = path + self.name = os.listdir(os.path.join(path, 'SegmentationClass')) + + def __len__(self): + return len(self.name) + + def __getitem__(self, index): + segment_name = self.name[index] #xx.png + segment_path = os.path.join(self.path, 'SegmentationClass',segment_name) + image_path = os.path.join(self.path,'JPEGImages', segment_name.replace('png','jpg')) + segment_image = keep_image_size_open(segment_path) + image = keep_image_size_open(image_path) + return transform(image), transform(segment_image) + +if __name__ == '__main__': + data = MyDataset('/Users/hanzhangma/Document/DataSet/VOC2007') + print(data[0][0].shape) # print the size of image(0,0) + print(data[0][1].shape) # print the size of image(0,1) \ No newline at end of file diff --git a/UNet/net.py b/UNet/net.py new file mode 100644 index 0000000..e82f695 --- /dev/null +++ b/UNet/net.py @@ -0,0 +1,87 @@ +from torch import nn +from torch.nn import functional as F +from torch import randn +import torch + +class Conv_Block(nn.Module): + def __init__(self, in_channel, out_channel): + super(Conv_Block, self).__init__() + self.layer = nn.Sequential( + nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False), + nn.BatchNorm2d(out_channel), + nn.Dropout2d(0.3), + nn.LeakyReLU(), + nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3,stride=1,padding=1,padding_mode='reflect', bias=False), + nn.BatchNorm2d(out_channel), + nn.Dropout2d(0.3), + nn.LeakyReLU() + ) + + def forward(self, x): + return self.layer(x) + +class DownSample(nn.Module): + def __init__(self, channel): + super(DownSample, self).__init__() + self.layer = nn.Sequential( + nn.Conv2d(channel, channel, 3, 2, 1, padding_mode='reflect', bias=False), + nn.BatchNorm2d(channel), + nn.LeakyReLU() + ) + def forward(self, x): + return self.layer(x) + +class UpSample(nn.Module): + def __init__(self, channel): + super(UpSample, self).__init__() + self.layer = nn.Sequential( + nn.Conv2d(channel, channel//2, 1, 1) + ) + def forward(self, x, feature_map): + up = F.interpolate(x, scale_factor=2, mode='nearest') + out = self.layer(up) + return torch.cat((out, feature_map), dim=1) + +class UNet(nn.Module): + def __init__(self): + super(UNet, self).__init__() + self.c1 = Conv_Block(3,64) + self.d1 = DownSample(64) + self.c2 = Conv_Block(64, 128) + self.d2 = DownSample(128) + self.c3 = Conv_Block(128, 256) + self.d3 = DownSample(256) + self.c4 = Conv_Block(256, 512) + self.d4 = DownSample(512) + self.c5 = Conv_Block(512, 1024) + + self.u1 = UpSample(1024) + self.c6 = Conv_Block(1024, 512) + self.u2 = UpSample(512) + self.c7 = Conv_Block(512, 256) + self.u3 = UpSample(256) + self.c8 = Conv_Block(256, 128) + self.u4 = UpSample(128) + self.c9 = Conv_Block(128, 64) + + self.out = nn.Conv2d(64, 3, 3, 1, 1) + self.Th = nn.Sigmoid() + + def forward(self, x): + R1 = self.c1(x) + R2 = self.c2(self.d1(R1)) + R3 = self.c3(self.d2(R2)) + R4 = self.c4(self.d3(R3)) + R5 = self.c5(self.d4(R4)) + + O1 = self.c6(self.u1(R5, R4)) + O2 = self.c7(self.u2(O1, R3)) + O3 = self.c8(self.u3(O2, R2)) + O4 = self.c9(self.u4(O3, R1)) + + return self.Th(self.out(O4)) + +if __name__ == '__main__': + x = randn(2, 3, 256, 256) + net = UNet() + print(net(x).shape) \ No newline at end of file diff --git a/UNet/train.py b/UNet/train.py new file mode 100644 index 0000000..0d723d0 --- /dev/null +++ b/UNet/train.py @@ -0,0 +1,53 @@ +import torch +from torch import optim +from torch.utils.data import DataLoader +from data import * +from net import * + +from torchvision.utils import save_image + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +weight_path = r'/Users/hanzhangma/Nextcloud/mhz/Study/SS24/MasterThesis/UNet/params/unet.pth' +data_path = r'/Users/hanzhangma/Document/DataSet/VOC2007' +save_path = r'/Users/hanzhangma/Nextcloud/mhz/Study/SS24/MasterThesis/Unet/train_image' + +if __name__ == '__main__': + data_loader = DataLoader(MyDataset(data_path), batch_size= 4, shuffle=True) + + net = UNet().to(device) + if os.path.exists(weight_path): + net.load_state_dict(torch.load(weight_path)) + print('successful load weight!') + else: + print('Failed on load weight!') + + opt = optim.Adam(net.parameters()) + loss_fun = nn.BCELoss() + + epoch=1 + + while True: + for i,(image,segment_image) in enumerate(data_loader): + image, segment_image = image.to(device), segment_image.to(device) + + out_image = net(image) + train_loss = loss_fun(out_image, segment_image) + + opt.zero_grad() + train_loss.backward() + opt.step() # 更新梯度 + + if i%5 ==0 : + print(f'{epoch} -- {i} -- train loss ===>> {train_loss.item()}') + + if i % 50 == 0: + torch.save(net.state_dict(), weight_path) + + _image = image[0] + _segment_image = segment_image[0] + _out_image = out_image[0] + + img = torch.stack([_image, _segment_image, _out_image], dim=0) + save_image(img, f'{save_path}/{i}.png') + + epoch += 1 diff --git a/UNet/utils.py b/UNet/utils.py new file mode 100644 index 0000000..00906dd --- /dev/null +++ b/UNet/utils.py @@ -0,0 +1,10 @@ +from PIL import Image + +def keep_image_size_open(path,size=(256,256)): + img = Image.open(path) + tmp = max(img.size) + mask = Image.new('RGB', (tmp, tmp),(0,0,0)) + mask.paste(img,(0,0)) + mask = mask.resize(size) + return mask +