from torch.utils.data import Dataset, DataLoader from torchvision import transforms, datasets import os import sys class TrainDataClass(Dataset): def __init__(self, image_path, transform=None): super(TrainDataClass, self).__init__() self.data = datasets.ImageFolder(os.path.join(image_path,'train'), transform) def __getitem__(self,idx): x,y = self.data[idx] return x,y def __len__(self): return len(self.data) class TestDataClass(Dataset): def __init__(self, image_path, transform=None): super(TestDataClass, self).__init__() self.data = datasets.ImageFolder(os.path.join(image_path,'test'), transform) def __getitem__(self,idx): x,y = self.data[idx] return x,y def __len__(self): return len(self.data) #if __name__ == "__main__": #NOTE: we don't have to transform because we already shrink raw images as same as MoblieNetV2's input size #TODO: PLEASE ERASE WHEN YOU REFRACORING!! #dataset_root = 'C:\\Users\\atari\\workspace\\chouse_train\\datasets\\201118\\middle' #classes = tuple(os.listdir(os.path.join(dataset_root,'train'))) # #tsfm=transforms.Compose([ # transforms.ToTensor() #]) # #train = TrainDataClass(dataset_root, tsfm) #trainloader = DataLoader(train, batch_size=512, shuffle=True) # #test = TestDataClass(dataset_root, tsfm) #testloader = DataLoader(test, batch_size=512, shuffle=True) # # #dataiter = iter(trainloader) #images,labels = dataiter.next() #print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(512)))