CNN Filtrelerini Görselleştirme

Amacımız, random üretilen bir resmin daha önce MNIST datası ile eğitilmiş bir modele girdi olarak vererek çıktı olarak istediğimiz rakamı üretmek. ve sonunda modelin öğrendiği özelikleri görselleştireceğiz.

   def Mymodel():
    
   
        model=Sequential()
        model.add(Convolution2D(32, 3, input_shape=(28,28,1)))
        model.add(Convolution2D(64, 3))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))
        
        model.add(Flatten())
        
        model.add(Dense(128))
        model.add(Activation('relu'))
        model.add(Dropout(0.5))
        
        model.add(Dense( 10))
        model.add(Activation(tf.nn.softmax))  
         
        optimizer = keras.optimizers.Adadelta()
        model.compile(loss='sparse_categorical_crossentropy',          optimizer=optimizer, metrics=['accuracy'])
             model.fit(X_train,Y_train,epochs=12,batch_size=128,validation_data=(X_test,Y_test))
         
        return model;
def GenerateImage(model,y_true ,img,lr,iterate):

     output = model.layers[9].output
     loss =K.mean(output[:,y_true])
     grads = K.gradients(loss, model.input)[0]
     grads /= (K.sqrt(K.mean(K.square(grads))) + 1e-5)
     train = K.function([model.input],[loss,grads])
     K.set_learning_phase(False)
     keras.backend.set_learning_phase(0)
     result = []
     imgAll = []
     for i in range(iterate):
          loss,grads=train([img])
          img+=grads*lr
          result.append(model.predict(img))
          imgAll.append(img)

     return result,imgAll

(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train/=255
X_test/=255

y_train = keras.utils.to_categorical(Y_train, 10)
y_test = keras.utils.to_categorical(Y_test, 10)

X_train = X_train[1:3000]
Y_train = Y_train[1:3000]
X_test = X_test[1:2000]
Y_test = Y_test[1:2000]

model =Mymodel()

Train on 2999 samples, validate on 1999 samples Epoch 1/12 2999/2999 [==============================] – 5s – loss: 1.2724 – acc: 0.6255 – val_loss: 1.1250 – val_acc: 0.5718 Epoch 2/12 2999/2999 [==============================] – 4s – loss: 0.4874 – acc: 0.8453 – val_loss: 0.4232 – val_acc: 0.8674 Epoch 3/12 2999/2999 [==============================] – 4s – loss: 0.2596 – acc: 0.9240 – val_loss: 0.3513 – val_acc: 0.8869 Epoch 4/12 2999/2999 [==============================] – 4s – loss: 0.2157 – acc: 0.9373 – val_loss: 0.3305 – val_acc: 0.8979 Epoch 5/12 2999/2999 [==============================] – 4s – loss: 0.1557 – acc: 0.9590 – val_loss: 0.2923 – val_acc: 0.9070 Epoch 6/12 2999/2999 [==============================] – 4s – loss: 0.1248 – acc: 0.9643 – val_loss: 0.2511 – val_acc: 0.9210 Epoch 7/12 2999/2999 [==============================] – 4s – loss: 0.0971 – acc: 0.9737 – val_loss: 0.2351 – val_acc: 0.9275 Epoch 8/12 2999/2999 [==============================] – 4s – loss: 0.0783 – acc: 0.9797 – val_loss: 0.2009 – val_acc: 0.9375 Epoch 9/12 2999/2999 [==============================] – 4s – loss: 0.0604 – acc: 0.9840 – val_loss: 0.2175 – val_acc: 0.9295 Epoch 10/12 2999/2999 [==============================] – 4s – loss: 0.0467 – acc: 0.9880 – val_loss: 0.1793 – val_acc: 0.9465 Epoch 11/12 2999/2999 [==============================] – 5s – loss: 0.0331 – acc: 0.9927 – val_loss: 0.1752 – val_acc: 0.9480 Epoch 12/12 2999/2999 [==============================] – 4s – loss: 0.0288 – acc: 0.9943 – val_loss: 0.1800 – val_acc: 0.9435

img = np.random.random((1,28,28,1))
fig = plt.figure()
plt.imshow(img.reshape((28,28)),interpolation='nearest')
plt.title("Üretilen Random Resim")


img = np.random.random((1,28,28,1))
lr =0.1
iterate =20
y_true =2
result,generate_image = GenerateImage(model,y_true ,img,lr,iterate)
fig = plt.figure()
plt.imshow(generate_image[9].reshape((28,28)),interpolation='nearest')
plt.title("iki")

for i in range(9):
    plt.subplot(330 + 1 + i)
    plt.imshow(generate_image[i+9].reshape(28, 28), interpolation='nearest')
plt.show()

Bir cevap yazın

E-posta hesabınız yayımlanmayacak. Gerekli alanlar * ile işaretlenmişlerdir