23 lines
632 B
Python
23 lines
632 B
Python
import torchvision
|
|
from torchvision.datasets import MNIST
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
def getMnistLoader(config):
|
|
'''
|
|
Get MNIST dataset's loader
|
|
|
|
Inputs:
|
|
config (configparser.ConfigParser)
|
|
Outputs:
|
|
loader (nn.utils.data.DataLoader)
|
|
'''
|
|
BATCH_SIZE = int(config['unet']['batch_size'])
|
|
|
|
transform = torchvision.transforms.Compose([
|
|
torchvision.transforms.ToTensor()
|
|
])
|
|
|
|
data = MNIST("./data", train=True, download=True, transform=transform)
|
|
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
|
|
return loader
|