Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,47 @@
can_return_tuple,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from ...utils.import_utils import (
is_causal_conv1d_available,
is_einops_available,
is_kernels_available,
is_mamba_2_ssm_available,
)
from .configuration_bamba import BambaConfig


if is_mamba_2_ssm_available():
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
selective_state_update = None
selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None
causal_conv1d_update, causal_conv1d_fn = None, None


def _lazy_load_mamba2_ssm():
global selective_state_update

if is_kernels_available() and is_einops_available():
from kernels import get_kernel

mamba_ssm = get_kernel("kernels-community/mamba-ssm")

selective_state_update = mamba_ssm.ops.triton.selective_state_update.selective_state_update
mamba_chunk_scan_combined = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined
mamba_split_conv1d_scan_combined = mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined

elif is_mamba_2_ssm_available():
from mamba_ssm.ops.triton.selective_state_update import selective_state_update

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None

def _lazy_load_causal_conv1d():
global causal_conv1d_update, causal_conv1d_fn

if is_kernels_available() and is_einops_available():
from kernels import get_kernel

causal_conv1d = get_kernel("kernels-community/causal-conv1d")
causal_conv1d_fn = causal_conv1d.causal_conv1d_fn
causal_conv1d_update = causal_conv1d.causal_conv1d_update

elif is_causal_conv1d_available:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update

is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))

Expand Down Expand Up @@ -285,11 +312,17 @@ def __init__(self, config: BambaConfig, layer_idx: int):

self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)

_lazy_load_causal_conv1d()
_lazy_load_mamba2_ssm()

global is_fast_path_available
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
" is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
" https://github.com/Dao-AILab/causal-conv1d or install the kernels library using `pip install kernels`"
)
else:
logger.warning_once("The fast path for Bamba will be used when running the model on a GPU")
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,10 @@ def is_matplotlib_available() -> bool:
def is_mistral_common_available() -> bool:
return _is_package_available("mistral_common")

@lru_cache
def is_einops_available() -> bool:
return _is_package_available("einops")


def check_torch_load_is_safe() -> None:
if not is_torch_greater_or_equal("2.6"):
Expand Down