Stable Diffusion from Begginer to Master (3) VAE Deep Dive
Starting from this 3rd tutorial of Stable Diffusion, we'll dive into the details of pipelines and see how each component work.
In this tutorial we'll breifly have a look at what components are there in a Pipeline, then take a deeper dive into one of the component - the Variationanl Auto Encoder(VAE).
!pip install -Uqq diffusers transformers ftfy accelerate bitsandbytes
#!pip install -Uqq triton xformers
import torch
from diffusers import StableDiffusionPipeline
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-base', revision="fp16", torch_dtype=torch.float16).to(device)
Let's take a look at what is inside a StableDiffusionPipeline
:
pipe
There are 5 main components of a StableDiffusionPipeline
:
VAE
This is a downsampler + upsampler combination which works in the beginning and the end of the pipeline. It can "compress" the image from 512x512 -> 64x64 (8 times, which means the image sizes need to be divisible by 8) or "decompress" it back.Text Encoder
andTokenizer
This two components work together to convert text prompts into embeddings.UNet
This is the main module for diffusion, which takes a nosiy image and a text embedding as the input, and predicts the noise.Scheduler
This determines the details of the diffusion process, e.g. how many steps to run, how much noise to remove at each step, etc.
Encode and Decode with VAE
The VAE is a encoder-decoder
combo, which can downsample an image to a smaller image(called latents
) and then convert it back.
Why is this important? An image with 512x512
is worth 786432
numbers, combining that with convolutions(matrix multiplications) can use a lot of GPU memory and power. By downsampling it into a much smaller image, doing image generation on that, and then upsampling it to a large image can save a lot of memory and computing power. Although not necessary at all from a mathematical standpoint, the VAE is actually the key part that makes it possible to run stable diffusion on low-end GPUs, even personal computers.
import math
import PIL
import torch
import numpy as np
from matplotlib import pyplot as plt
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
# PIL version compatiblity
try:
resample = PIL.Image.Resampling.LANCZOS
except:
resample = PIL.Image.LANCZOS
image = image.resize((w, h), resample=resample)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [PIL.Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [PIL.Image.fromarray(image) for image in images]
return pil_images
@torch.no_grad()
def encode_image(pipe, image, dtype=torch.float16):
if isinstance(image, PIL.Image.Image):
image = preprocess(image)
image = image.to(device=device, dtype=dtype)
latent_dist = pipe.vae.encode(image).latent_dist
latents = latent_dist.sample(generator=None)
latents = 0.18215 * latents
return latents
@torch.no_grad()
def decode_latents(pipe, latents):
latents = 1 / 0.18215 * latents
image = pipe.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = numpy_to_pil(image)
if len(image) == 1:
# single image
return image[0]
return image
def image_grid(imgs, rows=None, cols=None) -> PIL.Image.Image:
n_images = len(imgs)
if not rows and not cols:
cols = math.ceil(math.sqrt(n_images))
if not rows:
rows = math.ceil(n_images / cols)
if not cols:
cols = math.ceil(n_images / rows)
w, h = imgs[0].size
grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
Now let's try out the VAE.
- The original image is
(516, 484)
preprocess
will resize it to the nearest size divisible by 8, which is(512, 480)
.encode_image
converts image to latents with 4 channels, each with(60,64)
, which is 1/8 the original size. Note the x and y switches only because image size is(width, height)
but array shape is(height, width)
.decode_latents
converts latents back to image, which is(512, 480)
.
def demo_vae(image, latents, decoded_image):
fig, axes = plt.subplots(figsize=(8,4), nrows=1, ncols=2)
plt.tight_layout()
axes[0].imshow(image)
# axes[0].set_axis_off()
axes[0].set_title(f'Original {image.size}')
axes[1].imshow(decoded_image)
# axes[1].set_axis_off()
axes[1].set_title(f'Decoded {decoded_image.size}')
if len(latents.shape) == 4:
latents = latents[0]
latents = latents.float().cpu().numpy()
n_channels = latents.shape[0]
fig, axs = plt.subplots(1, n_channels, figsize=(n_channels*4, 4))
fig.suptitle(f'Latent channels {latents.shape}')
for c in range(n_channels):
axs[c].imshow(latents[c], cmap='Greys')
image = PIL.Image.open('images/pikachu.png').convert('RGB')
latents = encode_image(pipe, image)
decoded_image = decode_latents(pipe, latents)
demo_vae(image, latents, decoded_image)
def add_rand_noise(latents, noise_level):
scale = torch.mean(torch.abs(latents))
return latents + (torch.rand_like(latents)-0.5) * scale * noise_level
image = PIL.Image.open('images/pikachu.png').convert('RGB')
latents = encode_image(pipe, image)
fig, axs = plt.subplots(1, 1, figsize=(4, 4))
axs.imshow(image)
axs.set_title('Original')
# axs.set_axis_off()
count = 6
noise_levels = np.linspace(0.1,1,count)
fig, axs = plt.subplots(1, count, figsize=(4*count, 4))
for ax, noise in zip(axs, noise_levels):
latents1 = add_rand_noise(latents, noise)
decoded_image = decode_latents(pipe, latents1)
ax.imshow(decoded_image)
ax.set_title(f'Noised {noise}')
# ax.set_axis_off()
As you can see in the above, VAE is quite resilient to noises. Even when adding roughly 50% noise with respect to the mean latent value scale(!!!), it can still do quite a good job of restoring the image.
Playing with VAE code
Now we'll dive into the VAE code to see what is going on. To do this in colab, we can add some print statements to diffusers/model/vae.py
To relect our changes without having to restart the colab runtime, we can run the following reload code:
from diffusers.models import vae
import importlib
importlib.reload(vae)
pipe.vae.__class__ = vae.AutoencoderKL
The encode
function encodes an image of shape [1,3,h,w]
into a latent tensor [1,4,h/8,w/8]
.
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
# x: [1, 3, h, w]
# h: [1, 8, h/8, w/8]
h = self.encoder(x)
# moments: [1, 8, h/8, w/8]
moments = self.quant_conv(h)
# posterior.std: [1, 4, h/8, w/8]
# posterior.var: [1, 4, h/8, w/8]
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
- The
encoder
is aUNet
that converts from input imagex
with shape[1,3,h,w]
into a tensorh
with shape[1,8,h/8,w/8]
. quant_conv
is aConv2d
block with kernel size 1, which convertsh
intomoments
preserving the shape.- The
moments
with shape[1,8,h/8,w/8]
is used as the standard deviation and variance of a distribution, from which the latents are sampled.
from diffusers.models.vae import DiagonalGaussianDistribution
image = PIL.Image.open('images/pikachu.png').convert('RGB')
print('image:', image.size)
x = preprocess(image).to(device=device, dtype=torch.float16)
print('image tensor:', x.shape)
h = pipe.vae.encoder(x)
print('h:', h.shape)
moments = pipe.vae.quant_conv(h)
print('moments:', moments.shape)
latent_dist = DiagonalGaussianDistribution(moments)
latents = latent_dist.sample(generator=None)
latents = 0.18215 * latents
print('latents:', latents.shape)
Now the decode
function:
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
post_quant_conv
is aConv2d
with input channel and output channel equal to the latent channels (=4), kernel size 1. It convertslatents
toz
with the same shape.decoder
is aUNet
that converts from latent shape[1,4,h/8,w/8]
into an image tensor of shape[1,3,h,w]
.
latents = 1 / 0.18215 * latents
print('latents:', latents.shape)
z = pipe.vae.post_quant_conv(latents)
print('z:', z.shape)
dec = pipe.vae.decoder(z)
print('dec:', dec.shape)
image = (dec.detach() / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = numpy_to_pil(image)[0]
display(image)
We won't dive deeper into the Encoder and Decoder (both are UNet
, which is a special kind of Convolutional network, characterized by first scaling down and then scaling up feature maps). But you can defintely have a look at the source code to dig more details.
pipe.vae.encoder