diff --git a/train/dataset.py b/train/dataset.py new file mode 100644 index 0000000..6da52cf --- /dev/null +++ b/train/dataset.py @@ -0,0 +1,37 @@ +import os +import numpy as np +from torch.utils.data import Dataset, DataLoader + +class Cifar10Dataset(Dataset): + def __init__(self, data_dir): + self.imgs = [] + self.labels = [] + for file in os.listdir(data_dir): + if 'data_batch' in file: + batch = self.unpickle(f'{data_dir}/{file}') + + length = len(batch[b'data']) + self.labels += batch[b'labels'] + + # read image data + values = np.array(batch[b'data']) / 255.0 + imgs = np.zeros((length, 3, 32, 32)) + for index in range(length): + for channel in range(3): + imgs[index][channel] = values[index][32*32*channel : 32*32*(channel+1)].reshape((32, 32)) + self.imgs.append(imgs) + self.imgs = np.concatenate(self.imgs) + print(f"load images : {self.imgs.shape}") + print(f"load labels : {len(self.labels)}") + + def unpickle(self, file): + import pickle + with open(file, 'rb') as fo: + dict = pickle.load(fo, encoding='bytes') + return dict + + def __getitem__(self, index): + return self.imgs[index], self.labels[index] + + def __len__(self): + return len(self.imgs)