Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

datasetmaker.py 1.6KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from torch.utils.data import Dataset, DataLoader
  2. from torchvision import transforms, datasets
  3. import os
  4. import sys
  5. class TrainDataClass(Dataset):
  6. def __init__(self, image_path, transform=None):
  7. super(TrainDataClass, self).__init__()
  8. self.data = datasets.ImageFolder(os.path.join(image_path,'train'), transform)
  9. def __getitem__(self,idx):
  10. x,y = self.data[idx]
  11. return x,y
  12. def __len__(self):
  13. return len(self.data)
  14. class TestDataClass(Dataset):
  15. def __init__(self, image_path, transform=None):
  16. super(TestDataClass, self).__init__()
  17. self.data = datasets.ImageFolder(os.path.join(image_path,'test'), transform)
  18. def __getitem__(self,idx):
  19. x,y = self.data[idx]
  20. return x,y
  21. def __len__(self):
  22. return len(self.data)
  23. #if __name__ == "__main__":
  24. #NOTE: we don't have to transform because we already shrink raw images as same as MoblieNetV2's input size
  25. #TODO: PLEASE ERASE WHEN YOU REFRACORING!!
  26. #dataset_root = 'C:\\Users\\atari\\workspace\\chouse_train\\datasets\\201118\\middle'
  27. #classes = tuple(os.listdir(os.path.join(dataset_root,'train')))
  28. #
  29. #tsfm=transforms.Compose([
  30. # transforms.ToTensor()
  31. #])
  32. #
  33. #train = TrainDataClass(dataset_root, tsfm)
  34. #trainloader = DataLoader(train, batch_size=512, shuffle=True)
  35. #
  36. #test = TestDataClass(dataset_root, tsfm)
  37. #testloader = DataLoader(test, batch_size=512, shuffle=True)
  38. #
  39. #
  40. #dataiter = iter(trainloader)
  41. #images,labels = dataiter.next()
  42. #print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(512)))