Skip to content

Conversation

@SpenserCai
Copy link
Contributor

Summary

This PR adds support for Deformable Convolution v2 (DCNv2) to Candle, implementing the operation across CPU, Metal, and CUDA backends. The implementation follows the torchvision reference and maintains numerical consistency with PyTorch.

Motivation

Deformable Convolution is a learnable convolution operation that predicts sampling position offsets to achieve adaptive receptive fields. It is widely used in:

  • Object detection (e.g., Deformable DETR)
  • Image segmentation
  • Background removal models (e.g., BiRefNet, RMBG-2.0)

References:

API

impl Tensor {
    /// Performs Deformable Convolution 2D (DCNv2).
    ///
    /// # Arguments
    /// * `offset` - Offset tensor of shape [batch, 2*offset_groups*kH*kW, out_h, out_w]
    /// * `weight` - Convolution weight of shape [out_channels, in_channels/groups, kH, kW]
    /// * `mask` - Optional modulation mask of shape [batch, offset_groups*kH*kW, out_h, out_w]
    /// * `bias` - Optional bias of shape [out_channels]
    /// * `stride` - Stride (stride_h, stride_w)
    /// * `padding` - Padding (pad_h, pad_w)
    /// * `dilation` - Dilation (dilation_h, dilation_w)
    /// * `groups` - Number of convolution groups
    /// * `offset_groups` - Number of offset groups
    ///
    /// # Returns
    /// Output tensor of shape [batch, out_channels, out_h, out_w]
    pub fn deform_conv2d(
        &self,
        offset: &Tensor,
        weight: &Tensor,
        mask: Option<&Tensor>,
        bias: Option<&Tensor>,
        stride: (usize, usize),
        padding: (usize, usize),
        dilation: (usize, usize),
        groups: usize,
        offset_groups: usize,
    ) -> Result<Tensor>;
}

Implementation Details

Architecture

The implementation follows Candle's existing conv2d pattern:

  1. Storage Layer: Implements deform_im2col kernel that generates the columns matrix with deformable sampling
  2. Tensor Layer: Completes the full workflow:
    • Calls storage layer's deform_im2col to generate columns
    • Performs weight × columns matrix multiplication
    • Handles bias addition and output reshape

Backend Support

Backend Status Data Types
CPU F32, F64, F16, BF16
Metal F32, F16, BF16
CUDA F32, F64, F16, BF16

Key Algorithm: Bilinear Interpolation

Since offsets are floating-point values, bilinear interpolation is used to sample from the input feature map:

h_low = floor(h), h_high = h_low + 1
w_low = floor(w), w_high = w_low + 1

lh = h - h_low, lw = w - w_low
hh = 1 - lh,    hw = 1 - lw

output = hh*hw*v1 + hh*lw*v2 + lh*hw*v3 + lh*lw*v4

CUDA Half-Precision Handling

For __half and __nv_bfloat16 types, intermediate calculations are performed in float to avoid type ambiguity issues with CUDA's constructor overloads. This follows the same pattern used in upsample_bilinear2d.

Files Changed

candle-core

  • src/conv.rs - Added ParamsDeformConv2D struct and Tensor::deform_conv2d() method
  • src/backend.rs - Added deform_conv2d to BackendStorage trait
  • src/storage.rs - Added storage layer dispatch
  • src/cpu_backend/mod.rs - CPU backend implementation
  • src/cpu_backend/deform_conv2d.rs - CPU kernel implementation (new file)
  • src/metal_backend/mod.rs - Metal backend implementation
  • src/cuda_backend/mod.rs - CUDA backend implementation
  • src/dummy_cuda_backend.rs - Dummy CUDA backend stub
  • src/dummy_metal_backend.rs - Dummy Metal backend stub
  • tests/deform_conv2d_tests.rs - Comprehensive test suite (new file)
  • benches/benchmarks/deform_conv2d.rs - Performance benchmark (new file)

candle-kernels

  • src/deform_conv2d.cu - CUDA kernel implementation (new file)
  • src/lib.rs - Kernel registration

candle-metal-kernels

  • src/metal_src/deform_conv2d.metal - Metal shader implementation (new file)
  • src/kernels/deform_conv2d.rs - Metal kernel bindings (new file)
  • src/kernels/mod.rs - Module registration
  • src/kernel.rs - Kernel registration
  • src/lib.rs - Export registration
  • src/source.rs - Source registration

Testing

Test Cases

The test suite includes:

  1. Basic test - No mask, no bias, validates against PyTorch output
  2. With mask (DCNv2) - Tests modulation mask functionality
  3. With bias - Tests bias addition
  4. With stride - Tests stride=2 configuration
  5. With dilation - Tests dilation=2 configuration
  6. Offset groups - Tests multiple offset groups
  7. Batch processing - Tests batch_size > 1
  8. Full config - Tests mask + bias together

Numerical Consistency

Test data was generated using PyTorch/torchvision:

