Diffusion Models with Images#

Acknowledgement:. There is a lot of good information about diffusion models. This assignment relied heavily on … These may be useful to skim but you can not copy code from these places (and it won’t be that helpful because details and how things are set up matter here).

Our goal in this part of the assignment is to use the code and what we learned on page 1 for a single one-dimensional point and use it for images. This works because we can treat an image as a single point in a high-dimensional space.

What we will do is load in some images and then learn to do the following in analogy to your work on page 1.

  • diffuse the image to random gaussian noise (in a fast way)

  • undiffuse the image (using some guess for the random noise used in diffusing the image)

  • train a neural network to successfully make this guess for the random noise.

Our code will look largely the same as page 1 although might have to be a little more complicated because we will need to eventually work in batches (and ideally use the GPU).

Let’s start by pulling in the data so we can get our hand on the images. To do this, we want to use

!pip install datasets;
!pip install diffusers
import torchvision 
from torchvision.transforms  import Compose
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pylab as plt
import numpy as np
import torch
import copy
from diffusers import DDPMScheduler, UNet2DModel
import pickle

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=Compose([
              torchvision.transforms.ToTensor(),
              torchvision.transforms.Lambda(lambda t: (t * 2) - 1)]))
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

This will give you access to the data.

At this point, we can go ahead and look at the images in the following way:

def PlotImages(x):
  plt.imshow(torchvision.utils.make_grid((x+1)/2, nrow=16)[0], cmap='Greys')
  plt.show()

x=next(iter(train_dataloader))[0]
PlotImages(x)

This will show us a 16x16 arrow of our 128 images (which in this case are just letters).

We can also go ahead and visualize just one of those letters (the third one) by doing plt.matshow(x[3,0,:,:])

Fast Diffusion#

Our first goal is to go ahead and generate a fast diffusion of these letters. Copy your ForwardFastDiffusion from page 1. You can continue using the same set of \(\beta_t\) you were using previously - i.e.

beta_max = 0.02
timesteps = 200
T=timesteps
beta = torch.tensor(np.linspace(0.0001, beta_max, T,dtype=np.float32))
alpha=1-beta
alpha_bar=torch.tensor(np.cumprod(alpha))

Your previous function needs to be modified so that the random gaussian noise is not just one number but adds gaussian noise to the whole image. This can be done using torch.randn_like(x) instead of np.randn(). Now your function should work for one image - i.e. x[3,0,:,:]. Go ahead and verify this.

Now, you need to make one more change. We would like to be able to send it the entire batch of 128 images. This is stored in an array that is 128x1x28x28. Instead of sending a single T, now you want to be able to send a series of 128 T’s, Fix your code to deal with this. I did this with something like this

Ts=torch.tensor([t for i in range(0,128)]).long()
Ts=Ts.reshape(batch_size,1,1,1)

Then I was able to call ForwardFastDiffusion with these times. While we aren’t using it here, it is important that the forward diffusion can take different times in the list - i.e. Ts=[3,1,57,...]. You should test this.

Grading

Call PlotImage for the batch of 128 images diffused at time steps \(t\in [0,1,50,100,150,199]\)

Undiffusing Images#

Our next goal is to get undiffusing to work for these images. Copy over your training code from page 1. Start with a model guess which is “exact” for a single (batch of) images.

You want to add a @torch.no_grad at the beginning of the function - i.e.

@torch.no_grad()
def Undiffuse(x_t,Tstart,Tend):

I had to make two changes to my page 1 code.

First, I had to have the times t and t-1 not as a single scalar but as an array of 128x1x28x28 - i.e.

tsm1=torch.from_numpy(np.array([t-1 for i in range(0,batch_size)]))
tsm1=tsm1.reshape(batch_size,1,1,1)

I also had to change np.random.randn() to torch.randn_like(x_t). Make these changes to your code.

Grading

Get your undiffusion to work for a single (batch) of images. Undiffuse from time step \(T=199\) to \(T=0\) showing snapshots of what it looks like at $\(T \in [0,1,50,100,150,199]\)$

Undiffusing with a network#

At this point you should have it working for a single image. The next step is to replace your guess for the random numbers instead with a neural network that is going to make the guess for the random numbers.

In the single-point case, we used a standard feed-forward neural network. That one is not powerful enough here. Instead we are going to use a u-net. It’s essentially just a fancier neural network with more parameters:

# Create the network
net = UNet2DModel(
    sample_size=28,  # the target image resolution
    in_channels=1,  # the number of input channels, 3 for RGB images
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(32, 64, 64),  # Roughly matching our basic unet example
    down_block_types=( 
        "DownBlock2D",  # a regular ResNet downsampling block
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "AttnDownBlock2D",
    ), 
    up_block_types=(
        "AttnUpBlock2D", 
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",   # a regular ResNet upsampling block
      ),
) 

