放牧代码和思想
专注自然语言处理、机器学习算法
    愛しさ 優しさ すべて投げ出してもいい

谷歌深度学习公开课任务 1: notMNIST

目录

这是谷歌在优达学城上开的公开课,感觉就是TensorFlow的宣传片。肉眼观测难度较低,作为入门第一课快速过掉也许还行。课程概述中说:

我们将教授你如何训练和优化基本神经网络、卷积神经网络和长短期记忆网络。你将通过项目和任务接触完整的机器学习系统 TensorFlow。你将学习解决一系列曾经以为非常具有挑战性的新问题,并在你用深度学习方法轻松解决这些问题的过程中更好地了解人工智能的复杂属性。

我们与 Google 的首席科学家兼 Google 智囊团技术经理 Vincent Vanhoucke 联合开发了本课内容。

**注: 本课内容是机器学习工程师纳米学位的一部分。经过本课的学习,我们可视为你已完成了机器学习的初步课程,且至少已熟悉了监督学习方法。

但愿这门课不要退化为TensorFlow的API文档。

优达学城视频上会有一些交互元素,结合IPythonNotebook,看上去是一种新颖的教学方式。但在编程任务中,我就不太喜欢IPythonNotebook了。就像PHP中插入html一样,IPythonNotebook允许将Python和Markdown组织在一起,虽然有这种功能,但实际上会使代码乱糟糟,不是个好东西。更何况,为了写个脚本我还要装个Docker下个镜像,我烦不烦啊。所以我决定用传统的Python脚本来完成所有的编程任务,绘图则用传统的matplotlib。

我对Python的期许是简洁而不是效率,我不精通Numpy,有时候写了一长串代码后发现高手一句话搞定了。所以对这些编程任务只是心里想想大概要怎么实现,然后跳过Google“Numpy如何实现…”的步骤,直接看别人写的片段。然后在其基础上修修改改,使其符合我的逻辑。这样虽然少了些锻炼机会,但给我省了不少查文档踩陷阱的时间。

上面这些说明构成了本系列的宗旨——速攻、不恋战。另外,重申博客的一贯宗旨——服务于个人备忘,不作任何保证,请不要过于期待。

全部代码托管在:https://github.com/hankcs/udacity-deep-learning 

谷歌官方的仓库:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/1_notmnist.ipynb 

我使用如下命令转换为.py文件:

jupyter nbconvert --to python *.ipynb

任务简介

这次任务是做些数据预处理,为以后打下基础。整个课程围绕着OCR任务进行,于是首先需要熟悉熟悉OCR数据集。

notMNIST

名字向经典的MNIST致敬,不同之处在于MNIST是手写字母,而notMNIST是各种字体的字母OCR:

nmn.png

其难度也比MNIST上了一个档次,据作者讲:

logistic regression on top of stacked auto-encoder with fine-tuning gets about 89% accuracy whereas same approach gives got 98% on MNIST. 

下载解压的代码已经给出了,直接运行脚本会得到:

hankcs.com 2016-11-13 下午2.13.04.png

旧版本的脚本可能下载到不同的数据,可能影响到后面的accuracy。

题目1:可视化数据

这些数据到底长什么样?请可视化一些数据(有些图片无法读取,不用介意,忽略即可)。

用matplotlib.image的imread可以很方便地达到目的,我撸了一串不是很好看的代码:

import random
import matplotlib.image as mpimg


def plot_samples(data_folders, sample_size, title=None):
    fig = plt.figure()
    if title: fig.suptitle(title, fontsize=16, fontweight='bold')
    for folder in data_folders:
        image_files = os.listdir(folder)
        image_sample = random.sample(image_files, sample_size)
        for image in image_sample:
            image_file = os.path.join(folder, image)
            ax = fig.add_subplot(len(data_folders), sample_size, sample_size * data_folders.index(folder) +
                                 image_sample.index(image) + 1)
            image = mpimg.imread(image_file)
            ax.imshow(image)
            ax.set_axis_off()

    plt.show()


plot_samples(train_folders, 10, 'Train Folders')
plot_samples(test_folders, 10, 'Test Folders')

得到:

train folder.png

test folder.png

直观感觉这个数据不好搞,有些字体特奇怪。另外,数据中还有不少噪音,比如全蓝底的图片。

题目2:预处理后的可视化

谷歌提供了代码将图片转为ndarray矩阵并序列化到pickle中,请证明处理后的ndarray数据依然是正确的。

神奇的Python,用imshow不需要改代码依然可以直接显示:

def load_and_display_pickle(datasets, sample_size, title=None):
    fig = plt.figure()
    if title: fig.suptitle(title, fontsize=16, fontweight='bold')
    num_of_images = []
    for pickle_file in datasets:
        with open(pickle_file, 'rb') as f:
            data = pickle.load(f)
            print('Total images in', pickle_file, ':', len(data))

            for index, image in enumerate(data):
                if index == sample_size: break
                ax = fig.add_subplot(len(datasets), sample_size, sample_size * datasets.index(pickle_file) +
                                     index + 1)
                ax.imshow(image)
                ax.set_axis_off()
                ax.imshow(image)

            num_of_images.append(len(data))

    balance_check(num_of_images)
    plt.show()
    return num_of_images

其中,balance_check在下一题中实现。

题目3:均衡校验

分类问题希望每个类的实例数目均衡,请检查这一点。

def generate_fake_label(sizes):
    labels = np.ndarray(sum(sizes), dtype=np.int32)
    start = 0
    end = 0
    for label, size in enumerate(sizes):
        start = end
        end += size
        for j in range(start, end):
            labels[j] = label
    return labels


def plot_balance():
    fig, ax = plt.subplots(1, 2)
    bins = np.arange(train_labels.min(), train_labels.max() + 2)
    ax[0].hist(train_labels, bins=bins)
    ax[0].set_xticks((bins[:-1] + bins[1:]) / 2, [chr(k) for k in range(ord("A"), ord("J") + 1)])
    ax[0].set_title("Training data")

    bins = np.arange(test_labels.min(), test_labels.max() + 2)
    ax[1].hist(test_labels, bins=bins)
    ax[1].set_xticks((bins[:-1] + bins[1:]) / 2, [chr(k) for k in range(ord("A"), ord("J") + 1)])
    ax[1].set_title("Test data")
    plt.show()


def mean(numbers):
    return float(sum(numbers)) / max(len(numbers), 1)


def balance_check(sizes):
    mean_val = mean(sizes)
    print('mean of # images :', mean_val)
    for i in sizes:
        if abs(i - mean_val) > 0.1 * mean_val:
            print("Too much or less images")
        else:
            print("Well balanced", i)

test_labels = generate_fake_label(load_and_display_pickle(test_datasets, 10, 'Test Datasets'))
train_labels = generate_fake_label(load_and_display_pickle(train_datasets, 10, 'Train Datasets'))

plot_balance()

为了使plot_balance能够兼容还未生成的label数组,我不得不写了个蹩脚的generate_fake_label,最后利用hist绘制直方图。

这两步得到:

traindataset.png

testdataset.png

balance1.png

非常均衡了。

题目4:打乱

谷歌将训练集拆分为训练集和开发集,并打乱了所有数据的顺序。请证明打乱后数据依然正确。

没什么好说的,可视化跟第2步类似:

def plot_sample_dataset(dataset, labels, title):
    plt.suptitle(title, fontsize=16, fontweight='bold')
    items = random.sample(range(len(labels)), 12)
    for i, item in enumerate(items):
        plt.subplot(3, 4, i + 1)
        plt.axis('off')
        plt.title(chr(ord('A') + labels[item]))
        plt.imshow(dataset[item])
    plt.show()

plot_sample_dataset(train_dataset, train_labels, 'train dataset suffled')
plot_sample_dataset(valid_dataset, valid_labels, 'valid dataset suffled')
plot_sample_dataset(test_dataset, test_labels, 'test dataset suffled')
plot_balance()

得到

train-shuffled.png

valid-shuffled.png

test-shuffled.png

balance2.png

这次完全均衡了。

题目5:重复样本

数据集中含有许多重复图片(文件名不同,但内容相同),这可能导致训练集、开发集、测试集中有一些样本是一摸一样的,这种现象会使准确率虚高,请计算重复样本数量,并去掉它们。

本来我是准备一个一个数组比较的,但速度太慢了,从Arn-O那里学到了先计算hash,然后再比较的技巧:

import hashlib


def extract_overlap_hash_where(dataset_1, dataset_2):
    dataset_hash_1 = np.array([hashlib.sha256(img).hexdigest() for img in dataset_1])
    dataset_hash_2 = np.array([hashlib.sha256(img).hexdigest() for img in dataset_2])
    overlap = {}
    for i, hash1 in enumerate(dataset_hash_1):
        duplicates = np.where(dataset_hash_2 == hash1)
        if len(duplicates[0]):
            overlap[i] = duplicates[0]
    return overlap


def display_overlap(overlap, source_dataset, target_dataset):
    overlap = {k: v for k, v in overlap.items() if len(v) >= 3}
    item = random.choice(list(overlap.keys()))
    imgs = np.concatenate(([source_dataset[item]], target_dataset[overlap[item][0:7]]))
    plt.suptitle(item)
    for i, img in enumerate(imgs):
        plt.subplot(2, 4, i + 1)
        plt.axis('off')
        plt.imshow(img)

    plt.show()


