from keras.layers import Input, Conv2D, UpSampling2D,\
MaxPooling2D, Dense, Reshape, Flatten, Conv2DTranspose
from keras.models import Model
import tensorflow_probability as tfp
tfd = tfp.distributions
import tensorflow as tf
import keras as K
import pickle
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
from abc import abstractmethod
with open('../input/textures_42000_28px.pkl', 'rb') as f:
data = pickle.load(f)
BATCH_SIZE = 128
X_train = data['train_images']
X_test = data['test_images']
# Shuffle
X_train, _ = train_test_split(X_train, test_size=0, random_state=45)
X_test, _ = train_test_split(X_test, test_size=0, random_state=42)
Using TensorFlow backend.
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
* https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.
class AutoEncoder:
def __init__(self, input_shape, latent_dim):
self.input_shape = input_shape
self.latent_dim = latent_dim
@abstractmethod
def _encoder(self, input_img):
pass
@abstractmethod
def _decoder(self, latent):
pass
@abstractmethod
def get_compiled_model(self, loss_fn=None):
pass
def _get_loss(self, loss_fn):
if loss_fn == None:
return self._bernoulli
elif loss_fn == "binary":
return self._binary
elif loss_fn == "normal":
return self._normal
elif loss_fn == "normalDiag":
return self._normalDiag
"""
For binarized input
"""
def _binary(self, x_true, x_reco):
return -tf.nn.sigmoid_cross_entropy_with_logits(labels=x_true, logits=x_reco)
def _bernoulli(self, x_true, x_reco):
return -tf.reduce_mean(tfd.Bernoulli(x_reco)._log_prob(x_true))
"""
For non binarized input.
"""
def _normal(self, x_true, x_reco):
return -tf.reduce_mean(
tfd.Normal(x_reco, scale=0.001)._log_prob(x_true))
def _normalDiag(self, x_true, x_reco):
return -tf.reduce_mean(
tfd.MultivariateNormalDiag(x_reco, scale_identity_multiplier=tf.Variable(0.001))._log_prob(x_true))
################################################################################
class ConvolutionalAutoEncoder(AutoEncoder):
def _encoder(self, input_img):
x = Conv2D(512, (2, 2), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(256, (2, 2), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(128, (2, 2), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
encoded = Flatten()(x)
return encoded
def _decoder(self, latent):
x = Reshape((4, 4, 8))(latent)
x = Conv2DTranspose(128, (2, 2), strides=(2, 2), activation='relu', padding='same')(x)
x = Conv2D(256, (2, 2), activation='relu')(x)
x = Conv2DTranspose(512, (2, 2), strides=(2, 2), activation='relu', padding='same')(x)
reco = Conv2DTranspose(1, (2, 2), strides=(2, 2))(x)
return reco
def get_compiled_model(self, loss_fn=None):
self.loss_fn = super()._get_loss(loss_fn)
print(self.loss_fn)
input_img = Input(shape=self.input_shape)
encoded = self._encoder(input_img)
latent = Dense(self.latent_dim)(encoded)
reco = self._decoder(latent)
model = Model(input_img, reco)
model.compile(optimizer='adadelta', loss=self.loss_fn)
return model
################################################################################
################################################################################
class ConvolutionalAutoEncoderCIFAR(AutoEncoder):
def _encoder(self, input_img):
x = Conv2D(512, (2, 2), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(256, (2, 2), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(128, (2, 2), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
encoded = Flatten()(x)
return encoded
def _decoder(self, latent):
x = Reshape((4, 4, 16))(latent)
x = Conv2DTranspose(256, (2, 2), strides=(2, 2), activation='relu', padding='same')(x)
x = Conv2DTranspose(512, (2, 2), strides=(2, 2), activation='relu', padding='same')(x)
x = Conv2DTranspose(1024, (2, 2), strides=(2, 2), activation='relu', padding='same')(x)
reco = Conv2D(3, (2, 2), padding='same')(x)
return reco
def get_compiled_model(self, loss_fn=None):
self.loss_fn = super()._get_loss(loss_fn)
input_img = Input(shape=self.input_shape)
encoded = self._encoder(input_img)
latent = Dense(self.latent_dim)(encoded)
reco = self._decoder(latent)
model = Model(input_img, reco)
model.compile(optimizer='adadelta', loss=self.loss_fn)
return model
################################################################################
################################################################################
class DenseAutoEncoder(AutoEncoder):
def _encoder(self, input_tensor):
x = Dense(512, activation='relu')(input_tensor)
x = Dense(256, activation='relu')(x)
return x
def _decoder(self, latent):
x = Dense(256, activation='relu')(latent)
x = Dense(512, activation='relu')(x)
reco = Dense(self.input_shape[1])(x)
return reco
def get_compiled_model(self, loss_fn=None):
self.loss_fn = super()._get_loss(loss_fn)
input_tensor = Input(shape=(self.input_shape[1],))
encoded = self._encoder(input_tensor)
latent = Dense(self.latent_dim)(encoded)
reco = self._decoder(latent)
model = Model(input_tensor, reco)
model.compile(optimizer='adam', loss=self.loss_fn)
return model
################################################################################
from keras.preprocessing.image import ImageDataGenerator
class DataGenerator:
def __init__(self, train, test, BATCH_SIZE=128, IMAGE_SHAPE=(28, 28, 1)):
self.DATAGEN = ImageDataGenerator()
self.IMAGE_SHAPE = IMAGE_SHAPE
self.BATCH_SIZE = BATCH_SIZE
self.train = train
self.test = test
self._train = train.reshape(X_train.shape[0], *IMAGE_SHAPE)
self._test = test.reshape(X_test.shape[0], *IMAGE_SHAPE)
def flow(self):
return self.DATAGEN.flow(self._train, self._train, batch_size=self.BATCH_SIZE)
def flatten_flow(self):
def train_generator(_it):
image_dim = self.IMAGE_SHAPE[0]*self.IMAGE_SHAPE[1]*self.IMAGE_SHAPE[2]
while True:
batch_x, batch_y = next(_it)
yield batch_x.reshape(batch_x.shape[0], image_dim), batch_y.reshape(batch_y.shape[0], image_dim)
return train_generator(self.flow())
def validation_data(self):
return self._test, self._test
def flattened_validation_data(self):
return self.test, self.test
textures = DataGenerator(X_train, X_test)
flow = textures.flow()
flatten_flow = textures.flatten_flow()
validation_data = textures.validation_data()
flattened_validation_data = textures.flattened_validation_data()
autoencoder = ConvolutionalAutoEncoder(input_shape=(28, 28, 1), latent_dim=4*4*8)
model = autoencoder.get_compiled_model("normal")
model.summary()
<bound method AutoEncoder._normal of <__main__.ConvolutionalAutoEncoder object at 0x7f70d93d16a0>>
WARNING:tensorflow:From /opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 28, 28, 1) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 28, 28, 512) 2560
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 14, 14, 256) 524544
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 7, 7, 256) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 7, 7, 128) 131200
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 4, 4, 128) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 2048) 0
_________________________________________________________________
dense_1 (Dense) (None, 128) 262272
_________________________________________________________________
reshape_1 (Reshape) (None, 4, 4, 8) 0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 8, 8, 128) 4224
_________________________________________________________________
conv2d_4 (Conv2D) (None, 7, 7, 256) 131328
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 14, 14, 512) 524800
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 28, 28, 1) 2049
=================================================================
Total params: 1,582,977
Trainable params: 1,582,977
Non-trainable params: 0
_________________________________________________________________
history = model.fit_generator(
flow,
steps_per_epoch=250, verbose=1, epochs=120, validation_data=validation_data)
Epoch 1/120
250/250 [==============================] - 11s 45ms/step - loss: 23619.4608 - val_loss: 13064.3834
Epoch 2/120
250/250 [==============================] - 9s 34ms/step - loss: 12936.3868 - val_loss: 12984.5468
Epoch 3/120
250/250 [==============================] - 8s 34ms/step - loss: 12494.9925 - val_loss: 12053.7161
Epoch 4/120
250/250 [==============================] - 8s 34ms/step - loss: 11815.8750 - val_loss: 11714.9668
Epoch 5/120
250/250 [==============================] - 9s 34ms/step - loss: 11446.3867 - val_loss: 11660.5768
Epoch 6/120
250/250 [==============================] - 8s 34ms/step - loss: 11086.7619 - val_loss: 11206.2935
Epoch 7/120
250/250 [==============================] - 8s 34ms/step - loss: 10698.1091 - val_loss: 10769.3128
Epoch 8/120
250/250 [==============================] - 8s 34ms/step - loss: 10439.2475 - val_loss: 10465.4964
Epoch 9/120
250/250 [==============================] - 8s 34ms/step - loss: 10146.3715 - val_loss: 10258.0282
Epoch 10/120
250/250 [==============================] - 8s 34ms/step - loss: 10054.5833 - val_loss: 9868.2628
Epoch 11/120
250/250 [==============================] - 8s 34ms/step - loss: 9772.5510 - val_loss: 9884.0684
Epoch 12/120
250/250 [==============================] - 8s 34ms/step - loss: 9650.4927 - val_loss: 9731.5621
Epoch 13/120
250/250 [==============================] - 8s 34ms/step - loss: 9595.5500 - val_loss: 9677.3107
Epoch 14/120
250/250 [==============================] - 9s 34ms/step - loss: 9493.5600 - val_loss: 9483.5515
Epoch 15/120
250/250 [==============================] - 9s 34ms/step - loss: 9375.6692 - val_loss: 9489.3944
Epoch 16/120
250/250 [==============================] - 8s 34ms/step - loss: 9261.4709 - val_loss: 9552.0440
Epoch 17/120
250/250 [==============================] - 8s 34ms/step - loss: 9237.6051 - val_loss: 9599.7306
Epoch 18/120
250/250 [==============================] - 8s 34ms/step - loss: 9124.3825 - val_loss: 9267.0974
Epoch 19/120
250/250 [==============================] - 8s 34ms/step - loss: 9079.5191 - val_loss: 9225.4408
Epoch 20/120
250/250 [==============================] - 9s 34ms/step - loss: 8894.8275 - val_loss: 9194.9212
Epoch 21/120
250/250 [==============================] - 9s 34ms/step - loss: 8922.2388 - val_loss: 8982.7503
Epoch 22/120
250/250 [==============================] - 8s 34ms/step - loss: 8761.3968 - val_loss: 8985.6998
Epoch 23/120
250/250 [==============================] - 8s 34ms/step - loss: 8692.8831 - val_loss: 8927.6003
Epoch 24/120
250/250 [==============================] - 9s 34ms/step - loss: 8759.2916 - val_loss: 8826.8335
Epoch 25/120
250/250 [==============================] - 8s 34ms/step - loss: 8559.3473 - val_loss: 8922.2421
Epoch 26/120
250/250 [==============================] - 8s 34ms/step - loss: 8602.5747 - val_loss: 8893.8486
Epoch 27/120
250/250 [==============================] - 8s 34ms/step - loss: 8562.6724 - val_loss: 8792.2157
Epoch 28/120
250/250 [==============================] - 8s 34ms/step - loss: 8473.0045 - val_loss: 8770.5060
Epoch 29/120
250/250 [==============================] - 8s 34ms/step - loss: 8493.5419 - val_loss: 8907.6355
Epoch 30/120
250/250 [==============================] - 8s 34ms/step - loss: 8388.3239 - val_loss: 8865.3505
Epoch 31/120
250/250 [==============================] - 8s 34ms/step - loss: 8408.6520 - val_loss: 8824.8138
Epoch 32/120
250/250 [==============================] - 8s 34ms/step - loss: 8378.0836 - val_loss: 8846.5811
Epoch 33/120
250/250 [==============================] - 9s 34ms/step - loss: 8372.1065 - val_loss: 8744.0001
Epoch 34/120
250/250 [==============================] - 9s 34ms/step - loss: 8229.4310 - val_loss: 8696.3063
Epoch 35/120
250/250 [==============================] - 8s 34ms/step - loss: 8342.9508 - val_loss: 8964.1496
Epoch 36/120
250/250 [==============================] - 8s 34ms/step - loss: 8245.0309 - val_loss: 8712.2397
Epoch 37/120
250/250 [==============================] - 8s 34ms/step - loss: 8210.9562 - val_loss: 8708.8152
Epoch 38/120
250/250 [==============================] - 8s 34ms/step - loss: 8244.6897 - val_loss: 8769.0445
Epoch 39/120
250/250 [==============================] - 8s 34ms/step - loss: 8173.0771 - val_loss: 8714.7328
Epoch 40/120
250/250 [==============================] - 8s 34ms/step - loss: 8174.5877 - val_loss: 8679.3584
Epoch 41/120
250/250 [==============================] - 8s 34ms/step - loss: 8101.5290 - val_loss: 8665.1504
Epoch 42/120
250/250 [==============================] - 8s 34ms/step - loss: 8149.5159 - val_loss: 8740.8374
Epoch 43/120
250/250 [==============================] - 9s 34ms/step - loss: 8096.0216 - val_loss: 8755.7126
Epoch 44/120
250/250 [==============================] - 8s 34ms/step - loss: 8076.7973 - val_loss: 8738.7180
Epoch 45/120
250/250 [==============================] - 8s 34ms/step - loss: 8065.7741 - val_loss: 8895.6037
Epoch 46/120
250/250 [==============================] - 8s 34ms/step - loss: 8037.6332 - val_loss: 8693.5458
Epoch 47/120
250/250 [==============================] - 8s 34ms/step - loss: 7969.5045 - val_loss: 8657.9726
Epoch 48/120
142/250 [================>.............] - ETA: 3s - loss: 8082.0893
# "Accuracy"
plt.title("Convolutional model")
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()
reco = model.predict(X_train[:BATCH_SIZE].reshape(BATCH_SIZE, 28, 28, 1))
fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(7, 7))
for ind, ax in enumerate(axes.flatten()):
if ind % 2 == 0:
ax.set_title('Reconstructed')
ax.imshow(reco[ind].reshape(28, 28))
else:
ax.set_title(' <- Original')
ax.imshow(X_train[ind - 1].reshape(28, 28))
fig.tight_layout()
plt.show()
autoencoder = DenseAutoEncoder(input_shape=(None, 28*28*1), latent_dim=128)
model = autoencoder.get_compiled_model("normal")
model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) (None, 784) 0
_________________________________________________________________
dense_2 (Dense) (None, 512) 401920
_________________________________________________________________
dense_3 (Dense) (None, 256) 131328
_________________________________________________________________
dense_4 (Dense) (None, 128) 32896
_________________________________________________________________
dense_5 (Dense) (None, 256) 33024
_________________________________________________________________
dense_6 (Dense) (None, 512) 131584
_________________________________________________________________
dense_7 (Dense) (None, 784) 402192
=================================================================
Total params: 1,132,944
Trainable params: 1,132,944
Non-trainable params: 0
_________________________________________________________________
history = model.fit_generator(
flatten_flow,
steps_per_epoch=450, verbose=1, epochs=120, validation_data=flattened_validation_data)
Epoch 1/120
450/450 [==============================] - 3s 6ms/step - loss: 13660.7373 - val_loss: 12234.5759
Epoch 2/120
450/450 [==============================] - 2s 5ms/step - loss: 11909.9524 - val_loss: 11495.2333
Epoch 3/120
450/450 [==============================] - 2s 5ms/step - loss: 11203.8561 - val_loss: 11164.1901
Epoch 4/120
450/450 [==============================] - 2s 5ms/step - loss: 10724.8050 - val_loss: 10600.2170
Epoch 5/120
450/450 [==============================] - 2s 5ms/step - loss: 10409.8693 - val_loss: 10335.3044
Epoch 6/120
450/450 [==============================] - 2s 5ms/step - loss: 10218.1451 - val_loss: 10181.2365
Epoch 7/120
450/450 [==============================] - 2s 5ms/step - loss: 10010.8329 - val_loss: 10118.8608
Epoch 8/120
450/450 [==============================] - 2s 5ms/step - loss: 9843.1701 - val_loss: 9902.0219
Epoch 9/120
450/450 [==============================] - 2s 5ms/step - loss: 9792.8706 - val_loss: 9827.7078
Epoch 10/120
450/450 [==============================] - 2s 5ms/step - loss: 9650.9894 - val_loss: 9882.1016
Epoch 11/120
450/450 [==============================] - 2s 5ms/step - loss: 9616.6779 - val_loss: 9713.2675
Epoch 12/120
450/450 [==============================] - 2s 5ms/step - loss: 9569.2232 - val_loss: 9704.1599
Epoch 13/120
450/450 [==============================] - 2s 5ms/step - loss: 9505.5969 - val_loss: 9658.7047
Epoch 14/120
450/450 [==============================] - 2s 5ms/step - loss: 9442.9852 - val_loss: 9652.5928
Epoch 15/120
450/450 [==============================] - 2s 5ms/step - loss: 9460.0144 - val_loss: 9623.0287
Epoch 16/120
450/450 [==============================] - 2s 5ms/step - loss: 9386.0842 - val_loss: 9632.4656
Epoch 17/120
450/450 [==============================] - 2s 5ms/step - loss: 9350.1733 - val_loss: 9601.1257
Epoch 18/120
450/450 [==============================] - 2s 5ms/step - loss: 9316.0165 - val_loss: 9606.2534
Epoch 19/120
450/450 [==============================] - 2s 5ms/step - loss: 9281.1093 - val_loss: 9607.0830
Epoch 20/120
450/450 [==============================] - 2s 5ms/step - loss: 9275.0234 - val_loss: 9620.9025
Epoch 21/120
450/450 [==============================] - 2s 5ms/step - loss: 9215.7923 - val_loss: 9602.3895
Epoch 22/120
450/450 [==============================] - 2s 5ms/step - loss: 9179.6025 - val_loss: 9603.2651
Epoch 23/120
450/450 [==============================] - 2s 5ms/step - loss: 9203.7196 - val_loss: 9597.2188
Epoch 24/120
450/450 [==============================] - 2s 5ms/step - loss: 9156.7819 - val_loss: 9614.7970
Epoch 25/120
450/450 [==============================] - 2s 5ms/step - loss: 9161.8855 - val_loss: 9633.5151
Epoch 26/120
450/450 [==============================] - 2s 5ms/step - loss: 9101.4849 - val_loss: 9617.0719
Epoch 27/120
450/450 [==============================] - 2s 5ms/step - loss: 9126.8162 - val_loss: 9605.8102
Epoch 28/120
450/450 [==============================] - 2s 5ms/step - loss: 9074.2612 - val_loss: 9615.1227
Epoch 29/120
450/450 [==============================] - 2s 5ms/step - loss: 9088.8181 - val_loss: 9577.7517
Epoch 30/120
450/450 [==============================] - 2s 5ms/step - loss: 9053.9263 - val_loss: 9617.8864
Epoch 31/120
450/450 [==============================] - 2s 5ms/step - loss: 9068.4025 - val_loss: 9612.6874
Epoch 32/120
450/450 [==============================] - 2s 5ms/step - loss: 9027.4757 - val_loss: 9606.7799
Epoch 33/120
450/450 [==============================] - 2s 5ms/step - loss: 9061.8093 - val_loss: 9614.7357
Epoch 34/120
450/450 [==============================] - 2s 6ms/step - loss: 9000.1418 - val_loss: 9612.5115
Epoch 35/120
450/450 [==============================] - 3s 6ms/step - loss: 9029.9241 - val_loss: 9627.1765
Epoch 36/120
450/450 [==============================] - 2s 6ms/step - loss: 8992.8389 - val_loss: 9627.1121
Epoch 37/120
450/450 [==============================] - 2s 5ms/step - loss: 8985.4371 - val_loss: 9584.7681
Epoch 38/120
450/450 [==============================] - 2s 5ms/step - loss: 8977.5217 - val_loss: 9579.5192
Epoch 39/120
450/450 [==============================] - 2s 5ms/step - loss: 8982.8762 - val_loss: 9629.7528
Epoch 40/120
450/450 [==============================] - 2s 5ms/step - loss: 9008.0115 - val_loss: 9640.0305
Epoch 41/120
450/450 [==============================] - 2s 5ms/step - loss: 8938.5614 - val_loss: 9612.6247
Epoch 42/120
450/450 [==============================] - 2s 5ms/step - loss: 8982.7584 - val_loss: 9643.7563
Epoch 43/120
450/450 [==============================] - 2s 5ms/step - loss: 8925.7008 - val_loss: 9601.1343
Epoch 44/120
450/450 [==============================] - 2s 5ms/step - loss: 8942.8681 - val_loss: 9607.0021
Epoch 45/120
450/450 [==============================] - 2s 5ms/step - loss: 8915.8623 - val_loss: 9637.9321
Epoch 46/120
450/450 [==============================] - 2s 5ms/step - loss: 8973.4503 - val_loss: 9626.5003
Epoch 47/120
450/450 [==============================] - 2s 5ms/step - loss: 8896.2719 - val_loss: 9636.9351
Epoch 48/120
450/450 [==============================] - 2s 5ms/step - loss: 8928.6129 - val_loss: 9616.5152
Epoch 49/120
450/450 [==============================] - 2s 5ms/step - loss: 8912.7885 - val_loss: 9631.0610
Epoch 50/120
450/450 [==============================] - 2s 5ms/step - loss: 8911.1050 - val_loss: 9696.2987
Epoch 51/120
450/450 [==============================] - 2s 5ms/step - loss: 8882.1004 - val_loss: 9652.0001
Epoch 52/120
450/450 [==============================] - 2s 5ms/step - loss: 8923.3460 - val_loss: 9628.9417
Epoch 53/120
450/450 [==============================] - 2s 5ms/step - loss: 8888.6763 - val_loss: 9650.4985
Epoch 54/120
450/450 [==============================] - 2s 5ms/step - loss: 8890.2076 - val_loss: 9662.9747
Epoch 55/120
450/450 [==============================] - 2s 5ms/step - loss: 8876.1870 - val_loss: 9628.4428
Epoch 56/120
450/450 [==============================] - 2s 5ms/step - loss: 8863.6003 - val_loss: 9612.9577
Epoch 57/120
450/450 [==============================] - 2s 5ms/step - loss: 8847.7704 - val_loss: 9639.4333
Epoch 58/120
450/450 [==============================] - 2s 5ms/step - loss: 8875.4535 - val_loss: 9635.2878
Epoch 59/120
450/450 [==============================] - 2s 5ms/step - loss: 8840.6928 - val_loss: 9636.2642
Epoch 60/120
450/450 [==============================] - 2s 5ms/step - loss: 8819.1659 - val_loss: 9612.3059
Epoch 61/120
450/450 [==============================] - 2s 5ms/step - loss: 8836.8522 - val_loss: 9639.3892
Epoch 62/120
450/450 [==============================] - 2s 5ms/step - loss: 8844.5712 - val_loss: 9622.2806
Epoch 63/120
450/450 [==============================] - 2s 5ms/step - loss: 8821.7494 - val_loss: 9621.6694
Epoch 64/120
450/450 [==============================] - 2s 5ms/step - loss: 8815.1585 - val_loss: 9595.4954
Epoch 65/120
450/450 [==============================] - 2s 5ms/step - loss: 8833.7130 - val_loss: 9624.7817
Epoch 66/120
450/450 [==============================] - 2s 5ms/step - loss: 8812.2914 - val_loss: 9629.3812
Epoch 67/120
450/450 [==============================] - 2s 5ms/step - loss: 8813.4848 - val_loss: 9646.8900
Epoch 68/120
450/450 [==============================] - 2s 5ms/step - loss: 8771.5191 - val_loss: 9629.7945
Epoch 69/120
450/450 [==============================] - 2s 5ms/step - loss: 8823.6780 - val_loss: 9692.3803
Epoch 70/120
450/450 [==============================] - 2s 5ms/step - loss: 8798.5820 - val_loss: 9655.7915
Epoch 71/120
450/450 [==============================] - 2s 5ms/step - loss: 8827.3114 - val_loss: 9655.3335
Epoch 72/120
450/450 [==============================] - 2s 5ms/step - loss: 8766.0854 - val_loss: 9720.3519
Epoch 73/120
450/450 [==============================] - 2s 5ms/step - loss: 8823.4842 - val_loss: 9676.8129
Epoch 74/120
450/450 [==============================] - 2s 5ms/step - loss: 8763.5335 - val_loss: 9654.2177
Epoch 75/120
450/450 [==============================] - 2s 5ms/step - loss: 8811.3815 - val_loss: 9641.6519
Epoch 76/120
450/450 [==============================] - 2s 5ms/step - loss: 8768.8431 - val_loss: 9670.0294
Epoch 77/120
450/450 [==============================] - 2s 5ms/step - loss: 8804.1887 - val_loss: 9617.8878
Epoch 78/120
450/450 [==============================] - 2s 5ms/step - loss: 8773.5281 - val_loss: 9618.0616
Epoch 79/120
450/450 [==============================] - 2s 5ms/step - loss: 8753.8751 - val_loss: 9637.1529
Epoch 80/120
450/450 [==============================] - 2s 5ms/step - loss: 8796.6136 - val_loss: 9649.2580
Epoch 81/120
450/450 [==============================] - 2s 5ms/step - loss: 8753.0756 - val_loss: 9658.9972
Epoch 82/120
450/450 [==============================] - 2s 5ms/step - loss: 8778.9247 - val_loss: 9659.6327
Epoch 83/120
450/450 [==============================] - 2s 5ms/step - loss: 8751.4601 - val_loss: 9659.8732
Epoch 84/120
450/450 [==============================] - 2s 5ms/step - loss: 8783.7770 - val_loss: 9658.2439
Epoch 85/120
450/450 [==============================] - 2s 5ms/step - loss: 8753.4256 - val_loss: 9659.1800
Epoch 86/120
450/450 [==============================] - 2s 5ms/step - loss: 8735.5612 - val_loss: 9658.8878
Epoch 87/120
450/450 [==============================] - 2s 5ms/step - loss: 8764.9705 - val_loss: 9668.0582
Epoch 88/120
450/450 [==============================] - 2s 5ms/step - loss: 8771.7495 - val_loss: 9678.0606
Epoch 89/120
450/450 [==============================] - 2s 5ms/step - loss: 8749.6746 - val_loss: 9653.9383
Epoch 90/120
450/450 [==============================] - 2s 5ms/step - loss: 8753.7380 - val_loss: 9652.8087
Epoch 91/120
450/450 [==============================] - 2s 5ms/step - loss: 8722.6513 - val_loss: 9674.6652
Epoch 92/120
450/450 [==============================] - 2s 5ms/step - loss: 8749.0611 - val_loss: 9707.6248
Epoch 93/120
450/450 [==============================] - 2s 5ms/step - loss: 8756.7330 - val_loss: 9723.0908
Epoch 94/120
450/450 [==============================] - 2s 5ms/step - loss: 8736.2706 - val_loss: 9719.4223
Epoch 95/120
450/450 [==============================] - 2s 5ms/step - loss: 8733.7426 - val_loss: 9708.7910
Epoch 96/120
450/450 [==============================] - 2s 5ms/step - loss: 8758.1080 - val_loss: 9667.8734
Epoch 97/120
450/450 [==============================] - 2s 5ms/step - loss: 8713.1392 - val_loss: 9691.3214
Epoch 98/120
450/450 [==============================] - 2s 5ms/step - loss: 8762.0678 - val_loss: 9652.0038
Epoch 99/120
450/450 [==============================] - 2s 5ms/step - loss: 8709.4888 - val_loss: 9710.7233
Epoch 100/120
450/450 [==============================] - 2s 5ms/step - loss: 8755.1775 - val_loss: 9726.5732
Epoch 101/120
450/450 [==============================] - 2s 5ms/step - loss: 8722.7630 - val_loss: 9712.9317
Epoch 102/120
450/450 [==============================] - 2s 5ms/step - loss: 8737.0626 - val_loss: 9730.7912
Epoch 103/120
450/450 [==============================] - 2s 5ms/step - loss: 8708.2567 - val_loss: 9671.1690
Epoch 104/120
450/450 [==============================] - 2s 5ms/step - loss: 8743.8494 - val_loss: 9690.0946
Epoch 105/120
450/450 [==============================] - 2s 5ms/step - loss: 8730.4965 - val_loss: 9691.3483
Epoch 106/120
450/450 [==============================] - 2s 5ms/step - loss: 8714.1750 - val_loss: 9680.5736
Epoch 107/120
450/450 [==============================] - 2s 5ms/step - loss: 8715.7814 - val_loss: 9730.1071
Epoch 108/120
450/450 [==============================] - 2s 5ms/step - loss: 8731.5582 - val_loss: 9698.8861
Epoch 109/120
450/450 [==============================] - 2s 5ms/step - loss: 8706.9562 - val_loss: 9718.0676
Epoch 110/120
450/450 [==============================] - 2s 5ms/step - loss: 8696.3995 - val_loss: 9720.3038
Epoch 111/120
450/450 [==============================] - 2s 5ms/step - loss: 8712.9437 - val_loss: 9742.4841
Epoch 112/120
450/450 [==============================] - 2s 5ms/step - loss: 8700.8108 - val_loss: 9753.1207
Epoch 113/120
450/450 [==============================] - 2s 5ms/step - loss: 8751.4395 - val_loss: 9713.1487
Epoch 114/120
450/450 [==============================] - 2s 5ms/step - loss: 8714.1742 - val_loss: 9681.2843
Epoch 115/120
450/450 [==============================] - 2s 5ms/step - loss: 8711.6666 - val_loss: 9724.4919
Epoch 116/120
450/450 [==============================] - 2s 5ms/step - loss: 8693.2260 - val_loss: 9727.7688
Epoch 117/120
450/450 [==============================] - 2s 5ms/step - loss: 8735.7899 - val_loss: 9714.4226
Epoch 118/120
450/450 [==============================] - 2s 5ms/step - loss: 8678.0076 - val_loss: 9734.9756
Epoch 119/120
450/450 [==============================] - 2s 5ms/step - loss: 8731.5442 - val_loss: 9700.1280
Epoch 120/120
450/450 [==============================] - 2s 5ms/step - loss: 8662.2679 - val_loss: 9717.2447
# "Accuracy"
plt.title("Dense model")
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()
reco = model.predict(X_train[:BATCH_SIZE])
fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(7, 7))
for ind, ax in enumerate(axes.flatten()):
if ind % 2 == 0:
ax.set_title('Reconstructed')
ax.imshow(reco[ind].reshape(28, 28))
else:
ax.set_title(' <- Original')
ax.imshow(X_train[ind - 1].reshape(28, 28))
fig.tight_layout()
plt.show()
##############################################################################
##############################################################################
##############################################################################
###########################################################################
######################## MNIST ###########################
###########################################################################
from keras.datasets.mnist import load_data
(X_train, _), (X_test, _) = load_data()
X_train = X_train.reshape(X_train.shape[0], 784)
X_test = X_test.reshape(X_test.shape[0], 784)
X_train = X_train / 255.
X_test = X_test / 255.
Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
11493376/11490434 [==============================] - 1s 0us/step
mnist = DataGenerator(X_train, X_test)
autoencoder = DenseAutoEncoder(input_shape=(None, 28*28), latent_dim=16)
model = autoencoder.get_compiled_model("normal")
model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_3 (InputLayer) (None, 784) 0
_________________________________________________________________
dense_8 (Dense) (None, 512) 401920
_________________________________________________________________
dense_9 (Dense) (None, 256) 131328
_________________________________________________________________
dense_10 (Dense) (None, 16) 4112
_________________________________________________________________
dense_11 (Dense) (None, 256) 4352
_________________________________________________________________
dense_12 (Dense) (None, 512) 131584
_________________________________________________________________
dense_13 (Dense) (None, 784) 402192
=================================================================
Total params: 1,075,488
Trainable params: 1,075,488
Non-trainable params: 0
_________________________________________________________________
history = model.fit_generator(
mnist.flatten_flow(),
steps_per_epoch=350, verbose=1, epochs=150, validation_data=mnist.flattened_validation_data())
Epoch 1/150
350/350 [==============================] - 2s 7ms/step - loss: 15142.9858 - val_loss: 10154.8432
Epoch 2/150
350/350 [==============================] - 2s 5ms/step - loss: 9418.3889 - val_loss: 8631.6228
Epoch 3/150
350/350 [==============================] - 2s 5ms/step - loss: 8379.0186 - val_loss: 7896.3081
Epoch 4/150
350/350 [==============================] - 2s 5ms/step - loss: 7771.9378 - val_loss: 7472.9431
Epoch 5/150
350/350 [==============================] - 2s 5ms/step - loss: 7339.4300 - val_loss: 7160.7017
Epoch 6/150
350/350 [==============================] - 2s 5ms/step - loss: 7016.3167 - val_loss: 6832.9712
Epoch 7/150
350/350 [==============================] - 2s 5ms/step - loss: 6774.9122 - val_loss: 6645.7257
Epoch 8/150
350/350 [==============================] - 2s 5ms/step - loss: 6564.5668 - val_loss: 6477.7400
Epoch 9/150
350/350 [==============================] - 2s 5ms/step - loss: 6374.8392 - val_loss: 6324.1827
Epoch 10/150
350/350 [==============================] - 2s 5ms/step - loss: 6208.7290 - val_loss: 6219.8383
Epoch 11/150
350/350 [==============================] - 2s 5ms/step - loss: 6085.0264 - val_loss: 6103.7758
Epoch 12/150
350/350 [==============================] - 2s 5ms/step - loss: 5995.3391 - val_loss: 5966.6138
Epoch 13/150
350/350 [==============================] - 2s 5ms/step - loss: 5856.9506 - val_loss: 5919.0582
Epoch 14/150
350/350 [==============================] - 2s 5ms/step - loss: 5776.9623 - val_loss: 5816.7632
Epoch 15/150
350/350 [==============================] - 2s 5ms/step - loss: 5690.6315 - val_loss: 5713.8578
Epoch 16/150
350/350 [==============================] - 2s 6ms/step - loss: 5617.9508 - val_loss: 5678.5259
Epoch 17/150
350/350 [==============================] - 2s 6ms/step - loss: 5517.5940 - val_loss: 5611.0078
Epoch 18/150
350/350 [==============================] - 2s 6ms/step - loss: 5476.8313 - val_loss: 5534.2184
Epoch 19/150
350/350 [==============================] - 2s 6ms/step - loss: 5397.5761 - val_loss: 5510.1132
Epoch 20/150
350/350 [==============================] - 2s 6ms/step - loss: 5364.1072 - val_loss: 5432.7259
Epoch 21/150
350/350 [==============================] - 2s 5ms/step - loss: 5287.6597 - val_loss: 5440.8132
Epoch 22/150
350/350 [==============================] - 2s 5ms/step - loss: 5239.5073 - val_loss: 5337.5478
Epoch 23/150
350/350 [==============================] - 2s 5ms/step - loss: 5203.0992 - val_loss: 5309.7039
Epoch 24/150
350/350 [==============================] - 2s 5ms/step - loss: 5145.2460 - val_loss: 5318.1425
Epoch 25/150
350/350 [==============================] - 2s 5ms/step - loss: 5109.3819 - val_loss: 5241.3973
Epoch 26/150
350/350 [==============================] - 2s 5ms/step - loss: 5054.7724 - val_loss: 5217.9484
Epoch 27/150
350/350 [==============================] - 2s 5ms/step - loss: 5055.7161 - val_loss: 5208.1617
Epoch 28/150
350/350 [==============================] - 2s 5ms/step - loss: 5022.2865 - val_loss: 5151.4846
Epoch 29/150
350/350 [==============================] - 2s 5ms/step - loss: 4955.6559 - val_loss: 5139.0714
Epoch 30/150
350/350 [==============================] - 2s 5ms/step - loss: 4919.5205 - val_loss: 5124.4518
Epoch 31/150
350/350 [==============================] - 2s 5ms/step - loss: 4908.8088 - val_loss: 5091.0135
Epoch 32/150
350/350 [==============================] - 2s 5ms/step - loss: 4889.2170 - val_loss: 5064.4834
Epoch 33/150
350/350 [==============================] - 2s 5ms/step - loss: 4830.6258 - val_loss: 5024.9037
Epoch 34/150
350/350 [==============================] - 2s 5ms/step - loss: 4813.9645 - val_loss: 4994.3811
Epoch 35/150
350/350 [==============================] - 2s 5ms/step - loss: 4793.0399 - val_loss: 4990.7652
Epoch 36/150
350/350 [==============================] - 2s 5ms/step - loss: 4773.4605 - val_loss: 4959.6751
Epoch 37/150
350/350 [==============================] - 2s 5ms/step - loss: 4730.6830 - val_loss: 4974.6059
Epoch 38/150
350/350 [==============================] - 2s 5ms/step - loss: 4722.5003 - val_loss: 4931.1226
Epoch 39/150
350/350 [==============================] - 2s 5ms/step - loss: 4709.6245 - val_loss: 4897.5205
Epoch 40/150
350/350 [==============================] - 2s 5ms/step - loss: 4674.4467 - val_loss: 4937.0693
Epoch 41/150
350/350 [==============================] - 2s 6ms/step - loss: 4653.7611 - val_loss: 4897.7901
Epoch 42/150
350/350 [==============================] - 2s 5ms/step - loss: 4624.2723 - val_loss: 4858.9102
Epoch 43/150
350/350 [==============================] - 2s 6ms/step - loss: 4640.7219 - val_loss: 4874.8077
Epoch 44/150
350/350 [==============================] - 2s 5ms/step - loss: 4609.1193 - val_loss: 4872.0475
Epoch 45/150
350/350 [==============================] - 2s 5ms/step - loss: 4555.4525 - val_loss: 4822.7313
Epoch 46/150
350/350 [==============================] - 2s 5ms/step - loss: 4571.8711 - val_loss: 4806.2080
Epoch 47/150
350/350 [==============================] - 2s 5ms/step - loss: 4550.2844 - val_loss: 4810.4823
Epoch 48/150
350/350 [==============================] - 2s 5ms/step - loss: 4534.4649 - val_loss: 4764.9597
Epoch 49/150
350/350 [==============================] - 2s 5ms/step - loss: 4499.4446 - val_loss: 4760.8366
Epoch 50/150
350/350 [==============================] - 2s 5ms/step - loss: 4498.6708 - val_loss: 4768.1551
Epoch 51/150
350/350 [==============================] - 2s 5ms/step - loss: 4487.0666 - val_loss: 4718.1325
Epoch 52/150
350/350 [==============================] - 2s 5ms/step - loss: 4456.1989 - val_loss: 4717.7671
Epoch 53/150
350/350 [==============================] - 2s 5ms/step - loss: 4462.3192 - val_loss: 4730.3048
Epoch 54/150
350/350 [==============================] - 2s 5ms/step - loss: 4428.9357 - val_loss: 4714.4234
Epoch 55/150
350/350 [==============================] - 2s 5ms/step - loss: 4447.7095 - val_loss: 4696.1759
Epoch 56/150
350/350 [==============================] - 2s 5ms/step - loss: 4402.7315 - val_loss: 4692.2478
Epoch 57/150
350/350 [==============================] - 2s 5ms/step - loss: 4409.0183 - val_loss: 4686.8035
Epoch 58/150
350/350 [==============================] - 2s 5ms/step - loss: 4384.0240 - val_loss: 4674.4149
Epoch 59/150
350/350 [==============================] - 2s 5ms/step - loss: 4396.2847 - val_loss: 4664.9112
Epoch 60/150
350/350 [==============================] - 2s 5ms/step - loss: 4355.7867 - val_loss: 4659.7444
Epoch 61/150
350/350 [==============================] - 2s 6ms/step - loss: 4348.8880 - val_loss: 4654.9212
Epoch 62/150
350/350 [==============================] - 2s 6ms/step - loss: 4331.1273 - val_loss: 4639.5037
Epoch 63/150
350/350 [==============================] - 2s 6ms/step - loss: 4356.6035 - val_loss: 4662.4944
Epoch 64/150
350/350 [==============================] - 2s 6ms/step - loss: 4293.6156 - val_loss: 4608.3568
Epoch 65/150
350/350 [==============================] - 2s 6ms/step - loss: 4331.4394 - val_loss: 4601.5748
Epoch 66/150
350/350 [==============================] - 2s 6ms/step - loss: 4305.1876 - val_loss: 4598.4597
Epoch 67/150
350/350 [==============================] - 2s 6ms/step - loss: 4294.3468 - val_loss: 4607.4704
Epoch 68/150
350/350 [==============================] - 2s 5ms/step - loss: 4274.7371 - val_loss: 4583.5191
Epoch 69/150
350/350 [==============================] - 2s 5ms/step - loss: 4273.1792 - val_loss: 4582.7774
Epoch 70/150
350/350 [==============================] - 2s 5ms/step - loss: 4254.7774 - val_loss: 4588.7107
Epoch 71/150
350/350 [==============================] - 2s 5ms/step - loss: 4268.2919 - val_loss: 4575.0549
Epoch 72/150
350/350 [==============================] - 2s 5ms/step - loss: 4227.4732 - val_loss: 4588.1124
Epoch 73/150
350/350 [==============================] - 2s 5ms/step - loss: 4247.2972 - val_loss: 4543.1085
Epoch 74/150
350/350 [==============================] - 2s 5ms/step - loss: 4224.8493 - val_loss: 4540.4183
Epoch 75/150
350/350 [==============================] - 2s 5ms/step - loss: 4233.6569 - val_loss: 4541.2076
Epoch 76/150
350/350 [==============================] - 2s 5ms/step - loss: 4199.9723 - val_loss: 4544.7116
Epoch 77/150
350/350 [==============================] - 2s 5ms/step - loss: 4208.2691 - val_loss: 4521.7573
Epoch 78/150
350/350 [==============================] - 2s 5ms/step - loss: 4200.9037 - val_loss: 4521.2768
Epoch 79/150
350/350 [==============================] - 2s 5ms/step - loss: 4197.6159 - val_loss: 4511.7038
Epoch 80/150
350/350 [==============================] - 2s 5ms/step - loss: 4163.2197 - val_loss: 4507.3937
Epoch 81/150
350/350 [==============================] - 2s 5ms/step - loss: 4182.1838 - val_loss: 4538.1486
Epoch 82/150
350/350 [==============================] - 2s 5ms/step - loss: 4178.9978 - val_loss: 4483.4095
Epoch 83/150
350/350 [==============================] - 2s 5ms/step - loss: 4169.9846 - val_loss: 4505.4287
Epoch 84/150
350/350 [==============================] - 2s 5ms/step - loss: 4129.6050 - val_loss: 4528.9829
Epoch 85/150
350/350 [==============================] - 2s 5ms/step - loss: 4155.2138 - val_loss: 4502.2921
Epoch 86/150
350/350 [==============================] - 2s 5ms/step - loss: 4158.6156 - val_loss: 4481.7994
Epoch 87/150
350/350 [==============================] - 2s 5ms/step - loss: 4130.7427 - val_loss: 4472.5305
Epoch 88/150
350/350 [==============================] - 2s 5ms/step - loss: 4124.1155 - val_loss: 4494.6808
Epoch 89/150
350/350 [==============================] - 2s 5ms/step - loss: 4133.3006 - val_loss: 4500.1980
Epoch 90/150
350/350 [==============================] - 2s 5ms/step - loss: 4109.9784 - val_loss: 4461.0232
Epoch 91/150
350/350 [==============================] - 2s 5ms/step - loss: 4114.9366 - val_loss: 4472.8007
Epoch 92/150
350/350 [==============================] - 2s 5ms/step - loss: 4102.4089 - val_loss: 4482.9677
Epoch 93/150
350/350 [==============================] - 2s 5ms/step - loss: 4086.8066 - val_loss: 4466.9449
Epoch 94/150
350/350 [==============================] - 2s 5ms/step - loss: 4095.2319 - val_loss: 4438.8469
Epoch 95/150
350/350 [==============================] - 2s 5ms/step - loss: 4094.4648 - val_loss: 4469.2971
Epoch 96/150
350/350 [==============================] - 2s 5ms/step - loss: 4068.9962 - val_loss: 4469.5180
Epoch 97/150
350/350 [==============================] - 2s 5ms/step - loss: 4076.9649 - val_loss: 4453.5678
Epoch 98/150
350/350 [==============================] - 2s 5ms/step - loss: 4081.3846 - val_loss: 4446.1051
Epoch 99/150
350/350 [==============================] - 2s 5ms/step - loss: 4071.5067 - val_loss: 4453.7051
Epoch 100/150
350/350 [==============================] - 2s 5ms/step - loss: 4064.1762 - val_loss: 4421.6713
Epoch 101/150
350/350 [==============================] - 2s 5ms/step - loss: 4045.6038 - val_loss: 4456.4311
Epoch 102/150
350/350 [==============================] - 2s 5ms/step - loss: 4064.4845 - val_loss: 4423.8930
Epoch 103/150
350/350 [==============================] - 2s 5ms/step - loss: 4038.6196 - val_loss: 4400.5361
Epoch 104/150
350/350 [==============================] - 2s 5ms/step - loss: 4027.4091 - val_loss: 4414.9927
Epoch 105/150
350/350 [==============================] - 2s 6ms/step - loss: 4034.2553 - val_loss: 4407.3641
Epoch 106/150
350/350 [==============================] - 2s 6ms/step - loss: 4039.9970 - val_loss: 4412.9459
Epoch 107/150
350/350 [==============================] - 2s 6ms/step - loss: 4033.0521 - val_loss: 4426.9133
Epoch 108/150
350/350 [==============================] - 2s 6ms/step - loss: 4023.4365 - val_loss: 4416.0162
Epoch 109/150
350/350 [==============================] - 2s 6ms/step - loss: 3993.4855 - val_loss: 4406.3521
Epoch 110/150
350/350 [==============================] - 2s 5ms/step - loss: 4021.4775 - val_loss: 4422.9886
Epoch 111/150
350/350 [==============================] - 2s 5ms/step - loss: 3996.0973 - val_loss: 4427.6794
Epoch 112/150
350/350 [==============================] - 2s 5ms/step - loss: 3997.3847 - val_loss: 4432.0120
Epoch 113/150
350/350 [==============================] - 2s 5ms/step - loss: 3990.9648 - val_loss: 4378.3308
Epoch 114/150
350/350 [==============================] - 2s 6ms/step - loss: 4005.8485 - val_loss: 4418.9368
Epoch 115/150
350/350 [==============================] - 2s 6ms/step - loss: 3997.0347 - val_loss: 4412.3469
Epoch 116/150
350/350 [==============================] - 2s 5ms/step - loss: 3969.9309 - val_loss: 4380.4555
Epoch 117/150
350/350 [==============================] - 2s 5ms/step - loss: 3969.1734 - val_loss: 4365.0506
Epoch 118/150
350/350 [==============================] - 2s 5ms/step - loss: 3996.6887 - val_loss: 4384.1859
Epoch 119/150
350/350 [==============================] - 2s 5ms/step - loss: 3967.5521 - val_loss: 4384.8100
Epoch 120/150
350/350 [==============================] - 2s 5ms/step - loss: 3956.5853 - val_loss: 4375.0360
Epoch 121/150
350/350 [==============================] - 2s 5ms/step - loss: 3969.0719 - val_loss: 4389.7087
Epoch 122/150
350/350 [==============================] - 2s 5ms/step - loss: 3958.9129 - val_loss: 4332.3287
Epoch 123/150
350/350 [==============================] - 2s 5ms/step - loss: 3952.7724 - val_loss: 4380.3590
Epoch 124/150
350/350 [==============================] - 2s 5ms/step - loss: 3952.8879 - val_loss: 4385.4176
Epoch 125/150
350/350 [==============================] - 2s 5ms/step - loss: 3935.7318 - val_loss: 4365.5322
Epoch 126/150
350/350 [==============================] - 2s 5ms/step - loss: 3959.4461 - val_loss: 4350.4541
Epoch 127/150
350/350 [==============================] - 2s 5ms/step - loss: 3925.5687 - val_loss: 4363.5778
Epoch 128/150
350/350 [==============================] - 2s 5ms/step - loss: 3924.8037 - val_loss: 4344.4835
Epoch 129/150
350/350 [==============================] - 2s 5ms/step - loss: 3943.3177 - val_loss: 4351.4751
Epoch 130/150
350/350 [==============================] - 2s 5ms/step - loss: 3946.9814 - val_loss: 4365.1673
Epoch 131/150
350/350 [==============================] - 2s 5ms/step - loss: 3931.2108 - val_loss: 4370.5150
Epoch 132/150
350/350 [==============================] - 2s 5ms/step - loss: 3912.0809 - val_loss: 4356.3229
Epoch 133/150
350/350 [==============================] - 2s 5ms/step - loss: 3922.0330 - val_loss: 4355.8407
Epoch 134/150
350/350 [==============================] - 2s 5ms/step - loss: 3929.2697 - val_loss: 4356.1826
Epoch 135/150
350/350 [==============================] - 2s 5ms/step - loss: 3906.8788 - val_loss: 4345.7961
Epoch 136/150
350/350 [==============================] - 2s 5ms/step - loss: 3907.6197 - val_loss: 4341.4962
Epoch 137/150
350/350 [==============================] - 2s 5ms/step - loss: 3903.8915 - val_loss: 4338.8913
Epoch 138/150
350/350 [==============================] - 2s 5ms/step - loss: 3927.9478 - val_loss: 4330.6186
Epoch 139/150
350/350 [==============================] - 2s 5ms/step - loss: 3895.5445 - val_loss: 4316.3565
Epoch 140/150
350/350 [==============================] - 2s 5ms/step - loss: 3896.4994 - val_loss: 4334.4785
Epoch 141/150
350/350 [==============================] - 2s 5ms/step - loss: 3900.6723 - val_loss: 4344.8252
Epoch 142/150
350/350 [==============================] - 2s 5ms/step - loss: 3891.5926 - val_loss: 4315.2764
Epoch 143/150
350/350 [==============================] - 2s 5ms/step - loss: 3866.5077 - val_loss: 4344.7123
Epoch 144/150
350/350 [==============================] - 2s 5ms/step - loss: 3895.1134 - val_loss: 4319.5282
Epoch 145/150
350/350 [==============================] - 2s 5ms/step - loss: 3886.9918 - val_loss: 4314.1042
Epoch 146/150
350/350 [==============================] - 2s 5ms/step - loss: 3881.7049 - val_loss: 4363.0188
Epoch 147/150
350/350 [==============================] - 2s 5ms/step - loss: 3875.8412 - val_loss: 4338.9938
Epoch 148/150
350/350 [==============================] - 2s 5ms/step - loss: 3884.6811 - val_loss: 4330.3281
Epoch 149/150
350/350 [==============================] - 2s 5ms/step - loss: 3862.4324 - val_loss: 4313.8704
Epoch 150/150
350/350 [==============================] - 2s 6ms/step - loss: 3883.1469 - val_loss: 4325.9544
print(history.history.keys())
# "Accuracy"
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()
dict_keys(['val_loss', 'loss'])
reco = model.predict(X_train[:BATCH_SIZE].reshape(BATCH_SIZE, 28*28))
fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(7, 7))
for ind, ax in enumerate(axes.flatten()):
if ind % 2 == 0:
ax.set_title('Reconstructed')
ax.imshow(reco[ind].reshape(28, 28))
else:
ax.set_title(' <- Original')
ax.imshow(X_train[ind - 1].reshape(28, 28))
fig.tight_layout()
plt.show()
autoencoder2 = ConvolutionalAutoEncoder(input_shape=(28, 28, 1), latent_dim=128)
model2 = autoencoder2.get_compiled_model("normal")
model2.summary()
<bound method AutoEncoder._normal of <__main__.ConvolutionalAutoEncoder object at 0x7f70787187f0>>
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_4 (InputLayer) (None, 28, 28, 1) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 28, 28, 512) 2560
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 14, 14, 256) 524544
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 7, 7, 256) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 7, 7, 128) 131200
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 4, 4, 128) 0
_________________________________________________________________
flatten_2 (Flatten) (None, 2048) 0
_________________________________________________________________
dense_14 (Dense) (None, 128) 262272
_________________________________________________________________
reshape_2 (Reshape) (None, 4, 4, 8) 0
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 8, 8, 128) 4224
_________________________________________________________________
conv2d_8 (Conv2D) (None, 7, 7, 256) 131328
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 14, 14, 512) 524800
_________________________________________________________________
conv2d_transpose_6 (Conv2DTr (None, 28, 28, 1) 2049
=================================================================
Total params: 1,582,977
Trainable params: 1,582,977
Non-trainable params: 0
_________________________________________________________________
history = model2.fit_generator(
mnist.flow(),
steps_per_epoch=300, verbose=1,
epochs=100, validation_data=mnist.validation_data())
Epoch 1/100
300/300 [==============================] - 12s 39ms/step - loss: 26276.0108 - val_loss: 15903.1574
Epoch 2/100
300/300 [==============================] - 11s 36ms/step - loss: 10832.0214 - val_loss: 7923.0541
Epoch 3/100
300/300 [==============================] - 11s 35ms/step - loss: 7165.7129 - val_loss: 5585.1446
Epoch 4/100
300/300 [==============================] - 11s 35ms/step - loss: 5581.9106 - val_loss: 4696.2799
Epoch 5/100
300/300 [==============================] - 11s 35ms/step - loss: 4777.7561 - val_loss: 5042.8660
Epoch 6/100
300/300 [==============================] - 11s 35ms/step - loss: 4225.6829 - val_loss: 3743.4631
Epoch 7/100
300/300 [==============================] - 11s 36ms/step - loss: 3844.4287 - val_loss: 3395.6976
Epoch 8/100
300/300 [==============================] - 11s 36ms/step - loss: 3575.6794 - val_loss: 3629.4973
Epoch 9/100
300/300 [==============================] - 11s 35ms/step - loss: 3316.2805 - val_loss: 3957.8419
Epoch 10/100
300/300 [==============================] - 11s 35ms/step - loss: 3160.1344 - val_loss: 3055.5530
Epoch 11/100
300/300 [==============================] - 11s 35ms/step - loss: 3013.6510 - val_loss: 3046.7351
Epoch 12/100
300/300 [==============================] - 11s 35ms/step - loss: 2870.5135 - val_loss: 3248.7812
Epoch 13/100
300/300 [==============================] - 11s 35ms/step - loss: 2773.4790 - val_loss: 2625.8067
Epoch 14/100
300/300 [==============================] - 11s 36ms/step - loss: 2662.2649 - val_loss: 2914.2163
Epoch 15/100
300/300 [==============================] - 11s 36ms/step - loss: 2565.0835 - val_loss: 2519.3543
Epoch 16/100
300/300 [==============================] - 11s 35ms/step - loss: 2489.7219 - val_loss: 2660.9476
Epoch 17/100
300/300 [==============================] - 11s 35ms/step - loss: 2436.0144 - val_loss: 2483.4026
Epoch 18/100
300/300 [==============================] - 11s 35ms/step - loss: 2362.3690 - val_loss: 2507.7855
Epoch 19/100
300/300 [==============================] - 11s 35ms/step - loss: 2311.3814 - val_loss: 2252.0987
Epoch 20/100
300/300 [==============================] - 11s 35ms/step - loss: 2271.6946 - val_loss: 2498.2432
Epoch 21/100
300/300 [==============================] - 11s 36ms/step - loss: 2208.3675 - val_loss: 2523.1442
Epoch 22/100
300/300 [==============================] - 11s 36ms/step - loss: 2182.5378 - val_loss: 2021.3894
Epoch 23/100
300/300 [==============================] - 11s 36ms/step - loss: 2114.6360 - val_loss: 2447.2277
Epoch 24/100
300/300 [==============================] - 11s 36ms/step - loss: 2104.1579 - val_loss: 2233.2220
Epoch 25/100
300/300 [==============================] - 11s 35ms/step - loss: 2072.9755 - val_loss: 2101.6603
Epoch 26/100
300/300 [==============================] - 11s 35ms/step - loss: 1986.9296 - val_loss: 2097.5168
Epoch 27/100
300/300 [==============================] - 11s 35ms/step - loss: 1960.4876 - val_loss: 2225.8555
Epoch 28/100
300/300 [==============================] - 11s 35ms/step - loss: 1963.3425 - val_loss: 2070.1966
Epoch 29/100
300/300 [==============================] - 11s 35ms/step - loss: 1920.9830 - val_loss: 2015.1648
Epoch 30/100
300/300 [==============================] - 11s 36ms/step - loss: 1890.1700 - val_loss: 2214.8284
Epoch 31/100
300/300 [==============================] - 11s 36ms/step - loss: 1863.4607 - val_loss: 1769.3381
Epoch 32/100
300/300 [==============================] - 11s 35ms/step - loss: 1839.5669 - val_loss: 2240.7169
Epoch 33/100
300/300 [==============================] - 11s 35ms/step - loss: 1827.7939 - val_loss: 2142.3612
Epoch 34/100
300/300 [==============================] - 11s 35ms/step - loss: 1816.1895 - val_loss: 1778.1006
Epoch 35/100
300/300 [==============================] - 11s 35ms/step - loss: 1807.2374 - val_loss: 1962.2870
Epoch 36/100
300/300 [==============================] - 11s 36ms/step - loss: 1782.4532 - val_loss: 1928.6105
Epoch 37/100
300/300 [==============================] - 11s 36ms/step - loss: 1738.3449 - val_loss: 2013.1086
Epoch 38/100
300/300 [==============================] - 11s 36ms/step - loss: 1748.3411 - val_loss: 1935.6326
Epoch 39/100
300/300 [==============================] - 11s 35ms/step - loss: 1732.2009 - val_loss: 1709.1579
Epoch 40/100
300/300 [==============================] - 11s 35ms/step - loss: 1704.6356 - val_loss: 1880.2759
Epoch 41/100
300/300 [==============================] - 11s 35ms/step - loss: 1711.4011 - val_loss: 1795.3849
Epoch 42/100
300/300 [==============================] - 11s 35ms/step - loss: 1672.9612 - val_loss: 2017.1508
Epoch 43/100
300/300 [==============================] - 11s 36ms/step - loss: 1671.0020 - val_loss: 1772.9418
Epoch 44/100
300/300 [==============================] - 11s 36ms/step - loss: 1671.7077 - val_loss: 2088.2386
Epoch 45/100
300/300 [==============================] - 11s 36ms/step - loss: 1662.2006 - val_loss: 1910.0319
Epoch 46/100
300/300 [==============================] - 11s 36ms/step - loss: 1636.4632 - val_loss: 1777.4621
Epoch 47/100
300/300 [==============================] - 11s 35ms/step - loss: 1639.2950 - val_loss: 1604.6839
Epoch 48/100
300/300 [==============================] - 11s 35ms/step - loss: 1618.2466 - val_loss: 1808.8041
Epoch 49/100
300/300 [==============================] - 11s 35ms/step - loss: 1591.4878 - val_loss: 1816.7242
Epoch 50/100
300/300 [==============================] - 11s 35ms/step - loss: 1588.5865 - val_loss: 1830.5415
Epoch 51/100
300/300 [==============================] - 11s 36ms/step - loss: 1588.8630 - val_loss: 1786.2652
Epoch 52/100
300/300 [==============================] - 11s 35ms/step - loss: 1590.8364 - val_loss: 1765.7641
Epoch 53/100
300/300 [==============================] - 11s 36ms/step - loss: 1584.2209 - val_loss: 1731.3908
Epoch 54/100
300/300 [==============================] - 11s 35ms/step - loss: 1552.6711 - val_loss: 1810.2479
Epoch 55/100
300/300 [==============================] - 11s 35ms/step - loss: 1537.0594 - val_loss: 1626.7536
Epoch 56/100
300/300 [==============================] - 11s 35ms/step - loss: 1534.1269 - val_loss: 1594.6212
Epoch 57/100
300/300 [==============================] - 11s 35ms/step - loss: 1489.9755 - val_loss: 1674.5684
Epoch 58/100
300/300 [==============================] - 11s 35ms/step - loss: 1517.3997 - val_loss: 1691.1773
Epoch 59/100
300/300 [==============================] - 11s 35ms/step - loss: 1499.6118 - val_loss: 1793.2069
Epoch 60/100
300/300 [==============================] - 11s 36ms/step - loss: 1515.3711 - val_loss: 1768.1988
Epoch 61/100
300/300 [==============================] - 11s 36ms/step - loss: 1500.0835 - val_loss: 1533.2478
Epoch 62/100
300/300 [==============================] - 11s 35ms/step - loss: 1488.7158 - val_loss: 1726.8247
Epoch 63/100
300/300 [==============================] - 11s 35ms/step - loss: 1480.9995 - val_loss: 1857.0653
Epoch 64/100
300/300 [==============================] - 11s 35ms/step - loss: 1491.6317 - val_loss: 1722.9470
Epoch 65/100
300/300 [==============================] - 11s 35ms/step - loss: 1462.3420 - val_loss: 1654.9008
Epoch 66/100
300/300 [==============================] - 11s 35ms/step - loss: 1476.2390 - val_loss: 1660.8024
Epoch 67/100
300/300 [==============================] - 11s 35ms/step - loss: 1468.0918 - val_loss: 1734.0961
Epoch 68/100
300/300 [==============================] - 11s 36ms/step - loss: 1450.5849 - val_loss: 1689.7251
Epoch 69/100
300/300 [==============================] - 11s 35ms/step - loss: 1462.3582 - val_loss: 1728.2275
Epoch 70/100
300/300 [==============================] - 11s 35ms/step - loss: 1435.2285 - val_loss: 1713.7850
Epoch 71/100
300/300 [==============================] - 11s 35ms/step - loss: 1407.6673 - val_loss: 1629.2781
Epoch 72/100
300/300 [==============================] - 11s 35ms/step - loss: 1435.9693 - val_loss: 1618.5721
Epoch 73/100
300/300 [==============================] - 11s 36ms/step - loss: 1407.3317 - val_loss: 1528.9701
Epoch 74/100
300/300 [==============================] - 11s 35ms/step - loss: 1422.5389 - val_loss: 1633.0449
Epoch 75/100
300/300 [==============================] - 11s 36ms/step - loss: 1426.5827 - val_loss: 1614.1880
Epoch 76/100
300/300 [==============================] - 11s 36ms/step - loss: 1395.7740 - val_loss: 1535.6519
Epoch 77/100
300/300 [==============================] - 11s 35ms/step - loss: 1401.0441 - val_loss: 1597.3778
Epoch 78/100
300/300 [==============================] - 11s 35ms/step - loss: 1390.8550 - val_loss: 1644.0705
Epoch 79/100
300/300 [==============================] - 11s 35ms/step - loss: 1398.8603 - val_loss: 1568.5374
Epoch 80/100
300/300 [==============================] - 11s 36ms/step - loss: 1407.2409 - val_loss: 1725.1527
Epoch 81/100
300/300 [==============================] - 11s 35ms/step - loss: 1382.4032 - val_loss: 1738.3417
Epoch 82/100
229/300 [=====================>........] - ETA: 2s - loss: 1380.0090
plt.title("convolutional model")
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()
"""
Train set reconstruction.
"""
reco = model2.predict(X_train[:BATCH_SIZE].reshape(BATCH_SIZE, 28, 28, 1))
fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(7, 7))
for ind, ax in enumerate(axes.flatten()):
if ind % 2 == 0:
ax.set_title('Reconstructed')
ax.imshow(reco[ind].reshape(28, 28))
else:
ax.set_title(' <- Original')
ax.imshow(X_train[ind - 1].reshape(28, 28))
fig.tight_layout()
plt.show()
"""
Test set reconstruction.
"""
reco = model2.predict(X_test[:BATCH_SIZE].reshape(BATCH_SIZE, 28, 28, 1))
fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(7, 7))
for ind, ax in enumerate(axes.flatten()):
if ind % 2 == 0:
ax.set_title('Reconstructed')
ax.imshow(reco[ind].reshape(28, 28))
else:
ax.set_title(' <- Original')
ax.imshow(X_test[ind - 1].reshape(28, 28))
fig.tight_layout()
plt.show()
##########################
"""
Trying with cifar-10
"""
##########################
from keras.datasets.cifar10 import load_data
(X_train, _), (X_test, _) = load_data()
X_train = X_train.reshape(X_train.shape[0], 32*32*3)
X_test = X_test.reshape(X_test.shape[0], 32*32*3)
X_train = X_train / 255.
X_test = X_test / 255.
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 66s 0us/step
cifar = DataGenerator(X_train, X_test, BATCH_SIZE=128, IMAGE_SHAPE=(32, 32, 3))
autoencoder3 = ConvolutionalAutoEncoderCIFAR(input_shape=(32, 32, 3), latent_dim=256)
model3 = autoencoder3.get_compiled_model("normal")
model3.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_5 (InputLayer) (None, 32, 32, 3) 0
_________________________________________________________________
conv2d_9 (Conv2D) (None, 32, 32, 512) 6656
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 16, 16, 512) 0
_________________________________________________________________
conv2d_10 (Conv2D) (None, 16, 16, 256) 524544
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 8, 8, 256) 0
_________________________________________________________________
conv2d_11 (Conv2D) (None, 8, 8, 128) 131200
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 4, 4, 128) 0
_________________________________________________________________
flatten_3 (Flatten) (None, 2048) 0
_________________________________________________________________
dense_15 (Dense) (None, 256) 524544
_________________________________________________________________
reshape_3 (Reshape) (None, 4, 4, 16) 0
_________________________________________________________________
conv2d_transpose_7 (Conv2DTr (None, 8, 8, 256) 16640
_________________________________________________________________
conv2d_transpose_8 (Conv2DTr (None, 16, 16, 512) 524800
_________________________________________________________________
conv2d_transpose_9 (Conv2DTr (None, 32, 32, 1024) 2098176
_________________________________________________________________
conv2d_12 (Conv2D) (None, 32, 32, 3) 12291
=================================================================
Total params: 3,838,851
Trainable params: 3,838,851
Non-trainable params: 0
_________________________________________________________________
history = model3.fit_generator(
cifar.flow(),
steps_per_epoch=450, verbose=1,
epochs=60, validation_data=cifar.validation_data())
Epoch 1/60
450/450 [==============================] - 68s 151ms/step - loss: 27946.8238 - val_loss: 14479.2754
Epoch 2/60
450/450 [==============================] - 66s 146ms/step - loss: 13018.1585 - val_loss: 12141.1006
Epoch 3/60
450/450 [==============================] - 66s 146ms/step - loss: 9845.9855 - val_loss: 8486.0200
Epoch 4/60
450/450 [==============================] - 66s 147ms/step - loss: 8030.3989 - val_loss: 7132.8446
Epoch 5/60
450/450 [==============================] - 66s 147ms/step - loss: 7031.6062 - val_loss: 6060.4382
Epoch 6/60
450/450 [==============================] - 66s 146ms/step - loss: 6329.1948 - val_loss: 6210.8475
Epoch 7/60
450/450 [==============================] - 66s 146ms/step - loss: 5803.5753 - val_loss: 5252.7298
Epoch 8/60
450/450 [==============================] - 66s 146ms/step - loss: 5408.3857 - val_loss: 5725.4557
Epoch 9/60
450/450 [==============================] - 66s 147ms/step - loss: 5111.0215 - val_loss: 4777.3282
Epoch 10/60
450/450 [==============================] - 66s 146ms/step - loss: 4884.5150 - val_loss: 4810.2454
Epoch 11/60
450/450 [==============================] - 66s 146ms/step - loss: 4673.2346 - val_loss: 4934.9484
Epoch 12/60
450/450 [==============================] - 66s 146ms/step - loss: 4496.2168 - val_loss: 4835.7650
Epoch 13/60
450/450 [==============================] - 66s 146ms/step - loss: 4393.8634 - val_loss: 4871.9815
Epoch 14/60
450/450 [==============================] - 66s 146ms/step - loss: 4271.4403 - val_loss: 4036.4408
Epoch 15/60
450/450 [==============================] - 66s 146ms/step - loss: 4124.3513 - val_loss: 4167.2694
Epoch 16/60
450/450 [==============================] - 66s 146ms/step - loss: 4038.0627 - val_loss: 3807.8663
Epoch 17/60
450/450 [==============================] - 66s 146ms/step - loss: 3999.0882 - val_loss: 3970.8527
Epoch 18/60
450/450 [==============================] - 66s 146ms/step - loss: 3974.3608 - val_loss: 3837.2602
Epoch 19/60
450/450 [==============================] - 66s 146ms/step - loss: 3873.5285 - val_loss: 4494.7483
Epoch 20/60
450/450 [==============================] - 66s 146ms/step - loss: 3810.2094 - val_loss: 3949.2228
Epoch 21/60
450/450 [==============================] - 66s 146ms/step - loss: 3757.9901 - val_loss: 3570.7066
Epoch 22/60
450/450 [==============================] - 66s 146ms/step - loss: 3696.2370 - val_loss: 3732.2817
Epoch 23/60
450/450 [==============================] - 66s 146ms/step - loss: 3661.7554 - val_loss: 3608.2113
Epoch 24/60
450/450 [==============================] - 66s 146ms/step - loss: 3613.4332 - val_loss: 4236.9233
Epoch 25/60
450/450 [==============================] - 66s 146ms/step - loss: 3609.6554 - val_loss: 3901.4129
Epoch 26/60
450/450 [==============================] - 66s 146ms/step - loss: 3529.4795 - val_loss: 3661.5463
Epoch 27/60
450/450 [==============================] - 66s 146ms/step - loss: 3538.4969 - val_loss: 3507.9565
Epoch 28/60
284/450 [=================>............] - ETA: 22s - loss: 3480.6532
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Convolutional model - Normal diagonal loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()
"""
Train set reconstruction.
"""
reco = model3.predict(X_train[:BATCH_SIZE].reshape(BATCH_SIZE, 32, 32, 3))
fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(7, 7))
for ind, ax in enumerate(axes.flatten()):
if ind % 2 == 0:
ax.set_title('Reconstructed')
ax.imshow(reco[ind].reshape(32, 32, 3))
else:
ax.set_title(' <- Original')
ax.imshow(X_train[ind - 1].reshape(32, 32, 3))
fig.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
"""
Test set reconstruction.
"""
reco = model3.predict(X_test[:BATCH_SIZE].reshape(BATCH_SIZE, 32, 32, 3))
fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(7, 7))
for ind, ax in enumerate(axes.flatten()):
if ind % 2 == 0:
ax.set_title('Reconstructed')
ax.imshow(reco[ind].reshape(32, 32, 3))
else:
ax.set_title(' <- Original')
ax.imshow(X_test[ind - 1].reshape(32, 32, 3))
fig.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).