# Install the necessary dependencies

import os
import sys
!{sys.executable} -m pip install --quiet pandas scikit-learn numpy matplotlib jupyterlab_myst ipython

14. Denoising Diffusion Models#

This notebook shows how to train Denoising Difussion Models.

The code has been adapted and curated from this tutorial by Andras Beres.

14.1. Hyperparams#

import numpy as np
import math
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import os

# data
diffusion_steps = 20
image_size = 32

# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95

# optimization
batch_size = 64
num_epochs = 10
learning_rate = 1e-3
weight_decay = 1e-4
ema = 0.999

14.2. Dataset#

def preprocess_image(data):
    # center crop image
    height = tf.shape(____)[0]
    width = tf.shape(____)[1]
    crop_size = tf.minimum(height, width)
    image = tf.image.crop_to_bounding_box(
        data["image"],
        (height - crop_size) // 2,
        (width - crop_size) // 2,
        crop_size,
        crop_size,
    )
    # resize and clip
    # for image downsampling it is important to turn on antialiasing
    image = tf.image.resize(image, size=[image_size, image_size], antialias=True)
    return tf.clip_by_value(image / 255.0, 0.0, 1.0)

def prepare_dataset(split):
    return (
        tfds.load(____, split=split, shuffle_files=True)
        .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .repeat(1)
        .shuffle(10000)
        .batch(batch_size, drop_remainder=True)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

# load dataset
train_dataset = prepare_dataset("train")
val_dataset = prepare_dataset("test")
👩‍💻 Hint

We need to retrieve image data and load the MNIST dataset.

14.3. Denoising Network#

We will use the Residual U-Net model.

14.4. TODO: can we use something simpler?#

embedding_max_frequency = 1000.0
embedding_dims = 64

def sinusoidal_embedding(x):
    embedding_min_frequency = 1.0
    frequencies = tf.exp(
        tf.linspace(
            tf.math.log(____),
            tf.math.log(____),
            embedding_dims // 2,
        )
    )
    angular_speeds = 2.0 * math.pi * frequencies
    embeddings = tf.concat(
        [tf.sin(____ * ____), tf.cos(____ * ____)], axis=3
    )
    return embeddings
👩‍💻 Hint

Create a range of logarithmically spaced values between minimum frequency and maximum frequency, and then take the exponential of these values.Concatenate the sine and cosine components to form the final embedding.

14.5. Custom Residual Network#

def get_network_custom(image_size, block_depth=17, output_channels=1):
    # use the correct number of channels
    noisy_images = tf.keras.Input(shape=(image_size, image_size, output_channels))
    noise_variances = tf.keras.Input(shape=(1, 1, 1))

    e = tf.keras.layers.Lambda(____)(____)
    e = tf.keras.layers.UpSampling2D(size=image_size, interpolation="nearest")(e)
    x = tf.keras.layers.Conv2D(32, kernel_size=1)(____)
    x = tf.keras.layers.Concatenate()([x, e])

    x = tf.keras.layers.Conv2D(64, 3, padding='same', activation=tf.nn.relu)(x)
    for layers in range(2, block_depth+1):
        x = tf.keras.layers.BatchNormalization(center=False, scale=False)(x)
        x = tf.keras.layers.Conv2D(
            64, 3,
            padding='same', name='conv%d' % layers,
            activation=tf.keras.activations.swish,
            use_bias=False
        )(x)

    x = tf.keras.layers.Conv2D(output_channels, kernel_size=1, kernel_initializer="zeros")(x)
    return tf.keras.Model([noisy_images, noise_variances], x, name="simple-residual-net")
👩‍💻 Hint

We perform sinusoidal embedding on the noise variance data and use a convolutional layer to convolve the noisy image data.

14.6. Residual U-Net#

widths = [32, 64, 96, 128]
block_depth = 2

def ResidualBlock(width):
    def apply(x):
        input_width = x.shape[3]
        if input_width == width:
            residual = x
        else:
            residual = tf.keras.layers.Conv2D(width, kernel_size=1)(x)
        x = tf.keras.layers.BatchNormalization(center=False, scale=False)(x)
        x = tf.keras.layers.Conv2D(
            width, kernel_size=3, padding="same", activation=tf.keras.activations.swish
        )(x)
        x = tf.keras.layers.Conv2D(width, kernel_size=3, padding="same")(x)
        x = tf.keras.layers.Add()([x, residual])
        return x

    return apply


def DownBlock(width, block_depth):
    def apply(x):
        x, skips = x
        for _ in range(block_depth):
            x = ResidualBlock(width)(x)
            skips.append(x)
        x = tf.keras.layers.AveragePooling2D(pool_size=2)(x)
        return x

    return apply


def UpBlock(width, block_depth):
    def apply(x):
        x, skips = x
        x = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x)
        for _ in range(block_depth):
            x = tf.keras.layers.Concatenate()([x, skips.pop()])
            x = ResidualBlock(width)(x)
        return x

    return apply


