Distributed-Training-Example/train/dataset.py

44 lines
1.4 KiB
Python

'''
Please implement a Dataset which inherit the PyTorch Dataset class.
So that our trainer can load the data from /dataset_dir
'''
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
class Cifar10Dataset(Dataset):
def __init__(self, data_dir):
self.imgs = []
self.labels = []
for file in os.listdir(data_dir):
if 'data_batch' in file:
batch = self.unpickle(f'{data_dir}/{file}')
length = len(batch[b'data'])
self.labels += batch[b'labels']
# read image data
values = np.array(batch[b'data']) / 255.0
imgs = np.zeros((length, 3, 32, 32))
for index in range(length):
for channel in range(3):
imgs[index][channel] = values[index][32*32*channel : 32*32*(channel+1)].reshape((32, 32))
self.imgs.append(imgs)
self.imgs = np.concatenate(self.imgs)
print(f"load images : {self.imgs.shape}")
print(f"load labels : {len(self.labels)}")
def unpickle(self, file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def __getitem__(self, index):
return self.imgs[index], self.labels[index]
def __len__(self):
return len(self.imgs)