def sanitize(dataset_1, dataset_2, labels_1):
    dataset_hash_1 = np.array([hashlib.sha256(img).hexdigest() for img in dataset_1])
    dataset_hash_2 = np.array([hashlib.sha256(img).hexdigest() for img in dataset_2])
    overlap = []  # list of indexes
    for i, hash1 in enumerate(dataset_hash_1):
        duplicates = np.where(dataset_hash_2 == hash1)
        if len(duplicates[0]):
            overlap.append(i)
    return np.delete(dataset_1, overlap, 0), np.delete(labels_1, overlap, None)


overlap_test_train = extract_overlap_hash_where(test_dataset, train_dataset)
print('Number of overlaps:', len(overlap_test_train.keys()))
display_overlap(overlap_test_train, test_dataset, train_dataset)

test_dataset_sanit, test_labels_sanit = sanitize(test_dataset, train_dataset, test_labels)
print('Overlapping images removed from test_dataset: ', len(test_dataset) - len(test_dataset_sanit))
valid_dataset_sanit, valid_labels_sanit = sanitize(valid_dataset, train_dataset, valid_labels)
print('Overlapping images removed from valid_dataset: ', len(valid_dataset) - len(valid_dataset_sanit))
print('Training:', train_dataset.shape, train_labels.shape)
print('Validation:', valid_labels_sanit.shape, valid_labels_sanit.shape)
print('Testing:', test_dataset_sanit.shape, test_labels_sanit.shape)
pickle_file_sanit = 'notMNIST_sanit.pickle'

try:
    f = open(pickle_file_sanit, 'wb')
    save = {
        'train_dataset': train_dataset,
        'train_labels': train_labels,
        'valid_dataset': valid_dataset_sanit,
        'valid_labels': valid_labels_sanit,
        'test_dataset': test_dataset_sanit,
        'test_labels': test_labels_sanit,
    }
    pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
    f.close()
except Exception as e:
    print('Unable to save data to', pickle_file, ':', e)
    raise

statinfo = os.stat(pickle_file_sanit)
print('Compressed pickle size:', statinfo.st_size)

得到

Number of overlaps: 1324
Overlapping images removed from test_dataset:  1324
Overlapping images removed from valid_dataset:  1067
Training: (200000, 28, 28) (200000,)
Validation: (8933,) (8933,)
Testing: (8676, 28, 28) (8676,)

题目6:基线模型

有许多ML类库提供开箱即用的分类器,请挑选一个作为基线模型,并在不同训练集规模下评测性能。

这里用sklearn.linear_model 的 LogisticRegression:

def disp_sample_dataset(dataset, labels, title=None):
    fig = plt.figure()
    if title: fig.suptitle(title, fontsize=16, fontweight='bold')
    items = random.sample(range(len(labels)), 8)
    for i, item in enumerate(items):
        plt.subplot(2, 4, i + 1)
        plt.axis('off')
        plt.title(chr(ord('A') + labels[item]))
        plt.imshow(dataset[item])
    plt.show()


def train_and_predict(sample_size):
    regr = LogisticRegression()
    X_train = train_dataset[:sample_size].reshape(sample_size, 784)
    y_train = train_labels[:sample_size]
    regr.fit(X_train, y_train)

    X_test = test_dataset.reshape(test_dataset.shape[0], 28 * 28)
    y_test = test_labels

    pred_labels = regr.predict(X_test)

    print('Accuracy:', regr.score(X_test, y_test), 'when sample_size=', sample_size)
    disp_sample_dataset(test_dataset, pred_labels, 'sample_size=' + str(sample_size))


for sample_size in [50, 100, 1000, 5000, len(train_dataset)]:
    train_and_predict(sample_size)

得到:

Accuracy: 0.509 when sample_size= 50
Accuracy: 0.6966 when sample_size= 100
Accuracy: 0.8333 when sample_size= 1000
Accuracy: 0.8511 when sample_size= 5000

其中全集训练太耗时了,没跑出结果。事实上,5000个训练实例已经能拿到85%的accuracy,对开箱即用的sklearn而言相当可以了。

可视化一些预测结果吧:

50.png

100.png

1000.png

5000.png

Reference

https://github.com/Arn-O/udacity-deep-learning/blob/master/1_notmnist.ipynb 

知识共享许可协议 知识共享署名-非商业性使用-相同方式共享码农场 » 谷歌深度学习公开课任务 1: notMNIST

评论 3

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址
  1. #2

    hankcs,问个问题

    为什么在验证归一化图像的时候,将pickle文件可视化后会出现汉字还有一些奇奇怪怪的字母?

    李彬7年前 (2017-04-10)回复
  2. #1

    楼主,代码中url地址能连上去吗?我跑代码的时候提示 Error 101 : Network is unreachable..可能是在内地链接不上谷歌的

    Lam7年前 (2017-02-14)回复
    • 我也run过,同样的错误,我的解决办法是将数据集下载下来即可

      李彬7年前 (2017-04-10)回复

我的作品

HanLP自然语言处理包《自然语言处理入门》