import torch
from torchvision.ops import deform_conv2d
torch.manual_seed(42)
# batch=1, in_c=2, out_c=2, h=w=4, k=3, stride=1, padding=1
input = torch.randn(1, 2, 4, 4)
weight = torch.randn(2, 2, 3, 3) * 0.1
offset = torch.randn(1, 18, 4, 4) * 0.5
output = deform_conv2d(input, offset, weight, stride=1, padding=1)

All tests pass with max absolute error < 1e-4.

Test Results

CPU Tests

cargo test --test deform_conv2d_tests -- --nocapture

    Finished `test` profile [unoptimized + debuginfo] target(s) in 0.12s
     Running tests/deform_conv2d_tests.rs (target/debug/deps/deform_conv2d_tests-c5c50bf05fde3bd7)

running 8 tests

=== deform_conv2d_basic (Cpu) ===
  Elements: 32
  Max absolute error: 1.89e-6
  Mean absolute error: 7.25e-7
  RMSE: 8.95e-7

=== deform_conv2d_with_mask (Cpu) ===
  Elements: 32
  Max absolute error: 1.52e-6
  Mean absolute error: 4.74e-7
  RMSE: 6.21e-7

=== deform_conv2d_with_bias (Cpu) ===
  Elements: 32
  Max absolute error: 1.91e-6
  Mean absolute error: 7.26e-7
  RMSE: 8.95e-7
test deform_conv2d_basic_cpu ... ok
test deform_conv2d_batch_cpu ... ok
test deform_conv2d_with_stride_cpu ... ok
test deform_conv2d_with_mask_cpu ... ok
test deform_conv2d_with_bias_cpu ... ok
test deform_conv2d_with_dilation_cpu ... ok
test deform_conv2d_full_cpu ... ok
test deform_conv2d_offset_groups_cpu ... ok

test result: ok. 8 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

Metal Tests

cargo test --test deform_conv2d_tests --features metal -- --nocapture
    Finished `test` profile [unoptimized + debuginfo] target(s) in 0.14s
     Running tests/deform_conv2d_tests.rs (target/debug/deps/deform_conv2d_tests-4a64453b113e2137)

running 16 tests

=== deform_conv2d_basic (Cpu) ===
  Elements: 32
  Max absolute error: 1.89e-6
  Mean absolute error: 7.25e-7
  RMSE: 8.95e-7

=== deform_conv2d_with_mask (Cpu) ===
  Elements: 32
  Max absolute error: 1.52e-6
  Mean absolute error: 4.74e-7
  RMSE: 6.21e-7

=== deform_conv2d_with_bias (Cpu) ===
  Elements: 32
  Max absolute error: 1.91e-6
  Mean absolute error: 7.26e-7
  RMSE: 8.95e-7
test deform_conv2d_basic_cpu ... ok
test deform_conv2d_batch_cpu ... ok
test deform_conv2d_with_mask_cpu ... ok
test deform_conv2d_with_bias_cpu ... ok
test deform_conv2d_full_cpu ... ok
test deform_conv2d_with_dilation_cpu ... ok
test deform_conv2d_with_stride_cpu ... ok
test deform_conv2d_offset_groups_cpu ... ok
test deform_conv2d_offset_groups_metal ... ok
test deform_conv2d_batch_metal ... ok
test deform_conv2d_full_metal ... ok

=== deform_conv2d_basic (Metal(MetalDevice(DeviceId(2)))) ===
  Elements: 32
  Max absolute error: 1.89e-6
  Mean absolute error: 7.38e-7
  RMSE: 9.01e-7

=== deform_conv2d_with_mask (Metal(MetalDevice(DeviceId(7)))) ===
  Elements: 32
  Max absolute error: 1.55e-6
  Mean absolute error: 4.75e-7
  RMSE: 6.23e-7

=== deform_conv2d_with_bias (Metal(MetalDevice(DeviceId(3)))) ===
  Elements: 32
  Max absolute error: 1.91e-6
  Mean absolute error: 7.40e-7
  RMSE: 9.02e-7
test deform_conv2d_basic_metal ... ok
test deform_conv2d_with_mask_metal ... ok
test deform_conv2d_with_bias_metal ... ok
test deform_conv2d_with_dilation_metal ... ok
test deform_conv2d_with_stride_metal ... ok

test result: ok. 16 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.13s

CUDA Tests

cargo test --test deform_conv2d_tests --features cuda -- --nocapture
    Finished `test` profile [unoptimized + debuginfo] target(s) in 0.12s
     Running tests/deform_conv2d_tests.rs (target/debug/deps/deform_conv2d_tests-24bfc1cb6064f3e3)

running 16 tests

=== deform_conv2d_basic (Cpu) ===
  Elements: 32
  Max absolute error: 1.89e-6
  Mean absolute error: 7.25e-7
  RMSE: 8.95e-7

=== deform_conv2d_with_bias (Cpu) ===
  Elements: 32
  Max absolute error: 1.91e-6
  Mean absolute error: 7.26e-7
  RMSE: 8.95e-7

