我正在使用TensorDataset
从numpy数组创建数据集.
# convert numpy arrays to pytorch tensors
X_train = torch.stack([torch.from_numpy(np.array(i)) for i in X_train])
y_train = torch.stack([torch.from_numpy(np.array(i)) for i in y_train])
# reshape into [C,H,W]
X_train = X_train.reshape((-1,1,28,28)).float()
# create dataset and DataLoaders
train_dataset = torch.utils.data.TensorDataset(X_train,y_train)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=64)
如何将数据扩充(transforms)应用于TensorDataset?
例如,使用ImageFolder
,我可以将转换指定为torchvision.datasets.ImageFolder(root,transform = …)的参数之一.
根据PyTorch团队成员之一的this reply,默认情况下不支持.有其他替代方法吗?
随意询问是否需要更多代码来解释问题.
最佳答案
默认情况下,TensorDataset不支持变换.但是我们可以创建我们的自定义类来添加该选项.但是,正如我已经提到的,大多数转换都是为PIL.Image开发的.但是无论如何,这里是带有非常虚拟转换的非常简单的MNIST示例.带有MNIST here的csv文件.
码:
import numpy as np
import torch
from torch.utils.data import Dataset,TensorDataset
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# Import mnist dataset from cvs file and convert it to torch tensor
with open('mnist_train.csv','r') as f:
mnist_train = f.readlines()
# Images
X_train = np.array([[float(j) for j in i.strip().split(',')][1:] for i in mnist_train])
X_train = X_train.reshape((-1,28))
X_train = torch.tensor(X_train)
# Labels
y_train = np.array([int(i[0]) for i in mnist_train])
y_train = y_train.reshape(y_train.shape[0],1)
y_train = torch.tensor(y_train)
del mnist_train
class CustomTensorDataset(Dataset):
"""TensorDataset with support of transforms.
"""
def __init__(self,tensors,transform=None):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
self.transform = transform
def __getitem__(self,index):
x = self.tensors[0][index]
if self.transform:
x = self.transform(x)
y = self.tensors[1][index]
return x,y
def __len__(self):
return self.tensors[0].size(0)
def imshow(img,title=''):
"""Plot the image batch.
"""
plt.figure(figsize=(10,10))
plt.title(title)
plt.imshow(np.transpose( img.numpy(),(1,2,0)),cmap='gray')
plt.show()
# Dataset w/o any tranformations
train_dataset_normal = CustomTensorDataset(tensors=(X_train,y_train),transform=None)
train_loader = torch.utils.data.DataLoader(train_dataset_normal,batch_size=16)
# iterate
for i,data in enumerate(train_loader):
x,y = data
imshow(torchvision.utils.make_grid(x,4),title='Normal')
break # we need just one batch
# Let's add some transforms
# Dataset with flipping tranformations
def vflip(tensor):
"""Flips tensor vertically.
"""
tensor = tensor.flip(1)
return tensor
def hflip(tensor):
"""Flips tensor horizontally.
"""
tensor = tensor.flip(2)
return tensor
train_dataset_vf = CustomTensorDataset(tensors=(X_train,transform=vflip)
train_loader = torch.utils.data.DataLoader(train_dataset_vf,batch_size=16)
result = []
for i,title='Vertical flip')
break
train_dataset_hf = CustomTensorDataset(tensors=(X_train,transform=hflip)
train_loader = torch.utils.data.DataLoader(train_dataset_hf,title='Horizontal flip')
break
输出: