Retrieval Augmented #Diffusion (RDM) models: Smaller diffusion models can generate high-quality generations by accessing an external memory to guide the generation. Inspired by Deepmind's RETRO.

A 🧶

Paper: https://arxiv.org/abs/2204.11824

Day 10 #30daysofDiffusion #MachineLearning

Semi-Parametric Neural Image Synthesis

Novel architectures have recently improved generative image synthesis leading to excellent visual quality in various tasks. Much of this success is due to the scalability of these architectures and hence caused by a dramatic increase in model complexity and in the computational resources invested in training these models. Our work questions the underlying paradigm of compressing large training data into ever growing parametric representations. We rather present an orthogonal, semi-parametric approach. We complement comparably small diffusion or autoregressive models with a separate image database and a retrieval strategy. During training we retrieve a set of nearest neighbors from this external database for each training instance and condition the generative model on these informative samples. While the retrieval approach is providing the (local) content, the model is focusing on learning the composition of scenes based on this content. As demonstrated by our experiments, simply swapping the database for one with different contents transfers a trained model post-hoc to a novel domain. The evaluation shows competitive performance on tasks which the generative model has not been trained on, such as class-conditional synthesis, zero-shot stylization or text-to-image synthesis without requiring paired text-image data. With negligible memory and computational overhead for the external database and retrieval we can significantly reduce the parameter count of the generative model and still outperform the state-of-the-art.

arXiv.org
If the model can rely on this external memory always, it just has to learn important details about the image generation process such as the composition of scenes rather than, for example, remembering how different dogs look like.
Setting: X is the training set and D is a *disjoint* image set which is used for retrieval. θ denotes the parameters of the diffusion model. ξ is the retrieval function which takes in an image and selects "k" images from D. φ is a pretrained image encoder.
Both ξ and φ are pretrained, "fixed" functions and we do not modify them during training or inference. Only θ is learned/optimized during training.
So the generative model formulation in this case boils down to learning a diffusion (or auto-regressive model) conditioned on similar-looking images to a given train image - x.
In this paper, the authors use a CLIP image encoder for retrieval and use cosine similarity to choose top "k" similar-looking images for a given image. The authors also chose φ to be the CLIP image encoder.
The training objective of the diffusion model is shown below. Think of SDM except the text representation is replaced by multiple image representations, (which are similar to the image we are diffusing) cross-attending into the U-Net encoder.
The authors discuss 3 possible inference scenarios. Input is (1) an image (2) text (3) no condition. The first case is easy since we trained the model conditioned on the image.
If we want to generate an image based on text, the authors use the CLIP text encoder to find the closest matches from database D. (dot-p similarity between clip_text(input text) and clip_image(D), and pick the top "k" high similarity images).
Unconditional is interesting. The simplest way is to just sample a random image from the memory bank, D. However, in reality, each of the images in D is equally likely since some of them might be more similar to training data than others. So how to get around this?
Authors generate an MLE-ish score for the images in D based on how often they appear in the top "k" of the training data X.
Then using this p(x~) distribution from the above tweet, sample an image from D, and then use that to get top-k similar images from D again and use those as guidance to generate the final image.
Most of the architectural components are the same as Latent Diffusion Model. The authors evaluated the model performance with different datasets as "D". RDM-OpenImages model performed the best across various metrics
A larger "k" leads to higher recall, which means sample diversity is less. In this paper, the authors used k=4 in most of the experiments. However larger k in training led to better generalization capabilities.
When you replace the retrieval dataset during the inference, some zero-shot style transfer abilities are observed.
This one is slightly confusing to me, the authors say that the for text-image synthesis case if we use clip text embedding of the text and use that to retrieve "k" images from "D" and use both the text embedding and NN images, the generations are not good. (contd)
And just using the NNs leads to even worse generations. It is slightly strange that the model is not doing well on the thing it is trained for...
Official code (same repo as LDM) - https://github.com/CompVis/latent-diffusion
GitHub - CompVis/latent-diffusion: High-Resolution Image Synthesis with Latent Diffusion Models

High-Resolution Image Synthesis with Latent Diffusion Models - GitHub - CompVis/latent-diffusion: High-Resolution Image Synthesis with Latent Diffusion Models

GitHub