def get_network(image_size, widths, block_depth):
    # use the correct number of channels
    noisy_images = tf.keras.Input(shape=(image_size, image_size, 1))
    noise_variances = tf.keras.Input(shape=(1, 1, 1))

    e = tf.keras.layers.Lambda(sinusoidal_embedding)(noise_variances)
    e = tf.keras.layers.UpSampling2D(size=image_size, interpolation="nearest")(e)

    x = tf.keras.layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
    x = tf.keras.layers.Concatenate()([____, ____])

    skips = []
    for width in widths[:-1]:
        x = ____(width, block_depth)([x, skips])

    for _ in range(____):
        x = ResidualBlock(widths[-1])(x)

    for width in reversed(widths[:-1]):
        x = ____(width, block_depth)([x, skips])

    x = tf.keras.layers.Conv2D(1, kernel_size=1, kernel_initializer="zeros")(x)

    return _____([noisy_images, noise_variances], x, name="residual_unet")
👩‍💻 Hint

We need to construct downsampling and upsampling blocks, iterate over the depth of the residual blocks, and finally return a Keras model.

14.7. Difussion Model#

class DiffusionModel(tf.keras.Model):
    def __init__(self, network):
        super().__init__()
        self.normalizer = tf.keras.layers.Normalization()
        self.network = network
        self.ema_network = tf.keras.models.clone_model(self.network)

    def compile(self, **kwargs):
        super().compile(**kwargs)
        self.noise_loss_tracker = tf.keras.metrics.Mean(name="n_loss")
        self.image_loss_tracker = tf.keras.metrics.Mean(name="i_loss")

    @property
    def metrics(self):
        return [self.noise_loss_tracker, self.image_loss_tracker]

    def denormalize(self, images):
        images = self.normalizer.mean + images * self.normalizer.variance**0.5
        return tf.clip_by_value(images, 0.0, 1.0)

    def diffusion_schedule(self, diffusion_times):
        # diffusion times -> angles
        start_angle = tf.acos(max_signal_rate)
        end_angle = tf.acos(min_signal_rate)
        diffusion_angles = ____ + ____ * (____ - ____)
        # angles -> signal and noise rates
        signal_rates = tf.cos(diffusion_angles)
        noise_rates = tf.sin(diffusion_angles)
        # note that their squared sum is always: sin^2(x) + cos^2(x) = 1
        return noise_rates, signal_rates

    def denoise(self, noisy_images, noise_rates, signal_rates, training):
        # the exponential moving average weights are used at evaluation
        if training:
            network = self.network
        else:
            network = self.ema_network
        # predict noise component and calculate the image component using it
        pred_noises = network([noisy_images, noise_rates**2], training=training)
        pred_images = (____ - ____ * ____) / ____
        return pred_noises, pred_images

    def reverse_diffusion(self, initial_noise, steps):
        # reverse diffusion = sampling
        batch = initial_noise.shape[0]
        step_size = 1.0 / steps

        # important line:
        # at the first sampling step, the "noisy image" is pure noise
        # but its signal rate is assumed to be nonzero (min_signal_rate)
        next_noisy_images = initial_noise
        for step in range(diffusion_steps):
            noisy_images = next_noisy_images
            diffusion_times = tf.ones((batch, 1, 1, 1)) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=False
            )

            # this new noisy image will be used in the next step
            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(
            next_diffusion_times
            )
            next_noisy_images = (
            ____ * ____ + ____ * ____
           )
        return pred_images

    def generate(self, num_images, steps):
        # noise -> images -> denormalized images
        initial_noise = tf.random.normal(shape=(num_images, image_size, image_size, 1))
        generated_images = self.reverse_diffusion(initial_noise, steps)
        generated_images = self.denormalize(generated_images)
        return generated_images

    def train_step(self, images):
        # normalize images to have standard deviation of 1, like the noises
        images = self.normalizer(images, training=True)
        noises = tf.random.normal(shape=images.shape)
        diffusion_times = tf.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)

        # mix the images with noises accordingly
        noisy_images = signal_rates * images + noise_rates * noises

        with tf.GradientTape() as tape:
            # train the network to separate noisy images to their components
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=True
            )

            noise_loss = self.loss(noises, pred_noises)  # used for training
            image_loss = self.loss(images, pred_images)  # only used as metric

        gradients = tape.gradient(noise_loss, self.network.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

        self.noise_loss_tracker.update_state(noise_loss)
        self.image_loss_tracker.update_state(image_loss)

        # track the exponential moving averages of weights
        for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
            ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

        return {m.name: m.result() for m in self.metrics}
👩‍💻 Hint

Calculate the diffusion angle based on the diffusion time. Compute the denoised image by subtracting the predicted noise image. Generate the next noise image based on the next noise rate and signal rate.

14.8. Complete Model#

Chose one of the residual networks.

network = get_network_custom(image_size,block_depth=10)
# network = get_network(image_size,widths,block_depth)
print(network.summary())
Model: "simple-residual-net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_2 (InputLayer)           [(None, 1, 1, 1)]    0           []                               
                                                                                                  
 input_1 (InputLayer)           [(None, 32, 32, 1)]  0           []                               
                                                                                                  
 lambda (Lambda)                (None, 1, 1, 64)     0           ['input_2[0][0]']                
                                                                                                  
 conv2d (Conv2D)                (None, 32, 32, 32)   64          ['input_1[0][0]']                
                                                                                                  
 up_sampling2d (UpSampling2D)   (None, 32, 32, 64)   0           ['lambda[0][0]']                 
                                                                                                  
 concatenate (Concatenate)      (None, 32, 32, 96)   0           ['conv2d[0][0]',                 
                                                                  'up_sampling2d[0][0]']          
                                                                                                  
 conv2d_1 (Conv2D)              (None, 32, 32, 64)   55360       ['concatenate[0][0]']            
                                                                                                  
 batch_normalization (BatchNorm  (None, 32, 32, 64)  128         ['conv2d_1[0][0]']               
 alization)                                                                                       
                                                                                                  
 conv2 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization[0][0]']    
                                                                                                  
 batch_normalization_1 (BatchNo  (None, 32, 32, 64)  128         ['conv2[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv3 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_1[0][0]']  
                                                                                                  
 batch_normalization_2 (BatchNo  (None, 32, 32, 64)  128         ['conv3[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv4 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_2[0][0]']  
                                                                                                  
 batch_normalization_3 (BatchNo  (None, 32, 32, 64)  128         ['conv4[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv5 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_3[0][0]']  
                                                                                                  
 batch_normalization_4 (BatchNo  (None, 32, 32, 64)  128         ['conv5[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv6 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_4[0][0]']  
                                                                                                  
 batch_normalization_5 (BatchNo  (None, 32, 32, 64)  128         ['conv6[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv7 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_5[0][0]']  
                                                                                                  
 batch_normalization_6 (BatchNo  (None, 32, 32, 64)  128         ['conv7[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv8 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_6[0][0]']  
                                                                                                  
 batch_normalization_7 (BatchNo  (None, 32, 32, 64)  128         ['conv8[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv9 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_7[0][0]']  
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 32, 32, 64)  128         ['conv9[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv10 (Conv2D)                (None, 32, 32, 64)   36864       ['batch_normalization_8[0][0]']  
                                                                                                  
 conv2d_2 (Conv2D)              (None, 32, 32, 1)    65          ['conv10[0][0]']                 
                                                                                                  
==================================================================================================
Total params: 388,417
Trainable params: 387,265
Non-trainable params: 1,152
__________________________________________________________________________________________________
None
model = DiffusionModel(network)
model.compile(
    optimizer=tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay),
    loss=tf.keras.losses.mean_absolute_error,
)
model.normalizer.adapt(train_dataset)
model.fit(
    train_dataset,
    epochs=num_epochs,
)
Epoch 1/10
937/937 [==============================] - 2197s 2s/step - n_loss: 0.1291 - i_loss: 0.3230
Epoch 2/10
937/937 [==============================] - 1940s 2s/step - n_loss: 0.0934 - i_loss: 0.2160
Epoch 3/10
937/937 [==============================] - 1929s 2s/step - n_loss: 0.0881 - i_loss: 0.2033
Epoch 4/10
937/937 [==============================] - 1960s 2s/step - n_loss: 0.0864 - i_loss: 0.1974
Epoch 5/10
937/937 [==============================] - 1923s 2s/step - n_loss: 0.0841 - i_loss: 0.1926
Epoch 6/10
937/937 [==============================] - 1923s 2s/step - n_loss: 0.0835 - i_loss: 0.1908
Epoch 7/10
937/937 [==============================] - 1926s 2s/step - n_loss: 0.0823 - i_loss: 0.1864
Epoch 8/10
937/937 [==============================] - 1948s 2s/step - n_loss: 0.0813 - i_loss: 0.1817
Epoch 9/10
937/937 [==============================] - 1923s 2s/step - n_loss: 0.0813 - i_loss: 0.1825
Epoch 10/10
937/937 [==============================] - 1927s 2s/step - n_loss: 0.0801 - i_loss: 0.1787
<keras.callbacks.History at 0x25c6dfaf370>

14.9. Visualize#

num_rows = 2
num_cols = 3

generated_images = model.generate(
    num_images=num_rows * num_cols,
    steps=diffusion_steps,
)

plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
for row in range(num_rows):
    for col in range(num_cols):
        index = row * num_cols + col
        plt.subplot(num_rows, num_cols, index + 1)
        plt.imshow(generated_images[index])
        plt.axis("off")

plt.tight_layout()
../../_images/1cfb1bc4713e744e717759bc3902afcccd2a15bc2ee957d6c21f553ac1ee8823.png

14.10. Acknowledgments#

Thanks to Maciej Skorski for creating Denoising Difussion Model. It inspires the majority of the content in this chapter.