プログラミング備忘録

初級プログラマ。python、DL勉強中

kerasを使って、cifar10をconvolutional autoencoderで学習・推論

前回はMNISTを単純なautoencoderで学習推論してみたが
今回はcifar10を畳み込みオートエンコーダー(convolutional autoencoder)で学習・推論してみた

programdl.hatenablog.com

ソース

from keras.datasets import cifar10
from keras.layers import Input, Dense, Conv2D, Activation
from keras.layers import MaxPooling2D, UpSampling2D, BatchNormalization
from keras.models import Model
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.optimizers import Adam
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt

batch_size = 32
epochs = 10
saveDir = "./model/"

if not os.path.isdir(saveDir):
    os.makedirs(saveDir)

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

x_val = x_test[:9000]
x_test = x_test[9000:]

def showPic(orig, dec, num=5):
    n = num
    plt.figure(figsize=(10, 4))

    for i in range(n):
        # display original
        ax = plt.subplot(2, n, i+1)
        plt.imshow(orig[i])
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(2, n, i +1 + n)
        plt.imshow(dec[i])
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

def encoder(input_img):
    x = Conv2D(64, (3, 3), padding='same')(input_img)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2), padding='same')(x)
    x = Conv2D(32, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2), padding='same')(x)
    x = Conv2D(16, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    encoded = MaxPooling2D((2, 2), padding='same')(x)
    return encoded

def decoder(input_img,encoded):
    x = Conv2D(16, (3, 3), padding='same')(encoded)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(32, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(3, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    decoded = Activation('sigmoid')(x)
    model = Model(input_img, decoded)
    return model

def autoencoder():
    input_img = Input(shape=(32, 32, 3))
    encoded=encoder(input_img)
    autoencoded=decoder(input_img,encoded)
    return autoencoded

# train
model = autoencoder()
model.compile(optimizer='adam', loss='binary_crossentropy')
es_cb = EarlyStopping(monitor='val_loss', patience=2, verbose=1, mode='auto')
chkpt = saveDir + 'AE_Cifar10.{epoch:02d}-{loss:.2f}-{val_loss:.2f}.hdf5'
cp_cb = ModelCheckpoint(filepath = chkpt, \
       monitor='val_loss', verbose=1, save_best_only=True, mode='auto')
history = model.fit(x_train, x_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_val, x_val),
                    callbacks=[es_cb, cp_cb],
                    shuffle=True)

# predict
c10test = model.predict(x_test)
showPic(x_test, c10test) 

encoder,decoderは別々に関数化しておくのがいいらしいが
適切な関数分割がわからなかったので適当に実装。

結果

  • 10epoch f:id:programdl:20181110194911p:plain

  • 24epoch(early_stop) f:id:programdl:20181110194927p:plain

24epochのほうが若干輪郭がはっきりしているように見える。 これが異常検知などに応用できるレベルなのか要検証

参考

畳み込みオートエンコーダによる画像の再現、ノイズ除去、セグメンテーション - Qiita