| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- from torch.utils.data import Dataset, DataLoader
- from torchvision import transforms, datasets
- import os
- import sys
-
- class TrainDataClass(Dataset):
- def __init__(self, image_path, transform=None):
- super(TrainDataClass, self).__init__()
- self.data = datasets.ImageFolder(os.path.join(image_path,'train'), transform)
-
- def __getitem__(self,idx):
- x,y = self.data[idx]
- return x,y
- def __len__(self):
- return len(self.data)
-
- class TestDataClass(Dataset):
- def __init__(self, image_path, transform=None):
- super(TestDataClass, self).__init__()
- self.data = datasets.ImageFolder(os.path.join(image_path,'test'), transform)
-
- def __getitem__(self,idx):
- x,y = self.data[idx]
- return x,y
- def __len__(self):
- return len(self.data)
-
-
-
-
- #if __name__ == "__main__":
- #NOTE: we don't have to transform because we already shrink raw images as same as MoblieNetV2's input size
- #TODO: PLEASE ERASE WHEN YOU REFRACORING!!
- #dataset_root = 'C:\\Users\\atari\\workspace\\chouse_train\\datasets\\201118\\middle'
- #classes = tuple(os.listdir(os.path.join(dataset_root,'train')))
- #
- #tsfm=transforms.Compose([
- # transforms.ToTensor()
- #])
- #
- #train = TrainDataClass(dataset_root, tsfm)
- #trainloader = DataLoader(train, batch_size=512, shuffle=True)
- #
- #test = TestDataClass(dataset_root, tsfm)
- #testloader = DataLoader(test, batch_size=512, shuffle=True)
- #
- #
- #dataiter = iter(trainloader)
- #images,labels = dataiter.next()
- #print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(512)))
|