谷歌深度学习公开课任务 1: notMNIST

**注: 本课内容是机器学习工程师纳米学位的一部分。经过本课的学习，我们可视为你已完成了机器学习的初步课程，且至少已熟悉了监督学习方法。

`jupyter nbconvert --to python *.ipynb`

任务简介

notMNIST

logistic regression on top of stacked auto-encoder with fine-tuning gets about 89% accuracy whereas same approach gives got 98% on MNIST.

题目1：可视化数据

```import random
import matplotlib.image as mpimg

def plot_samples(data_folders, sample_size, title=None):
fig = plt.figure()
if title: fig.suptitle(title, fontsize=16, fontweight='bold')
for folder in data_folders:
image_files = os.listdir(folder)
image_sample = random.sample(image_files, sample_size)
for image in image_sample:
image_file = os.path.join(folder, image)
ax = fig.add_subplot(len(data_folders), sample_size, sample_size * data_folders.index(folder) +
image_sample.index(image) + 1)
ax.imshow(image)
ax.set_axis_off()

plt.show()

plot_samples(train_folders, 10, 'Train Folders')
plot_samples(test_folders, 10, 'Test Folders')```

题目2：预处理后的可视化

```def load_and_display_pickle(datasets, sample_size, title=None):
fig = plt.figure()
if title: fig.suptitle(title, fontsize=16, fontweight='bold')
num_of_images = []
for pickle_file in datasets:
with open(pickle_file, 'rb') as f:
print('Total images in', pickle_file, ':', len(data))

for index, image in enumerate(data):
if index == sample_size: break
ax = fig.add_subplot(len(datasets), sample_size, sample_size * datasets.index(pickle_file) +
index + 1)
ax.imshow(image)
ax.set_axis_off()
ax.imshow(image)

num_of_images.append(len(data))

balance_check(num_of_images)
plt.show()
return num_of_images```

题目3：均衡校验

```def generate_fake_label(sizes):
labels = np.ndarray(sum(sizes), dtype=np.int32)
start = 0
end = 0
for label, size in enumerate(sizes):
start = end
end += size
for j in range(start, end):
labels[j] = label
return labels

def plot_balance():
fig, ax = plt.subplots(1, 2)
bins = np.arange(train_labels.min(), train_labels.max() + 2)
ax[0].hist(train_labels, bins=bins)
ax[0].set_xticks((bins[:-1] + bins[1:]) / 2, [chr(k) for k in range(ord("A"), ord("J") + 1)])
ax[0].set_title("Training data")

bins = np.arange(test_labels.min(), test_labels.max() + 2)
ax[1].hist(test_labels, bins=bins)
ax[1].set_xticks((bins[:-1] + bins[1:]) / 2, [chr(k) for k in range(ord("A"), ord("J") + 1)])
ax[1].set_title("Test data")
plt.show()

def mean(numbers):
return float(sum(numbers)) / max(len(numbers), 1)

def balance_check(sizes):
mean_val = mean(sizes)
print('mean of # images :', mean_val)
for i in sizes:
if abs(i - mean_val) > 0.1 * mean_val:
print("Too much or less images")
else:
print("Well balanced", i)

test_labels = generate_fake_label(load_and_display_pickle(test_datasets, 10, 'Test Datasets'))
train_labels = generate_fake_label(load_and_display_pickle(train_datasets, 10, 'Train Datasets'))

plot_balance()```

题目4：打乱

```def plot_sample_dataset(dataset, labels, title):
plt.suptitle(title, fontsize=16, fontweight='bold')
items = random.sample(range(len(labels)), 12)
for i, item in enumerate(items):
plt.subplot(3, 4, i + 1)
plt.axis('off')
plt.title(chr(ord('A') + labels[item]))
plt.imshow(dataset[item])
plt.show()

plot_sample_dataset(train_dataset, train_labels, 'train dataset suffled')
plot_sample_dataset(valid_dataset, valid_labels, 'valid dataset suffled')
plot_sample_dataset(test_dataset, test_labels, 'test dataset suffled')
plot_balance()```

题目5：重复样本

