python-TensorDataset上的PyTorch转换

前端之家收集整理的这篇文章主要介绍了python-TensorDataset上的PyTorch转换 前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。

我正在使用TensorDataset从numpy数组创建数据集.

# convert numpy arrays to pytorch tensors
X_train = torch.stack([torch.from_numpy(np.array(i)) for i in X_train])
y_train = torch.stack([torch.from_numpy(np.array(i)) for i in y_train])

# reshape into [C,H,W]
X_train = X_train.reshape((-1,1,28,28)).float()

# create dataset and DataLoaders
train_dataset = torch.utils.data.TensorDataset(X_train,y_train)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=64)

如何将数据扩充(transforms)应用于TensorDataset?

例如,使用ImageFolder,我可以将转换指定为torchvision.datasets.ImageFolder(root,transform = …)的参数之一.

根据PyTorch团队成员之一的this reply,默认情况下不支持.有其他替代方法吗?

随意询问是否需要更多代码来解释问题.

最佳答案
默认情况下,TensorDataset不支持变换.但是我们可以创建我们的自定义类来添加该选项.但是,正如我已经提到的,大多数转换都是为PIL.Image开发的.但是无论如何,这里是带有非常虚拟转换的非常简单的MNIST示例.带有MNIST here的csv文件.

码:

import numpy as np
import torch
from torch.utils.data import Dataset,TensorDataset

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

# Import mnist dataset from cvs file and convert it to torch tensor

with open('mnist_train.csv','r') as f:
    mnist_train = f.readlines()

# Images
X_train = np.array([[float(j) for j in i.strip().split(',')][1:] for i in mnist_train])
X_train = X_train.reshape((-1,28))
X_train = torch.tensor(X_train)

# Labels
y_train = np.array([int(i[0]) for i in mnist_train])
y_train = y_train.reshape(y_train.shape[0],1)
y_train = torch.tensor(y_train)

del mnist_train


class CustomTensorDataset(Dataset):
    """TensorDataset with support of transforms.
    """
    def __init__(self,tensors,transform=None):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self,index):
        x = self.tensors[0][index]

        if self.transform:
            x = self.transform(x)

        y = self.tensors[1][index]

        return x,y

    def __len__(self):
        return self.tensors[0].size(0)


def imshow(img,title=''):
    """Plot the image batch.
    """
    plt.figure(figsize=(10,10))
    plt.title(title)
    plt.imshow(np.transpose( img.numpy(),(1,2,0)),cmap='gray')
    plt.show()


# Dataset w/o any tranformations
train_dataset_normal = CustomTensorDataset(tensors=(X_train,y_train),transform=None)
train_loader = torch.utils.data.DataLoader(train_dataset_normal,batch_size=16)

# iterate
for i,data in enumerate(train_loader):
    x,y = data  
    imshow(torchvision.utils.make_grid(x,4),title='Normal')
    break  # we need just one batch


# Let's add some transforms

# Dataset with flipping tranformations

def vflip(tensor):
    """Flips tensor vertically.
    """
    tensor = tensor.flip(1)
    return tensor


def hflip(tensor):
    """Flips tensor horizontally.
    """
    tensor = tensor.flip(2)
    return tensor


train_dataset_vf = CustomTensorDataset(tensors=(X_train,transform=vflip)
train_loader = torch.utils.data.DataLoader(train_dataset_vf,batch_size=16)

result = []

for i,title='Vertical flip')
    break


train_dataset_hf = CustomTensorDataset(tensors=(X_train,transform=hflip)
train_loader = torch.utils.data.DataLoader(train_dataset_hf,title='Horizontal flip')
    break

输出

norm


vert


horz

猜你在找的Python相关文章