python+kerasを使って、UNETを学習・推論(作業中)
autoencoder,caeと試してきたので、次はunetを触ってみた
イメージ
ソース
import keras from keras.models import load_model from keras.datasets import cifar10 from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Conv2DTranspose, concatenate from keras.models import Model from keras.callbacks import EarlyStopping, ModelCheckpoint from keras.optimizers import Adam from keras import backend as K import os import pickle import numpy as np import cv2 batch_size = 50 num_classes = 10 epochs = 100 saveDir = "./model/" if not os.path.isdir(saveDir): os.makedirs(saveDir) (x_train, y_train), (x_test, y_test) = cifar10.load_data() x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 x_val = x_test[:7000] x_test = x_test[7000:] def showOrigDec(orig, dec, num=10): import matplotlib.pyplot as plt n = num plt.figure(figsize=(20, 8)) for i in range(n): # display original ax = plt.subplot(4, n, i+1) #plt.imshow(orig[i].reshape(32, 32, 3)) plt.imshow(orig[i]) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) # display reconstruction ax = plt.subplot(4, n, i +1 + n) #plt.imshow(dec[i].reshape(32, 32, 3)) plt.imshow(dec[i]) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) # display histogram ax = plt.subplot(4, n, i +1 + n + n) plt.hist((orig[i]*255).ravel(),256,[0,256]) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ax = plt.subplot(4, n, i +1 + n + n + n) color = ('b','g','r') for j,col in enumerate(color): histr = cv2.calcHist([(orig[i]*255)],[j],None,[256],[0,256]) plt.plot(histr,color = col) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.show() def dice_coef(y_true, y_pred): y_true_f = K.flatten(y_true) y_pred_f = K.flatten(y_pred) intersection = K.sum(y_true_f * y_pred_f) return (2. * intersection + 1) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1) def dice_coef_loss(y_true, y_pred): return -dice_coef(y_true, y_pred) def UNET(): inputs = Input((32, 32, 3)) conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs) conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4) conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5) up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same') (conv5), conv4], axis=3) conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6) conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same') (conv6), conv3], axis=3) conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7) conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7) up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (conv7), conv2], axis=3) conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8) conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8) up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (conv8), conv1], axis=3) conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9) conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9) conv10 = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(conv9) model = Model(inputs=[inputs], outputs=[conv10]) return model model=UNET() model.compile(optimizer='adam', loss=dice_coef_loss, metrics=[dice_coef]) es_cb = EarlyStopping(monitor='val_loss', patience=4, verbose=1, mode='auto') chkpt = saveDir + 'AE_UNET.{epoch:02d}-{loss:.2f}-{val_loss:.2f}.hdf5' cp_cb = ModelCheckpoint(filepath = chkpt, monitor='val_loss', \ verbose=1, save_best_only=True, mode='auto') history = model.fit(x_train, x_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_val, x_val), callbacks=[es_cb, cp_cb], shuffle=True) c10test = model.predict(x_test) c10val = model.predict(x_val) showOrigDec(x_test, c10test)
参考
A 2017 Guide to Semantic Segmentation with Deep Learning
【Python】 KerasでU-Net構造ネットワークによるセグメンテーションをする - 旅行好きなソフトエンジニアの備忘録
[1505.04597] U-Net: Convolutional Networks for Biomedical Image Segmentation