# Install the necessary dependencies
import os
import sys
!{sys.executable} -m pip install --quiet pandas scikit-learn numpy matplotlib jupyterlab_myst ipython
14. Denoising Diffusion Models#
This notebook shows how to train Denoising Difussion Models.
The code has been adapted and curated from this tutorial by Andras Beres.
14.1. Hyperparams#
import numpy as np
import math
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import os
# data
diffusion_steps = 20
image_size = 32
# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95
# optimization
batch_size = 64
num_epochs = 10
learning_rate = 1e-3
weight_decay = 1e-4
ema = 0.999
14.2. Dataset#
def preprocess_image(data):
# center crop image
height = tf.shape(____)[0]
width = tf.shape(____)[1]
crop_size = tf.minimum(height, width)
image = tf.image.crop_to_bounding_box(
data["image"],
(height - crop_size) // 2,
(width - crop_size) // 2,
crop_size,
crop_size,
)
# resize and clip
# for image downsampling it is important to turn on antialiasing
image = tf.image.resize(image, size=[image_size, image_size], antialias=True)
return tf.clip_by_value(image / 255.0, 0.0, 1.0)
def prepare_dataset(split):
return (
tfds.load(____, split=split, shuffle_files=True)
.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
.cache()
.repeat(1)
.shuffle(10000)
.batch(batch_size, drop_remainder=True)
.prefetch(buffer_size=tf.data.AUTOTUNE)
)
# load dataset
train_dataset = prepare_dataset("train")
val_dataset = prepare_dataset("test")
👩💻 Hint
We need to retrieve image data and load the MNIST dataset.
14.3. Denoising Network#
We will use the Residual U-Net model.
14.4. TODO: can we use something simpler?#
embedding_max_frequency = 1000.0
embedding_dims = 64
def sinusoidal_embedding(x):
embedding_min_frequency = 1.0
frequencies = tf.exp(
tf.linspace(
tf.math.log(____),
tf.math.log(____),
embedding_dims // 2,
)
)
angular_speeds = 2.0 * math.pi * frequencies
embeddings = tf.concat(
[tf.sin(____ * ____), tf.cos(____ * ____)], axis=3
)
return embeddings
👩💻 Hint
Create a range of logarithmically spaced values between minimum frequency and maximum frequency, and then take the exponential of these values.Concatenate the sine and cosine components to form the final embedding.
14.5. Custom Residual Network#
def get_network_custom(image_size, block_depth=17, output_channels=1):
# use the correct number of channels
noisy_images = tf.keras.Input(shape=(image_size, image_size, output_channels))
noise_variances = tf.keras.Input(shape=(1, 1, 1))
e = tf.keras.layers.Lambda(____)(____)
e = tf.keras.layers.UpSampling2D(size=image_size, interpolation="nearest")(e)
x = tf.keras.layers.Conv2D(32, kernel_size=1)(____)
x = tf.keras.layers.Concatenate()([x, e])
x = tf.keras.layers.Conv2D(64, 3, padding='same', activation=tf.nn.relu)(x)
for layers in range(2, block_depth+1):
x = tf.keras.layers.BatchNormalization(center=False, scale=False)(x)
x = tf.keras.layers.Conv2D(
64, 3,
padding='same', name='conv%d' % layers,
activation=tf.keras.activations.swish,
use_bias=False
)(x)
x = tf.keras.layers.Conv2D(output_channels, kernel_size=1, kernel_initializer="zeros")(x)
return tf.keras.Model([noisy_images, noise_variances], x, name="simple-residual-net")
👩💻 Hint
We perform sinusoidal embedding on the noise variance data and use a convolutional layer to convolve the noisy image data.
14.6. Residual U-Net#
widths = [32, 64, 96, 128]
block_depth = 2
def ResidualBlock(width):
def apply(x):
input_width = x.shape[3]
if input_width == width:
residual = x
else:
residual = tf.keras.layers.Conv2D(width, kernel_size=1)(x)
x = tf.keras.layers.BatchNormalization(center=False, scale=False)(x)
x = tf.keras.layers.Conv2D(
width, kernel_size=3, padding="same", activation=tf.keras.activations.swish
)(x)
x = tf.keras.layers.Conv2D(width, kernel_size=3, padding="same")(x)
x = tf.keras.layers.Add()([x, residual])
return x
return apply
def DownBlock(width, block_depth):
def apply(x):
x, skips = x
for _ in range(block_depth):
x = ResidualBlock(width)(x)
skips.append(x)
x = tf.keras.layers.AveragePooling2D(pool_size=2)(x)
return x
return apply
def UpBlock(width, block_depth):
def apply(x):
x, skips = x
x = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x)
for _ in range(block_depth):
x = tf.keras.layers.Concatenate()([x, skips.pop()])
x = ResidualBlock(width)(x)
return x
return apply
def get_network(image_size, widths, block_depth):
# use the correct number of channels
noisy_images = tf.keras.Input(shape=(image_size, image_size, 1))
noise_variances = tf.keras.Input(shape=(1, 1, 1))
e = tf.keras.layers.Lambda(sinusoidal_embedding)(noise_variances)
e = tf.keras.layers.UpSampling2D(size=image_size, interpolation="nearest")(e)
x = tf.keras.layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
x = tf.keras.layers.Concatenate()([____, ____])
skips = []
for width in widths[:-1]:
x = ____(width, block_depth)([x, skips])
for _ in range(____):
x = ResidualBlock(widths[-1])(x)
for width in reversed(widths[:-1]):
x = ____(width, block_depth)([x, skips])
x = tf.keras.layers.Conv2D(1, kernel_size=1, kernel_initializer="zeros")(x)
return _____([noisy_images, noise_variances], x, name="residual_unet")
👩💻 Hint
We need to construct downsampling and upsampling blocks, iterate over the depth of the residual blocks, and finally return a Keras model.
14.7. Difussion Model#
class DiffusionModel(tf.keras.Model):
def __init__(self, network):
super().__init__()
self.normalizer = tf.keras.layers.Normalization()
self.network = network
self.ema_network = tf.keras.models.clone_model(self.network)
def compile(self, **kwargs):
super().compile(**kwargs)
self.noise_loss_tracker = tf.keras.metrics.Mean(name="n_loss")
self.image_loss_tracker = tf.keras.metrics.Mean(name="i_loss")
@property
def metrics(self):
return [self.noise_loss_tracker, self.image_loss_tracker]
def denormalize(self, images):
images = self.normalizer.mean + images * self.normalizer.variance**0.5
return tf.clip_by_value(images, 0.0, 1.0)
def diffusion_schedule(self, diffusion_times):
# diffusion times -> angles
start_angle = tf.acos(max_signal_rate)
end_angle = tf.acos(min_signal_rate)
diffusion_angles = ____ + ____ * (____ - ____)
# angles -> signal and noise rates
signal_rates = tf.cos(diffusion_angles)
noise_rates = tf.sin(diffusion_angles)
# note that their squared sum is always: sin^2(x) + cos^2(x) = 1
return noise_rates, signal_rates
def denoise(self, noisy_images, noise_rates, signal_rates, training):
# the exponential moving average weights are used at evaluation
if training:
network = self.network
else:
network = self.ema_network
# predict noise component and calculate the image component using it
pred_noises = network([noisy_images, noise_rates**2], training=training)
pred_images = (____ - ____ * ____) / ____
return pred_noises, pred_images
def reverse_diffusion(self, initial_noise, steps):
# reverse diffusion = sampling
batch = initial_noise.shape[0]
step_size = 1.0 / steps
# important line:
# at the first sampling step, the "noisy image" is pure noise
# but its signal rate is assumed to be nonzero (min_signal_rate)
next_noisy_images = initial_noise
for step in range(diffusion_steps):
noisy_images = next_noisy_images
diffusion_times = tf.ones((batch, 1, 1, 1)) - step * step_size
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
pred_noises, pred_images = self.denoise(
noisy_images, noise_rates, signal_rates, training=False
)
# this new noisy image will be used in the next step
next_diffusion_times = diffusion_times - step_size
next_noise_rates, next_signal_rates = self.diffusion_schedule(
next_diffusion_times
)
next_noisy_images = (
____ * ____ + ____ * ____
)
return pred_images
def generate(self, num_images, steps):
# noise -> images -> denormalized images
initial_noise = tf.random.normal(shape=(num_images, image_size, image_size, 1))
generated_images = self.reverse_diffusion(initial_noise, steps)
generated_images = self.denormalize(generated_images)
return generated_images
def train_step(self, images):
# normalize images to have standard deviation of 1, like the noises
images = self.normalizer(images, training=True)
noises = tf.random.normal(shape=images.shape)
diffusion_times = tf.random.uniform(
shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
)
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
# mix the images with noises accordingly
noisy_images = signal_rates * images + noise_rates * noises
with tf.GradientTape() as tape:
# train the network to separate noisy images to their components
pred_noises, pred_images = self.denoise(
noisy_images, noise_rates, signal_rates, training=True
)
noise_loss = self.loss(noises, pred_noises) # used for training
image_loss = self.loss(images, pred_images) # only used as metric
gradients = tape.gradient(noise_loss, self.network.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))
self.noise_loss_tracker.update_state(noise_loss)
self.image_loss_tracker.update_state(image_loss)
# track the exponential moving averages of weights
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
ema_weight.assign(ema * ema_weight + (1 - ema) * weight)
return {m.name: m.result() for m in self.metrics}
👩💻 Hint
Calculate the diffusion angle based on the diffusion time. Compute the denoised image by subtracting the predicted noise image. Generate the next noise image based on the next noise rate and signal rate.
14.8. Complete Model#
Chose one of the residual networks.
network = get_network_custom(image_size,block_depth=10)
# network = get_network(image_size,widths,block_depth)
print(network.summary())
Model: "simple-residual-net"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_2 (InputLayer) [(None, 1, 1, 1)] 0 []
input_1 (InputLayer) [(None, 32, 32, 1)] 0 []
lambda (Lambda) (None, 1, 1, 64) 0 ['input_2[0][0]']
conv2d (Conv2D) (None, 32, 32, 32) 64 ['input_1[0][0]']
up_sampling2d (UpSampling2D) (None, 32, 32, 64) 0 ['lambda[0][0]']
concatenate (Concatenate) (None, 32, 32, 96) 0 ['conv2d[0][0]',
'up_sampling2d[0][0]']
conv2d_1 (Conv2D) (None, 32, 32, 64) 55360 ['concatenate[0][0]']
batch_normalization (BatchNorm (None, 32, 32, 64) 128 ['conv2d_1[0][0]']
alization)
conv2 (Conv2D) (None, 32, 32, 64) 36864 ['batch_normalization[0][0]']
batch_normalization_1 (BatchNo (None, 32, 32, 64) 128 ['conv2[0][0]']
rmalization)
conv3 (Conv2D) (None, 32, 32, 64) 36864 ['batch_normalization_1[0][0]']
batch_normalization_2 (BatchNo (None, 32, 32, 64) 128 ['conv3[0][0]']
rmalization)
conv4 (Conv2D) (None, 32, 32, 64) 36864 ['batch_normalization_2[0][0]']
batch_normalization_3 (BatchNo (None, 32, 32, 64) 128 ['conv4[0][0]']
rmalization)
conv5 (Conv2D) (None, 32, 32, 64) 36864 ['batch_normalization_3[0][0]']
batch_normalization_4 (BatchNo (None, 32, 32, 64) 128 ['conv5[0][0]']
rmalization)
conv6 (Conv2D) (None, 32, 32, 64) 36864 ['batch_normalization_4[0][0]']
batch_normalization_5 (BatchNo (None, 32, 32, 64) 128 ['conv6[0][0]']
rmalization)
conv7 (Conv2D) (None, 32, 32, 64) 36864 ['batch_normalization_5[0][0]']
batch_normalization_6 (BatchNo (None, 32, 32, 64) 128 ['conv7[0][0]']
rmalization)
conv8 (Conv2D) (None, 32, 32, 64) 36864 ['batch_normalization_6[0][0]']
batch_normalization_7 (BatchNo (None, 32, 32, 64) 128 ['conv8[0][0]']
rmalization)
conv9 (Conv2D) (None, 32, 32, 64) 36864 ['batch_normalization_7[0][0]']
batch_normalization_8 (BatchNo (None, 32, 32, 64) 128 ['conv9[0][0]']
rmalization)
conv10 (Conv2D) (None, 32, 32, 64) 36864 ['batch_normalization_8[0][0]']
conv2d_2 (Conv2D) (None, 32, 32, 1) 65 ['conv10[0][0]']
==================================================================================================
Total params: 388,417
Trainable params: 387,265
Non-trainable params: 1,152
__________________________________________________________________________________________________
None
model = DiffusionModel(network)
model.compile(
optimizer=tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay),
loss=tf.keras.losses.mean_absolute_error,
)
model.normalizer.adapt(train_dataset)
model.fit(
train_dataset,
epochs=num_epochs,
)
Epoch 1/10
937/937 [==============================] - 2197s 2s/step - n_loss: 0.1291 - i_loss: 0.3230
Epoch 2/10
937/937 [==============================] - 1940s 2s/step - n_loss: 0.0934 - i_loss: 0.2160
Epoch 3/10
937/937 [==============================] - 1929s 2s/step - n_loss: 0.0881 - i_loss: 0.2033
Epoch 4/10
937/937 [==============================] - 1960s 2s/step - n_loss: 0.0864 - i_loss: 0.1974
Epoch 5/10
937/937 [==============================] - 1923s 2s/step - n_loss: 0.0841 - i_loss: 0.1926
Epoch 6/10
937/937 [==============================] - 1923s 2s/step - n_loss: 0.0835 - i_loss: 0.1908
Epoch 7/10
937/937 [==============================] - 1926s 2s/step - n_loss: 0.0823 - i_loss: 0.1864
Epoch 8/10
937/937 [==============================] - 1948s 2s/step - n_loss: 0.0813 - i_loss: 0.1817
Epoch 9/10
937/937 [==============================] - 1923s 2s/step - n_loss: 0.0813 - i_loss: 0.1825
Epoch 10/10
937/937 [==============================] - 1927s 2s/step - n_loss: 0.0801 - i_loss: 0.1787
<keras.callbacks.History at 0x25c6dfaf370>
14.9. Visualize#
num_rows = 2
num_cols = 3
generated_images = model.generate(
num_images=num_rows * num_cols,
steps=diffusion_steps,
)
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
for row in range(num_rows):
for col in range(num_cols):
index = row * num_cols + col
plt.subplot(num_rows, num_cols, index + 1)
plt.imshow(generated_images[index])
plt.axis("off")
plt.tight_layout()
14.10. Acknowledgments#
Thanks to Maciej Skorski for creating Denoising Difussion Model. It inspires the majority of the content in this chapter.