プログラミング備忘録

初級プログラマ。python、DL勉強中

python+kerasを使って、MNISTをautoencoderで学習・推論

今更ながらautoencoderを実装してみた。
dataはMINISTを使用

ソース

import keras
from keras.models import load_model
from keras.models import Model
from keras.datasets import mnist
from keras.layers import Input, Dense
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

# MNIST
(x_train, _), (x_test, _) = mnist.load_data()
x_train, x_valid = train_test_split(x_train, test_size=0.175)
x_train = x_train.astype('float32')/255.
x_valid = x_valid.astype('float32')/255.
x_test = x_test.astype('float32')/255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_valid = x_valid.reshape((len(x_valid), np.prod(x_valid.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

#入力
input_img = Input(shape=(784, ))

#encoder
encoding_dim = 32
encoded = Dense(encoding_dim, activation='relu')(input_img)

#decoder
decoded = Dense(784, activation='sigmoid')(encoded)
autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')


# train
autoencoder.fit(x_train, x_train,
                epochs=50,
                batch_size=256,
                shuffle=True,
                validation_data=(x_valid, x_valid))

# predict
out_img=autoencoder.predict(x_test)

# output
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):

    ax = plt.subplot(2, n, i+1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    ax = plt.subplot(2, n, i+1+n)
    plt.imshow(out_img[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

学習は入力と出力に同じデータを設定する
中間層は入力、出力層より絞る必要があり、今回は32、16、8でトライ

結果

  • 中間層32 f:id:programdl:20181110184607p:plain

  • 中間層16 f:id:programdl:20181110184621p:plain

  • 中間層8 f:id:programdl:20181110184631p:plain

中間層を少なくすると、出力画像がかなり粗くなる
中間層を絞ることで特徴量を抽出するのがautoencoderのポイントだが、
絞りすぎると返って特徴の抽出が厳しくなるように見える

参考

Building Autoencoders in Keras

KerasでAutoEncoder - Qiita