GitHub Code

DiffuseSeg: Synthetic Data and Segmentation from a Single Diffusion Model

DiffuseSeg demonstrates how a single, unconditionally trained Denoising Diffusion Probabilistic Model (DDPM) can serve as a powerful backbone for both high fidelity synthetic image generation and label-efficient semantic segmentation.

A summary of DiffuseSeg
A summary of DiffuseSeg

The core idea is to repurpose the rich, multi-scale features learned by the U-Net decoder of a DDPM. By extracting these features, we can train a lightweight, pixel level segmentation head with very few labeled examples, effectively turning the generative model into a labeled data factory.

This project was inspired by the following paper.

Key Features

How It Works: The TwoStage Pipeline

The project is implemented in two main stages:

Stage 1: Train a Denoising Diffusion Model (DDPM)

DDPM sampling GIF

Stage 2: Train a Segmentation Head

With the DDPM U-Net frozen, we use it as a feature extractor.

This approach allows the model to generate a segmentation mask for any image—real or synthetically generated by the DDPM.

Results

The segmentation head achieves strong performance on the CelebA-HQ validation set, demonstrating the quality of the features extracted from the trained DDPM.

Example Predictions

Here are some end-to-end results, showing a synthetic image generated by the DDPM and the corresponding segmentation map produced by the MLP head.

e2e_1 e2e_2

Here are some validation results, showing a image, its GT Mask from the CelebA-HQ Dataset accompanied by the corresponding segmentation map produced by DiffuseSeg.

Val_1 Val_2

Trained weights and demo :

Setup and Installation

  1. Clone the repository:
    git clone https://github.com/your-username/DiffuseSeg.git
    cd DiffuseSeg
  2. Create a virtual environment and install dependencies:
    conda create -n diffuseg_env python=3.9
    conda activate diffuseg_env
    pip install -r requirements.txt

How to Run

  1. Training the DDPM

    To train the diffusion model from scratch, use the DDPM-train.py script. Make sure your dataset path and training parameters are correctly set in utils/config.yaml. Also make dataset specific changes (im_size, im_channels) in the config file, while also noting that architectural changes (in terms of num of down/mid/up blocks ) can be made within config file.

    python utils/DDPM-train.py
  2. Feature Extraction
    python utils/Feature_extractor.py
  3. Training the Segmentation Head
    python utils/train_MLPs.py
  4. Inference

Citation

Find below the original paper that inspired this approach:

@inproceedings{baranchuk2022label,
  title={Label-Efficient Semantic Segmentation with Diffusion Models},
  author={Dmitry Baranchuk and Ivan Rubachev and Andrey Voynov and Valentin Khrulkov and Artem Babenko},
  booktitle={International Conference on Learning Representations},
  year={2022}
}