feat: dataset
This commit is contained in:
parent
de44cad219
commit
bf905e9e03
37
train/dataset.py
Normal file
37
train/dataset.py
Normal file
@ -0,0 +1,37 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class Cifar10Dataset(Dataset):
|
||||
def __init__(self, data_dir):
|
||||
self.imgs = []
|
||||
self.labels = []
|
||||
for file in os.listdir(data_dir):
|
||||
if 'data_batch' in file:
|
||||
batch = self.unpickle(f'{data_dir}/{file}')
|
||||
|
||||
length = len(batch[b'data'])
|
||||
self.labels += batch[b'labels']
|
||||
|
||||
# read image data
|
||||
values = np.array(batch[b'data']) / 255.0
|
||||
imgs = np.zeros((length, 3, 32, 32))
|
||||
for index in range(length):
|
||||
for channel in range(3):
|
||||
imgs[index][channel] = values[index][32*32*channel : 32*32*(channel+1)].reshape((32, 32))
|
||||
self.imgs.append(imgs)
|
||||
self.imgs = np.concatenate(self.imgs)
|
||||
print(f"load images : {self.imgs.shape}")
|
||||
print(f"load labels : {len(self.labels)}")
|
||||
|
||||
def unpickle(self, file):
|
||||
import pickle
|
||||
with open(file, 'rb') as fo:
|
||||
dict = pickle.load(fo, encoding='bytes')
|
||||
return dict
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.imgs[index], self.labels[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
Loading…
Reference in New Issue
Block a user