38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
import os
|
|
import numpy as np
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
class Cifar10Dataset(Dataset):
|
|
def __init__(self, data_dir):
|
|
self.imgs = []
|
|
self.labels = []
|
|
for file in os.listdir(data_dir):
|
|
if 'data_batch' in file:
|
|
batch = self.unpickle(f'{data_dir}/{file}')
|
|
|
|
length = len(batch[b'data'])
|
|
self.labels += batch[b'labels']
|
|
|
|
# read image data
|
|
values = np.array(batch[b'data']) / 255.0
|
|
imgs = np.zeros((length, 3, 32, 32))
|
|
for index in range(length):
|
|
for channel in range(3):
|
|
imgs[index][channel] = values[index][32*32*channel : 32*32*(channel+1)].reshape((32, 32))
|
|
self.imgs.append(imgs)
|
|
self.imgs = np.concatenate(self.imgs)
|
|
print(f"load images : {self.imgs.shape}")
|
|
print(f"load labels : {len(self.labels)}")
|
|
|
|
def unpickle(self, file):
|
|
import pickle
|
|
with open(file, 'rb') as fo:
|
|
dict = pickle.load(fo, encoding='bytes')
|
|
return dict
|
|
|
|
def __getitem__(self, index):
|
|
return self.imgs[index], self.labels[index]
|
|
|
|
def __len__(self):
|
|
return len(self.imgs)
|