=== deform_conv2d_with_mask (Cpu) ===
  Elements: 32
  Max absolute error: 1.52e-6
  Mean absolute error: 4.74e-7
  RMSE: 6.21e-7
test deform_conv2d_batch_cpu ... ok
test deform_conv2d_basic_cpu ... ok
test deform_conv2d_with_bias_cpu ... ok
test deform_conv2d_full_cpu ... ok
test deform_conv2d_offset_groups_cpu ... ok
test deform_conv2d_with_mask_cpu ... ok
test deform_conv2d_with_dilation_cpu ... ok
test deform_conv2d_with_stride_cpu ... ok

=== deform_conv2d_basic (Cuda(CudaDevice(DeviceId(3)))) ===
  Elements: 32
  Max absolute error: 1.92e-6
  Mean absolute error: 7.35e-7
  RMSE: 9.02e-7

=== deform_conv2d_with_mask (Cuda(CudaDevice(DeviceId(5)))) ===
  Elements: 32
  Max absolute error: 1.61e-6
  Mean absolute error: 4.80e-7
  RMSE: 6.30e-7
test deform_conv2d_basic_gpu ... ok
test deform_conv2d_with_mask_gpu ... ok

=== deform_conv2d_with_bias (Cuda(CudaDevice(DeviceId(7)))) ===
  Elements: 32
  Max absolute error: 1.91e-6
  Mean absolute error: 7.38e-7
  RMSE: 9.01e-7
test deform_conv2d_with_bias_gpu ... ok
test deform_conv2d_with_dilation_gpu ... ok
test deform_conv2d_full_gpu ... ok
test deform_conv2d_offset_groups_gpu ... ok
test deform_conv2d_with_stride_gpu ... ok
test deform_conv2d_batch_gpu ... ok

test result: ok. 16 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.09s

Usage Example

use candle_core::{Device, Tensor};

fn main() -> candle_core::Result<()> {
    let device = Device::Cpu;
    
    // Input: [batch=1, channels=64, height=32, width=32]
    let input = Tensor::randn(0f32, 1.0, (1, 64, 32, 32), &device)?;
    
    // Weight: [out_channels=64, in_channels=64, kH=3, kW=3]
    let weight = Tensor::randn(0f32, 0.1, (64, 64, 3, 3), &device)?;
    
    // Offset: [batch=1, 2*offset_groups*kH*kW=18, out_h=32, out_w=32]
    let offset = Tensor::randn(0f32, 0.5, (1, 18, 32, 32), &device)?;
    
    // Optional mask for DCNv2: [batch=1, offset_groups*kH*kW=9, out_h=32, out_w=32]
    let mask = Tensor::rand(0f32, 1.0, (1, 9, 32, 32), &device)?;
    
    // Optional bias: [out_channels=64]
    let bias = Tensor::zeros((64,), candle_core::DType::F32, &device)?;
    
    let output = input.deform_conv2d(
        &offset,
        &weight,
        Some(&mask),    // DCNv2 modulation mask
        Some(&bias),    // bias
        (1, 1),         // stride
        (1, 1),         // padding
        (1, 1),         // dilation
        1,              // groups
        1,              // offset_groups
    )?;
    
    println!("Output shape: {:?}", output.dims()); // [1, 64, 32, 32]
    Ok(())
}

Limitations

  • Backward pass is not implemented (forward-only)
  • Groups > 1 uses sequential processing (not optimized for large group counts)

Performance Benchmarks

Benchmarks run using cargo bench --bench bench_main -- deform_conv2d.

CPU Performance

Config: [1, 64, 32, 32] input, [64, 64, 3, 3] weight, with mask

Platform CPU Data Type Time Throughput
macOS Apple M4 Pro F32 34.91 ms 7.16 MiB/s
Linux Intel i9-13900K F32 52.07 ms 4.80 MiB/s

Metal Performance (Apple M4 Pro)

Config: [1, 256, 64, 64] input, [256, 256, 3, 3] weight, with mask

Data Type Time Throughput
F32 1.14 ms 3.43 GiB/s
F16 1.04 ms 1.88 GiB/s
BF16 1.17 ms 1.66 GiB/s

CUDA Performance (NVIDIA RTX A6000)

Config: [1, 256, 64, 64] input, [256, 256, 3, 3] weight, with mask

Data Type Time Throughput
F32 414.65 µs 9.42 GiB/s
F16 197.44 µs 9.89 GiB/s
BF16 200.57 µs 9.74 GiB/s

Note: CPU benchmarks use smaller tensor size [1, 64, 32, 32] due to slower execution. GPU benchmarks (Metal/CUDA) use larger tensor size [1, 256, 64, 64] for realistic workloads.

Checklist

  • CPU implementation
  • Metal implementation
  • CUDA implementation
  • F32/F64 support
  • F16/BF16 support (all backends)
  • DCNv2 mask support
  • Bias support
  • Stride/padding/dilation support
  • Groups support
  • Offset groups support
  • Comprehensive test suite
  • Numerical consistency with PyTorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant