跳至主要內容

〔TensorFlow〕MNIST 数据集

大林鸱大约 3 分钟深度学习计算机视觉自然语言处理

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集(training set)由来自 250 个不同人手写的数字构成,其中 50%是高中学生,50%来自人口普查局(the Census Bureau)的工作人员。测试集(test set)也是同样比例的手写数字数据,但保证了测试集和训练集的作者集不相交。

MNIST 数据集一共有 7 万张图片,其中 6 万张是训练集,1 万张是测试集。每张图片是 28 × 28 的 0 − 9 的手写数字图片组成。每个图片是黑底白字的形式,黑底用 0 表示,白字用 0-1 之间的浮点数表示,越接近 1,颜色越白。为了简单起见,每个图像都被压平并转换为 784 个特征(28*28)的一维数字阵列。

MNIST 数据集
MNIST 数据集

MNIST 数据集不需要你手动去下载。

import tensorflow as tf

#加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

如果你想下载下来瞅瞅,那就把链接(http://yann.lecun.com/exdb/mnist/)拷贝到新的浏览器中。新打开一个浏览器,它不会提示输入用户名和密码。

文件名数据集大小格式
train-images-idx3-ubyte.gz训练集图像9.45M二进制编码文件
train-labels-idx1-ubyte.gz训练集标签28.2K二进制编码文件
t10k-images-idx3-ubyte.gz测试集图像1.57M二进制编码文件
t10k-labels-idx1-ubyte.gz测试集标签4.43K二进制编码文件

文件都是有格式的二进制编码文件,无法直接预览,使用如下代码(代码来自 CSDN)预览图像和标签,输出图像就是上面的图像。

import os
import gzip
import logging

import numpy as np
import matplotlib.pyplot as plt

logging.basicConfig(format="%(message)s", level=logging.DEBUG)  # 设置Python日志管理工具的消息格式和显示级别

plt.rcParams["font.sans-serif"] = "SimHei"  # 确保plt绘图正常显示中文
plt.rcParams["figure.figsize"] = [9, 10]  # 设置plt绘图尺寸

def parse_mnist(minst_file_addr: str = None) -> np.array:
    """解析MNIST二进制文件, 并返回解析结果
    输入参数:
        minst_file: MNIST数据集的文件地址. 类型: 字符串.

    返回值:
        解析后的numpy数组
    """
    if minst_file_addr is not None:
        minst_file_name = os.path.basename(minst_file_addr)  # 根据地址获取MNIST文件名字
        with gzip.open(filename=minst_file_addr, mode="rb") as minst_file:
            mnist_file_content = minst_file.read()
        if "label" in minst_file_name:  # 传入的为标签二进制编码文件地址
            data = np.frombuffer(buffer=mnist_file_content, dtype=np.uint8, offset=8)  # MNIST标签文件的前8个字节为描述性内容,直接从第九个字节开始读取标签,并解析
        else:  # 传入的为图片二进制编码文件地址
            data = np.frombuffer(buffer=mnist_file_content, dtype=np.uint8, offset=16)  # MNIST图片文件的前16个字节为描述性内容,直接从第九个字节开始读取标签,并解析
            data = data.reshape(-1, 28, 28)
    else:
        logging.warning(msg="请传入MNIST文件地址!")

    return data


if __name__ == "__main__":
    train_imgs = parse_mnist(minst_file_addr="C:/Users/xxx/Downloads/train-images-idx3-ubyte.gz")  # 训练集图像
    train_labels = parse_mnist(minst_file_addr="C:/Users/xxx/Downloads/train-labels-idx1-ubyte.gz")  # 训练集标签

    # 可视化
    fig, ax = plt.subplots(ncols=3, nrows=3)
    ax[0, 0].imshow(train_imgs[0], cmap=plt.cm.gray)
    ax[0, 0].set_title(f"[LOJOB]标签为{train_labels[0]}")
    ax[0, 1].imshow(train_imgs[1], cmap=plt.cm.gray)
    ax[0, 1].set_title(f"[LOJOB]标签为{train_labels[1]}")
    ax[0, 2].imshow(train_imgs[2], cmap=plt.cm.gray)
    ax[0, 2].set_title(f"[LOJOB]标签为{train_labels[2]}")
    ax[1, 0].imshow(train_imgs[3], cmap=plt.cm.gray)
    ax[1, 0].set_title(f"[LOJOB]标签为{train_labels[3]}")
    ax[1, 1].imshow(train_imgs[4], cmap=plt.cm.gray)
    ax[1, 1].set_title(f"[LOJOB]标签为{train_labels[4]}")
    ax[1, 2].imshow(train_imgs[5], cmap=plt.cm.gray)
    ax[1, 2].set_title(f"[LOJOB]标签为{train_labels[5]}")
    ax[2, 0].imshow(train_imgs[6], cmap=plt.cm.gray)
    ax[2, 0].set_title(f"[LOJOB]标签为{train_labels[6]}")
    ax[2, 1].imshow(train_imgs[7], cmap=plt.cm.gray)
    ax[2, 1].set_title(f"[LOJOB]标签为{train_labels[7]}")
    ax[2, 2].imshow(train_imgs[8], cmap=plt.cm.gray)
    ax[2, 2].set_title(f"[LOJOB]标签为{train_labels[8]}")
    plt.show()  # 显示绘图

    print(plt.rcParams.keys())
上次编辑于: