feat: dataset
This commit is contained in:
parent
de44cad219
commit
bf905e9e03
37
train/dataset.py
Normal file
37
train/dataset.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
|
class Cifar10Dataset(Dataset):
|
||||||
|
def __init__(self, data_dir):
|
||||||
|
self.imgs = []
|
||||||
|
self.labels = []
|
||||||
|
for file in os.listdir(data_dir):
|
||||||
|
if 'data_batch' in file:
|
||||||
|
batch = self.unpickle(f'{data_dir}/{file}')
|
||||||
|
|
||||||
|
length = len(batch[b'data'])
|
||||||
|
self.labels += batch[b'labels']
|
||||||
|
|
||||||
|
# read image data
|
||||||
|
values = np.array(batch[b'data']) / 255.0
|
||||||
|
imgs = np.zeros((length, 3, 32, 32))
|
||||||
|
for index in range(length):
|
||||||
|
for channel in range(3):
|
||||||
|
imgs[index][channel] = values[index][32*32*channel : 32*32*(channel+1)].reshape((32, 32))
|
||||||
|
self.imgs.append(imgs)
|
||||||
|
self.imgs = np.concatenate(self.imgs)
|
||||||
|
print(f"load images : {self.imgs.shape}")
|
||||||
|
print(f"load labels : {len(self.labels)}")
|
||||||
|
|
||||||
|
def unpickle(self, file):
|
||||||
|
import pickle
|
||||||
|
with open(file, 'rb') as fo:
|
||||||
|
dict = pickle.load(fo, encoding='bytes')
|
||||||
|
return dict
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.imgs[index], self.labels[index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.imgs)
|
||||||
Loading…
Reference in New Issue
Block a user