〔TensorFlow〕MNIST 数据集
2024年3月29日大约 3 分钟
MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集(training set)由来自 250 个不同人手写的数字构成,其中 50%是高中学生,50%来自人口普查局(the Census Bureau)的工作人员。测试集(test set)也是同样比例的手写数字数据,但保证了测试集和训练集的作者集不相交。
MNIST 数据集一共有 7 万张图片,其中 6 万张是训练集,1 万张是测试集。每张图片是 28 × 28 的 0 − 9 的手写数字图片组成。每个图片是黑底白字的形式,黑底用 0 表示,白字用 0-1 之间的浮点数表示,越接近 1,颜色越白。为了简单起见,每个图像都被压平并转换为 784 个特征(28*28)的一维数字阵列。

MNIST 数据集不需要你手动去下载。
import tensorflow as tf
#加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
如果你想下载下来瞅瞅,那就把链接(http://yann.lecun.com/exdb/mnist/)拷贝到新的浏览器中。新打开一个浏览器,它不会提示输入用户名和密码。
文件名 | 数据集 | 大小 | 格式 |
---|---|---|---|
train-images-idx3-ubyte.gz | 训练集图像 | 9.45M | 二进制编码文件 |
train-labels-idx1-ubyte.gz | 训练集标签 | 28.2K | 二进制编码文件 |
t10k-images-idx3-ubyte.gz | 测试集图像 | 1.57M | 二进制编码文件 |
t10k-labels-idx1-ubyte.gz | 测试集标签 | 4.43K | 二进制编码文件 |
文件都是有格式的二进制编码文件,无法直接预览,使用如下代码(代码来自 CSDN)预览图像和标签,输出图像就是上面的图像。
import os
import gzip
import logging
import numpy as np
import matplotlib.pyplot as plt
logging.basicConfig(format="%(message)s", level=logging.DEBUG) # 设置Python日志管理工具的消息格式和显示级别
plt.rcParams["font.sans-serif"] = "SimHei" # 确保plt绘图正常显示中文
plt.rcParams["figure.figsize"] = [9, 10] # 设置plt绘图尺寸
def parse_mnist(minst_file_addr: str = None) -> np.array:
"""解析MNIST二进制文件, 并返回解析结果
输入参数:
minst_file: MNIST数据集的文件地址. 类型: 字符串.
返回值:
解析后的numpy数组
"""
if minst_file_addr is not None:
minst_file_name = os.path.basename(minst_file_addr) # 根据地址获取MNIST文件名字
with gzip.open(filename=minst_file_addr, mode="rb") as minst_file:
mnist_file_content = minst_file.read()
if "label" in minst_file_name: # 传入的为标签二进制编码文件地址
data = np.frombuffer(buffer=mnist_file_content, dtype=np.uint8, offset=8) # MNIST标签文件的前8个字节为描述性内容,直接从第九个字节开始读取标签,并解析
else: # 传入的为图片二进制编码文件地址
data = np.frombuffer(buffer=mnist_file_content, dtype=np.uint8, offset=16) # MNIST图片文件的前16个字节为描述性内容,直接从第九个字节开始读取标签,并解析
data = data.reshape(-1, 28, 28)
else:
logging.warning(msg="请传入MNIST文件地址!")
return data
if __name__ == "__main__":
train_imgs = parse_mnist(minst_file_addr="C:/Users/xxx/Downloads/train-images-idx3-ubyte.gz") # 训练集图像
train_labels = parse_mnist(minst_file_addr="C:/Users/xxx/Downloads/train-labels-idx1-ubyte.gz") # 训练集标签
# 可视化
fig, ax = plt.subplots(ncols=3, nrows=3)
ax[0, 0].imshow(train_imgs[0], cmap=plt.cm.gray)
ax[0, 0].set_title(f"[LOJOB]标签为{train_labels[0]}")
ax[0, 1].imshow(train_imgs[1], cmap=plt.cm.gray)
ax[0, 1].set_title(f"[LOJOB]标签为{train_labels[1]}")
ax[0, 2].imshow(train_imgs[2], cmap=plt.cm.gray)
ax[0, 2].set_title(f"[LOJOB]标签为{train_labels[2]}")
ax[1, 0].imshow(train_imgs[3], cmap=plt.cm.gray)
ax[1, 0].set_title(f"[LOJOB]标签为{train_labels[3]}")
ax[1, 1].imshow(train_imgs[4], cmap=plt.cm.gray)
ax[1, 1].set_title(f"[LOJOB]标签为{train_labels[4]}")
ax[1, 2].imshow(train_imgs[5], cmap=plt.cm.gray)
ax[1, 2].set_title(f"[LOJOB]标签为{train_labels[5]}")
ax[2, 0].imshow(train_imgs[6], cmap=plt.cm.gray)
ax[2, 0].set_title(f"[LOJOB]标签为{train_labels[6]}")
ax[2, 1].imshow(train_imgs[7], cmap=plt.cm.gray)
ax[2, 1].set_title(f"[LOJOB]标签为{train_labels[7]}")
ax[2, 2].imshow(train_imgs[8], cmap=plt.cm.gray)
ax[2, 2].set_title(f"[LOJOB]标签为{train_labels[8]}")
plt.show() # 显示绘图
print(plt.rcParams.keys())