This network was taken from this blog post explaining diffusion.

The equivalent of calling GuessZ(x_t,t) in this case is to call net(x_t,ts.squeeze()).sample.

This network is just initialized with random weights. In a couple minutes you’re going to train your own weights. But before then, we’d like to debug your sampling (and get it on the gpu). Therefore, we are going to supply you with a network that we’ve already partially trained (in the same way you’re going to train your network).

Copy network.pickle onto your google colab and run the following code to load it:

with open('network.pickle', 'rb') as f:
    # The protocol version used is detected automatically, so we do not
    # have to specify it.
    net = pickle.load(f)

Go ahead and modify your undiffusing code to use this network. The primary change you’re going to have to make it work is to replace guessZ with the call to the network. Check that your code runs when you do this and see if you can get it to give you a undiffused sample. This is going to be very slow. You are going to want to get it working on the GPU. This only requires a few changes in your code but is a bit of a painful process. Here’s what I had to do:

  • In colab you have to change the RunTime to GPU (this will cause you to have to rerun everything in your notebook)

  • Add a line device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  • Add a line net.to(device). (This sends the network to the device)

  • In your Undiffuse function, you’re going to have to send many arrays to the device. For example, I needed to do things like:

    • alpha_ts=alpha[ts].to(device)

    • s_t=(1-alpha_ts)/torch.sqrt(1-alpha_bar_ts).to(device)

  • When plotting I had to pull it back to the cpu with PlotImages(x_t.cpu())

When I do all this, my undiffusing function takes approximately 15 seconds to go from T=200 to T=0.

Note: Google colab doesn’t let you run on the GPU forever. If you’re not using it, you should try to pick a different runtime.

Grading

Get your undiffusion to work with the unet and on the GPU. Show two different undiffusion samples (to show you get different samples) at times \(T \in \{0,1,50,100,150,199\}\) by calling undiffusion twice.

In addition, just like page 1, get your code to report the “guess” for the actual image at those times (for one of the samples).

Training with images#

Our next trick is to get training to work. Again you should be able to copy over your code from page 1.

You are going to have to make the following changes:

  • Instead of SamplePInit() now should just return a new image; you can do this by x0=next(iter(train_dataloader))[0]

  • We have to get everything to the GPU:

    • x0=x0.to(device)

    • alpha_bar_t=alpha_bar[ts].to(device) (you may have to send alpha_bar_t to your ForwardDiffusionFast instead of the time.

    • etc.

(Remember to make sure this version of your forward diffusion returns the noisy data and the noise).

Go ahead and make these modifications. Make sure you don’t start with our network (i.e. call the network and don’t load our pickled version any more). Every 500 steps, you can go ahead and call Undiffuse and plot it to see how your training is coming along.

Grading

Train your network. Plot your loss as a function of steps (you may want to show a zoomed in version as well). In addition, plot a final sample of your trained undiffused network.

My network produces reasonable digits (like a seven year old might write) in about 5 minutes.

Prompt Images#

Finally, we would like to get prompts working with images. In this case, the prompts that are going to be actually used is simply a number (0 through 9). Again, like in page 1, we are going to need to give the network the prompt both during training and undiffusing.

We will embed the digits in the following way: we will represent number \(i\) by a vector \(v\) of length 10 where \(v_i=1\) and is zero otherwise. This is called a one-hot encoding. Write some code which takes a vector of \(k\) labels - i.e. tensor([6, 5, 1, 4 … ]) and returns a \(k\) x \(1\) x \(10\) vector which embeds the labels.

Now that we have our embedded labels, we will use the following network for our prompting:

# Create the network
net = UNet2DConditionModel(
    sample_size=28,  # the target image resolution
    in_channels=1,  # the number of input channels, 3 for RGB images
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(32, 64, 64),  # Roughly matching our basic unet example
    cross_attention_dim = 10,
    down_block_types=( 
        "DownBlock2D",  # a regular ResNet downsampling block
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "AttnDownBlock2D",
    ), 
    up_block_types=(
        "CrossAttnUpBlock2D", 
        "CrossAttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",   # a regular ResNet upsampling block
      ),
) #<<<

The key change from our previous network is that we are using a UNet2DConditionModel instead of a UNet2DModel; we have added a cross_attention_dim which we set to 10 (for the ten digits). Also in the up blocks, we are using a CrossAttnUpBlock2D.

Now when we call our network, we need to use net(x_t,t, encoder_hidden_states=embedVec) (remember you should take your list of labels and embed it in the one-hot encoding).

During training, you can get the labels (the number MNIST thinks it has) by doing x, labels=next(iter(train_dataloader))

Once your diffusion model is trained, you should be able to then send it a prompt (some digit) during undiffusing and you should end up primarily with that type of digit.

Grading

Train your network. Plot a final sample of a batch of your trained undiffused network with a prompt using each of the ten digits.