import os import numpy as np from torch.utils.data import Dataset, DataLoader class Cifar10Dataset(Dataset): def __init__(self, data_dir): self.imgs = [] self.labels = [] for file in os.listdir(data_dir): if 'data_batch' in file: batch = self.unpickle(f'{data_dir}/{file}') length = len(batch[b'data']) self.labels += batch[b'labels'] # read image data values = np.array(batch[b'data']) / 255.0 imgs = np.zeros((length, 3, 32, 32)) for index in range(length): for channel in range(3): imgs[index][channel] = values[index][32*32*channel : 32*32*(channel+1)].reshape((32, 32)) self.imgs.append(imgs) self.imgs = np.concatenate(self.imgs) print(f"load images : {self.imgs.shape}") print(f"load labels : {len(self.labels)}") def unpickle(self, file): import pickle with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') return dict def __getitem__(self, index): return self.imgs[index], self.labels[index] def __len__(self): return len(self.imgs)