```import hashlib

def extract_overlap_hash_where(dataset_1, dataset_2):
dataset_hash_1 = np.array([hashlib.sha256(img).hexdigest() for img in dataset_1])
dataset_hash_2 = np.array([hashlib.sha256(img).hexdigest() for img in dataset_2])
overlap = {}
for i, hash1 in enumerate(dataset_hash_1):
duplicates = np.where(dataset_hash_2 == hash1)
if len(duplicates[0]):
overlap[i] = duplicates[0]
return overlap

def display_overlap(overlap, source_dataset, target_dataset):
overlap = {k: v for k, v in overlap.items() if len(v) >= 3}
item = random.choice(list(overlap.keys()))
imgs = np.concatenate(([source_dataset[item]], target_dataset[overlap[item][0:7]]))
plt.suptitle(item)
for i, img in enumerate(imgs):
plt.subplot(2, 4, i + 1)
plt.axis('off')
plt.imshow(img)

plt.show()

def sanitize(dataset_1, dataset_2, labels_1):
dataset_hash_1 = np.array([hashlib.sha256(img).hexdigest() for img in dataset_1])
dataset_hash_2 = np.array([hashlib.sha256(img).hexdigest() for img in dataset_2])
overlap = []  # list of indexes
for i, hash1 in enumerate(dataset_hash_1):
duplicates = np.where(dataset_hash_2 == hash1)
if len(duplicates[0]):
overlap.append(i)
return np.delete(dataset_1, overlap, 0), np.delete(labels_1, overlap, None)

overlap_test_train = extract_overlap_hash_where(test_dataset, train_dataset)
print('Number of overlaps:', len(overlap_test_train.keys()))
display_overlap(overlap_test_train, test_dataset, train_dataset)

test_dataset_sanit, test_labels_sanit = sanitize(test_dataset, train_dataset, test_labels)
print('Overlapping images removed from test_dataset: ', len(test_dataset) - len(test_dataset_sanit))
valid_dataset_sanit, valid_labels_sanit = sanitize(valid_dataset, train_dataset, valid_labels)
print('Overlapping images removed from valid_dataset: ', len(valid_dataset) - len(valid_dataset_sanit))
print('Training:', train_dataset.shape, train_labels.shape)
print('Validation:', valid_labels_sanit.shape, valid_labels_sanit.shape)
print('Testing:', test_dataset_sanit.shape, test_labels_sanit.shape)
pickle_file_sanit = 'notMNIST_sanit.pickle'

try:
f = open(pickle_file_sanit, 'wb')
save = {
'train_dataset': train_dataset,
'train_labels': train_labels,
'valid_dataset': valid_dataset_sanit,
'valid_labels': valid_labels_sanit,
'test_dataset': test_dataset_sanit,
'test_labels': test_labels_sanit,
}
pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
f.close()
except Exception as e:
print('Unable to save data to', pickle_file, ':', e)
raise

statinfo = os.stat(pickle_file_sanit)
print('Compressed pickle size:', statinfo.st_size)```

```Number of overlaps: 1324
Overlapping images removed from test_dataset:  1324
Overlapping images removed from valid_dataset:  1067
Training: (200000, 28, 28) (200000,)
Validation: (8933,) (8933,)
Testing: (8676, 28, 28) (8676,)```

题目6：基线模型

```def disp_sample_dataset(dataset, labels, title=None):
fig = plt.figure()
if title: fig.suptitle(title, fontsize=16, fontweight='bold')
items = random.sample(range(len(labels)), 8)
for i, item in enumerate(items):
plt.subplot(2, 4, i + 1)
plt.axis('off')
plt.title(chr(ord('A') + labels[item]))
plt.imshow(dataset[item])
plt.show()

def train_and_predict(sample_size):
regr = LogisticRegression()
X_train = train_dataset[:sample_size].reshape(sample_size, 784)
y_train = train_labels[:sample_size]
regr.fit(X_train, y_train)

X_test = test_dataset.reshape(test_dataset.shape[0], 28 * 28)
y_test = test_labels

pred_labels = regr.predict(X_test)

print('Accuracy:', regr.score(X_test, y_test), 'when sample_size=', sample_size)
disp_sample_dataset(test_dataset, pred_labels, 'sample_size=' + str(sample_size))

for sample_size in [50, 100, 1000, 5000, len(train_dataset)]:
train_and_predict(sample_size)```

```Accuracy: 0.509 when sample_size= 50
Accuracy: 0.6966 when sample_size= 100
Accuracy: 0.8333 when sample_size= 1000
Accuracy: 0.8511 when sample_size= 5000```

Reference

评论 3

1. #2

hankcs，问个问题

为什么在验证归一化图像的时候，将pickle文件可视化后会出现汉字还有一些奇奇怪怪的字母？

李彬3年前 (2017-04-10)回复
2. #1

楼主，代码中url地址能连上去吗？我跑代码的时候提示 Error 101 : Network is unreachable..可能是在内地链接不上谷歌的

Lam3年前 (2017-02-14)回复
• 我也run过，同样的错误，我的解决办法是将数据集下载下来即可

李彬3年前 (2017-04-10)回复