add UNet code
This commit is contained in:
31
UNet/data.py
Normal file
31
UNet/data.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import os
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from utils import *
|
||||
from torchvision import transforms
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
|
||||
#use VOC2007 Dataset
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self.name = os.listdir(os.path.join(path, 'SegmentationClass'))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.name)
|
||||
|
||||
def __getitem__(self, index):
|
||||
segment_name = self.name[index] #xx.png
|
||||
segment_path = os.path.join(self.path, 'SegmentationClass',segment_name)
|
||||
image_path = os.path.join(self.path,'JPEGImages', segment_name.replace('png','jpg'))
|
||||
segment_image = keep_image_size_open(segment_path)
|
||||
image = keep_image_size_open(image_path)
|
||||
return transform(image), transform(segment_image)
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = MyDataset('/Users/hanzhangma/Document/DataSet/VOC2007')
|
||||
print(data[0][0].shape) # print the size of image(0,0)
|
||||
print(data[0][1].shape) # print the size of image(0,1)
|
||||
Reference in New Issue
Block a user