Python PyTorch 如何获取 MNIST 数据
作者:深色風信子
这篇文章主要介绍了Python PyTorch 如何获取 MNIST 数据,通过示例代码介绍了PyTorch 保存 MNIST 数据,PyTorch 显示 MNIST 数据的操作方法,感兴趣的朋友跟随小编一起看看吧
1 PyTorch 获取 MNIST 数据
import torch import numpy as np import matplotlib.pyplot as plt # type: ignore from torchvision import datasets, transforms def mnist_get(): print(torch.__version__) # 定义数据转换 transform = transforms.Compose([ transforms.ToTensor(), # 将图像转换为张量 transforms.Normalize((0.5,), (0.5,)) # 归一化图像数据 ]) # 获取数据 train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform) test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform) # 训练数据 train_image = train_data.data.numpy() train_label = train_data.targets.numpy() # 测试数据 test_image = test_data.data.numpy() test_label = test_data.targets.numpy()
2 PyTorch 保存 MNIST 数据
import torch import numpy as np import matplotlib.pyplot as plt # type: ignore from torchvision import datasets, transforms def mnist_save(mnist_path): print(torch.__version__) # 定义数据转换 transform = transforms.Compose([ transforms.ToTensor(), # 将图像转换为张量 transforms.Normalize((0.5,), (0.5,)) # 归一化图像数据 ]) # 获取数据 train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform) test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform) # 训练数据 train_image = train_data.data.numpy() train_label = train_data.targets.numpy() # 测试数据 test_image = test_data.data.numpy() test_label = test_data.targets.numpy() np.savez(mnist_path, train_data=train_image, train_label=train_label, test_data=test_image, test_label=test_label) mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz' mnist_save(mnist_path)
3 PyTorch 显示 MNIST 数据
import torch import numpy as np import matplotlib.pyplot as plt # type: ignore from torchvision import datasets, transforms def mnist_show(mnist_path): data = np.load(mnist_path) image = data['train_data'][0:100] label = data['train_label'].reshape(-1, ) plt.figure(figsize = (10, 10)) for i in range(100): print('%f, %f' % (i, label[i])) plt.subplot(10, 10, i + 1) plt.imshow(image[i]) plt.show() mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz' mnist_show(mnist_path)
到此这篇关于Python PyTorch 获取 MNIST 数据的文章就介绍到这了,更多相关Python PyTorch 获取 MNIST 数据内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!