Skip to content

Conversation

@SpenserCai
Copy link
Contributor

@SpenserCai SpenserCai commented Jan 12, 2026

Summary

This PR adds support for BiRefNet (Bilateral Reference Network), a state-of-the-art model for high-resolution dichotomous image segmentation, commonly used for background removal tasks.

⚠️ Dependency Note

This PR depends on the deform_conv2d operation which is currently pending in PR #3292 (The development work of this PR has been completed, and the data verification has passed and is waiting to be merged).

The following dependencies in Cargo.toml are temporarily using a fork:

# Use fork with deform_conv2d support - change back to path after PR merged
candle = { git = "https://github.com/SpenserCai/candle", branch = "support_deform_conv2d", package = "candle-core" }
candle-kernels = { git = "https://github.com/SpenserCai/candle", branch = "support_deform_conv2d" }
candle-metal-kernels = { git = "https://github.com/SpenserCai/candle", branch = "support_deform_conv2d" }

Once PR #3292 is merged, these dependencies will be reverted to local paths.

Changes

New Files

candle-transformers/src/models/birefnet/

  • mod.rs - Module exports and documentation
  • config.rs - Model configuration with support for different backbone types
  • birefnet.rs - Main BiRefNet model implementation
  • swin.rs - Swin Transformer backbone (SwinTransformerBlock, BasicLayer, PatchEmbed, PatchMerging)
  • decoder.rs - Decoder with multi-scale feature fusion
  • aspp.rs - ASPP (Atrous Spatial Pyramid Pooling) with Deformable Convolution support
  • blocks.rs - Basic building blocks (BasicDecBlk, BasicLatBlk, SimpleConvs)

candle-examples/examples/birefnet/

  • main.rs - Example application for background removal
  • README.md - Usage documentation

Modified Files

  • candle-transformers/src/models/mod.rs - Added birefnet module export
  • Cargo.toml - Added temporary fork dependencies for deform_conv2d support

Features

  • Full BiRefNet architecture implementation
  • Swin Transformer backbone with configurable depths and window sizes
  • ASPPDeformable module with learnable deformable convolutions
  • Multi-scale decoder with lateral connections
  • Support for multiple model variants:

Example Usage

# Auto-download model from HuggingFace
cargo run --example birefnet --release --features cuda -- \
    --image input.jpg \
    --output output.png

# Use local model
cargo run --example birefnet --release --features cuda -- \
    --model ./model.safetensors \
    --image input.jpg \
    --output output.png

# Output mask only
cargo run --example birefnet --release --features cuda -- \
    --image input.jpg \
    --output mask.png \
    --mask-only

# Benchmark mode
cargo run --example birefnet --release --features cuda -- \
    --image input.jpg \
    --output output.png \
    --bench 10

Platform Support

  • ✅ CUDA (--features cuda)
  • ✅ Metal (--features metal)
  • ✅ CPU (--cpu flag)

Test Results

Input Image

image

Python Reference Output (Official BiRefNet)

image

Rust Candle Output

image

Cuda Result

    Running `target/release/examples/birefnet --model ../weights/BiRefNet/model.safetensors --image ../test_2_s12.png --output ../test_2_s12_birefnet.png`
BiRefNet Background Removal
===========================
Device: Cuda(CudaDevice(DeviceId(1)))
Dtype: F32

Loading model...
Model: "../weights/BiRefNet/model.safetensors"
Model loaded successfully

Processing image: ../test_2_s12.png
Original size: 768x1280, preprocessed in 41.93ms

Running inference...
Inference completed in 515.22ms

Saving output to: ../test_2_s12_birefnet.png
RGBA image saved successfully

Done!

Metal Result

     Running `/Users/spensercai/Dev/candle_dev/candle/target/release/examples/birefnet --model /Users/spensercai/Dev/candle_dev/weights/BiRefNet/model.safetensors --image /Users/spensercai/Dev/candle_dev/check_tests/test_datas/rmbg2.0/test_1.png --output /Users/spensercai/Dev/candle_dev/check_tests/test_datas/rmbg2.0/test_1_birefnet_rust.png`
BiRefNet Background Removal
===========================
Device: Metal(MetalDevice(DeviceId(1)))
Dtype: F32

Loading model...
Model: "/Users/spensercai/Dev/candle_dev/weights/BiRefNet/model.safetensors"
Model loaded successfully

Processing image: /Users/spensercai/Dev/candle_dev/check_tests/test_datas/rmbg2.0/test_1.png
Original size: 1258x1176, preprocessed in 24.99ms

Running inference...
Inference completed in 4.14s

Saving output to: /Users/spensercai/Dev/candle_dev/check_tests/test_datas/rmbg2.0/test_1_birefnet_rust.png
RGBA image saved successfully

Done!

References

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant