diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 89922b062..49eb9cc7d 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -99,7 +99,6 @@ def create_triton_backend(runner): return TritonAttnBackend(runner) - @register_attention_backend("torch_native") def create_torch_native_backend(runner): from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend @@ -120,6 +119,11 @@ def create_flashmla_backend(runner): return FlashMLABackend(runner) +@register_attention_backend("dcu_mla") +def create_dcu_mla_backend(runner): + from sglang.srt.layers.attention.dcu_mla_backend import DCUMLABackend + + return DCUMLABackend(runner) @register_attention_backend("fa3") def create_flashattention_v3_backend(runner): diff --git a/python/sglang/srt/layers/attention/dcu_mla_backend.py b/python/sglang/srt/layers/attention/dcu_mla_backend.py new file mode 100644 index 000000000..cb1626d90 --- /dev/null +++ b/python/sglang/srt/layers/attention/dcu_mla_backend.py @@ -0,0 +1,484 @@ + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Tuple, Union + +import torch +import triton + +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + +try: + from flash_mla import ( + flash_mla_with_kvcache, + flash_mla_with_kvcache_quantization, + get_mla_metadata + ) + _has_flash_mla = True +except Exception: + try: + from vllm.attention.ops.flashmla import ( + flash_mla_with_kvcache, + get_mla_metadata + ) + _has_flash_mla = False + except Exception: + raise ImportError( + "Can not import FlashMLA。Please perform the following operations to use flashmla:\n" + " pip install flash-mla\n" + " or\n" + " pip install vllm" + ) + +PAGE_SIZE = 64 # 强制64 + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInput + +@dataclass +class VllmMLADecodeMetadata: + flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + num_splits: Optional[torch.Tensor] = None + block_kv_indices: Optional[torch.Tensor] = None + +class DCUMLABackend(AttentionBackend): + + def __init__( + self, + model_runner: "ModelRunner", + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + kv_last_page_len_buf: Optional[torch.Tensor] = None, + ): + super().__init__() + + if model_runner.server_args.page_size != PAGE_SIZE: + raise ValueError( + f"dcu_mla backend requires page_size={PAGE_SIZE}, " + f"but got the {model_runner.server_args.page_size}" + ) + + self.num_q_heads = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.req_to_token = model_runner.req_to_token_pool.req_to_token + + self.kv_lora_rank = model_runner.model_config.kv_lora_rank + self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim + self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim + self.v_head_dim = model_runner.model_config.v_head_dim + self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim + + self.data_type = model_runner.kv_cache_dtype + self.q_data_type = model_runner.dtype + + self.device = model_runner.device + self.max_context_len = model_runner.model_config.context_len + self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + + self.forward_metadata: Union[VllmMLADecodeMetadata] = None + + self.skip_prefill = skip_prefill + if not skip_prefill: + # 先用triton backend,后面考虑替换 + # from sglang.srt.layers.attention.triton_backend import TritonAttnBackend + # self.triton_backend = TritonAttnBackend( + # model_runner, + # skip_prefill=False, + # kv_indptr_buf=kv_indptr_buf, + # ) + # prefill改用flash attn + from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend + self.flashattn_backend = FlashAttentionBackend( + model_runner, + skip_prefill=False, + ) + + def _build_decode_metadata( + self, + forward_batch: ForwardBatch, + seq_lens: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + + bs = forward_batch.batch_size + max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) + + # 参考vllm官方博客分页 + block_kv_indices = torch.full( + (bs, max_seqlen_pad), -1, dtype=torch.int32, device=seq_lens.device + ) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + seq_lens, + None, + block_kv_indices, + self.req_to_token.stride(0), + max_seqlen_pad, + ) + + mla_metadata, num_splits = get_mla_metadata( + seq_lens.to(torch.int32), self.num_q_heads, 1 + ) + return (mla_metadata, num_splits), num_splits, block_kv_indices + + def init_forward_metadata(self, forward_batch: ForwardBatch): + + if forward_batch.forward_mode.is_decode_or_idle(): + # decode用flashmla + (mla_metadata, num_splits), num_splits_t, block_kv_indices = ( + self._build_decode_metadata(forward_batch, forward_batch.seq_lens) + ) + self.forward_metadata = VllmMLADecodeMetadata( + mla_metadata, num_splits_t, block_kv_indices + ) + elif forward_batch.forward_mode.is_target_verify(): + seq_lens = forward_batch.seq_lens + self.num_draft_tokens + (mla_metadata, num_splits), num_splits_t, block_kv_indices = ( + self._build_decode_metadata(forward_batch, seq_lens) + ) + self.forward_metadata = VllmMLADecodeMetadata( + mla_metadata, num_splits_t, block_kv_indices + ) + else: + # prefill/extend用triton backend -> 改用flash attn + if not self.skip_prefill: + # self.triton_backend.init_forward_metadata(forward_batch) + self.flashattn_backend.init_forward_metadata(forward_batch) + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + block_kv_indices: Optional[torch.Tensor] = None, + ): + if block_kv_indices is None: + cuda_graph_kv_indices = torch.full( + (max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE), + 1, + dtype=torch.int32, + device="cuda", + ) + else: + cuda_graph_kv_indices = block_kv_indices + + if self.num_draft_tokens: + mla_metadata, num_splits = get_mla_metadata( + torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device), + self.num_draft_tokens * self.num_q_heads, + 1, + ) + else: + mla_metadata, num_splits = get_mla_metadata( + torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device), + self.num_q_heads, + 1, + ) + + self.cuda_graph_mla_metadata = mla_metadata + self.cuda_graph_num_splits = num_splits + self.cuda_graph_kv_indices = cuda_graph_kv_indices + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional["SpecInput"], + ): + if forward_mode.is_decode_or_idle(): + max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + None, + self.cuda_graph_kv_indices, + self.req_to_token.stride(0), + self.cuda_graph_kv_indices.stride(0), + ) + num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1) + mla_metadata, num_splits = get_mla_metadata( + seq_lens.to(torch.int32), num_q_heads, 1 + ) + self.cuda_graph_mla_metadata.copy_(mla_metadata) + self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) + self.forward_metadata = VllmMLADecodeMetadata( + self.cuda_graph_mla_metadata, + self.cuda_graph_num_splits[: bs + 1], + self.cuda_graph_kv_indices[:bs, :max_seqlen_pad], + ) + elif forward_mode.is_target_verify(): + seq_lens = seq_lens + self.num_draft_tokens + max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + None, + self.cuda_graph_kv_indices, + self.req_to_token.stride(0), + self.cuda_graph_kv_indices.stride(0), + ) + mla_metadata, num_splits = get_mla_metadata( + seq_lens.to(torch.int32), self.num_draft_tokens * self.num_q_heads, 1 + ) + self.cuda_graph_mla_metadata.copy_(mla_metadata) + self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) + self.forward_metadata = VllmMLADecodeMetadata( + self.cuda_graph_mla_metadata, + self.cuda_graph_num_splits[: bs + 1], + self.cuda_graph_kv_indices[:bs, :max_seqlen_pad], + ) + else: + if not self.skip_prefill: + # self.triton_backend.init_forward_metadata_capture_cuda_graph( + # bs, + # num_tokens, + # req_pool_indices, + # seq_lens, + # encoder_lens, + # forward_mode, + # spec_info, + # ) + self.flashattn_backend.init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional["SpecInput"], + seq_lens_cpu: Optional[torch.Tensor], + ): + if forward_mode.is_decode_or_idle(): + assert seq_lens_cpu is not None + seq_lens = seq_lens[:bs] + seq_lens_cpu = seq_lens_cpu[:bs] + max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens, + None, + self.cuda_graph_kv_indices, + self.req_to_token.stride(0), + self.cuda_graph_kv_indices.stride(0), + ) + num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1) + mla_metadata, num_splits = get_mla_metadata( + seq_lens.to(torch.int32), num_q_heads, 1 + ) + self.cuda_graph_mla_metadata.copy_(mla_metadata) + self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) + self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata + self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1] + self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[ + :bs, :max_seqlen_pad + ] + elif forward_mode.is_target_verify(): + seq_lens = seq_lens[:bs] + self.num_draft_tokens + seq_lens_cpu = seq_lens_cpu[:bs] + self.num_draft_tokens + max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens, + None, + self.cuda_graph_kv_indices, + self.req_to_token.stride(0), + self.cuda_graph_kv_indices.stride(0), + ) + mla_metadata, num_splits = get_mla_metadata( + seq_lens.to(torch.int32), self.num_draft_tokens * self.num_q_heads, 1 + ) + self.cuda_graph_mla_metadata.copy_(mla_metadata) + self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) + self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata + self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1] + self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[ + :bs, :max_seqlen_pad + ] + else: + if not self.skip_prefill: + # self.triton_backend.init_forward_metadata_replay_cuda_graph( + # bs, + # req_pool_indices, + # seq_lens, + # seq_lens_sum, + # encoder_lens, + # forward_mode, + # spec_info, + # seq_lens_cpu, + # ) + self.flashattn_backend.init_forward_metadata_replay_cuda_graph( + bs, + req_pool_indices, + seq_lens, + seq_lens_sum, + encoder_lens, + forward_mode, + spec_info, + seq_lens_cpu, + ) + + def get_cuda_graph_seq_len_fill_value(self): + return 1 + + def _call_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor, + block_table: torch.Tensor, cache_seqlens: torch.Tensor, + scaling: float): + o, _ = flash_mla_with_kvcache( + q=reshape_q, + k_cache=k_cache_reshaped, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, + num_splits=self.forward_metadata.num_splits, + softmax_scale=scaling, + causal=True, + ) + return o + + def _call_fp8_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor, + block_table: torch.Tensor, cache_seqlens: torch.Tensor, + scaling: float): + assert _has_flash_mla, "FP8 KV cache 需要flash_mla包" + o, _ = flash_mla_with_kvcache_quantization( + q=reshape_q, + k_cache=k_cache_reshaped, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, + num_splits=self.forward_metadata.num_splits, + softmax_scale=scaling, + causal=True, + is_fp8_kvcache=True, + ) + return o + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: "RadixAttention", + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + ): + cache_loc = forward_batch.out_cache_loc + + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + cache_loc, + k, + v, + ) + + bs = forward_batch.batch_size + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) + k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim) + + if self.data_type in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e5m2fnuz", None), + ): + o = self._call_fp8_decode( + reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs], + forward_batch.seq_lens.to(torch.int32), layer.scaling, + ) + else: + o = self._call_decode( + reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs], + forward_batch.seq_lens.to(torch.int32), layer.scaling, + ) + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: "RadixAttention", + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + sinks=None, + ): + if ( + forward_batch.forward_mode == ForwardMode.EXTEND + or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND + ): + # flash_attn不支持fp8,fp8无法正常执行extend + if not self.skip_prefill: + # return self.triton_backend.forward_extend( + # q, k, v, layer, forward_batch, save_kv_cache, sinks + # ) + return self.flashattn_backend.forward_extend( + q, k, v, layer, forward_batch, save_kv_cache, sinks + ) + else: + raise RuntimeError("skip prefill but use forward_extend") + + cache_loc = forward_batch.out_cache_loc + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + + bs = forward_batch.batch_size + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) + k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim) + + if self.data_type in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e5m2fnuz", None), + ): + o = self._call_fp8_decode( + reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs], + (forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32), + layer.scaling, + ) + else: + o = self._call_decode( + reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs], + (forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32), + layer.scaling, + ) + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + diff --git a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py index 775e03bb2..72ab94ad4 100644 --- a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +++ b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py @@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import torch import torch.nn.functional as F -from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache from sgl_kernel.sparse_flash_attn import ( convert_vertical_slash_indexes, convert_vertical_slash_indexes_mergehead, diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 8d8f789d0..f71eccee4 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -20,7 +20,8 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner from sgl_kernel import merge_state_v2 -from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache @dataclass diff --git a/python/sglang/srt/layers/attention/flashattention_interface.py b/python/sglang/srt/layers/attention/flashattention_interface.py new file mode 100644 index 000000000..b0c8b7b05 --- /dev/null +++ b/python/sglang/srt/layers/attention/flashattention_interface.py @@ -0,0 +1,94 @@ +from flash_attn import ( + flash_attn_varlen_func as flash_attn_varlen_func_interface, + flash_attn_with_kvcache as flash_attn_with_kvcache_interface +) +from typing import Optional, Union + +import torch + +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[int, torch.Tensor]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + attention_chunk: Optional[int] = None, + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, + sinks=None, + ver=3, +): + return flash_attn_with_kvcache_interface( + q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]), + k_cache=k_cache, + v_cache=v_cache, + block_table=page_table, + cache_seqlens=cache_seqlens, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + return_softmax_lse=return_softmax_lse, + num_splits=num_splits, + ) + +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=None, + max_seqlen_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + num_splits=1, + pack_gqa=None, + sm_margin=0, + return_softmax_lse=False, + sinks=None, + ver=3, +): + return flash_attn_varlen_func_interface( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_q, + softmax_scale=softmax_scale, + causal=causal, + ) \ No newline at end of file diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 7da15cc47..03dc4728b 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -45,7 +45,8 @@ if _is_hip: "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." ) else: - from sgl_kernel.flash_attn import flash_attn_with_kvcache + # from sgl_kernel.flash_attn import flash_attn_with_kvcache + from sglang.srt.layers.attention.flashattention_interface import flash_attn_with_kvcache @dataclass(frozen=True) diff --git a/python/sglang/srt/layers/attention/xpu_backend.py b/python/sglang/srt/layers/attention/xpu_backend.py index 5ab4a160c..e5cea93cb 100644 --- a/python/sglang/srt/layers/attention/xpu_backend.py +++ b/python/sglang/srt/layers/attention/xpu_backend.py @@ -20,7 +20,8 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner from sgl_kernel import merge_state_v2 -from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache class XPUAttentionBackend(AttentionBackend): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ce2a9c4c5..af4c4434e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -165,6 +165,7 @@ MLA_ATTENTION_BACKENDS = [ "triton", "flashmla", "cutlass_mla", + "dcu_mla", "trtllm_mla", "ascend", "nsa", diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 891dd51e8..fce685ea0 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -342,6 +342,10 @@ def handle_attention_flashmla(attn, forward_batch): return _handle_attention_backend(attn, forward_batch, "flashmla") +def handle_attention_dcu_mla(attn, forward_batch): + return _handle_attention_backend(attn, forward_batch, "dcu_mla") + + def handle_attention_cutlass_mla(attn, forward_batch): return _handle_attention_backend(attn, forward_batch, "cutlass_mla") @@ -3577,6 +3581,7 @@ AttentionBackendRegistry.register("ascend", handle_attention_ascend) AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer) AttentionBackendRegistry.register("fa3", handle_attention_fa3) AttentionBackendRegistry.register("flashmla", handle_attention_flashmla) +AttentionBackendRegistry.register("dcu_mla", handle_attention_dcu_mla) AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla) AttentionBackendRegistry.register("fa4", handle_attention_fa4) AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7d26160a6..f1054eb1c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -102,6 +102,8 @@ ATTENTION_BACKEND_CHOICES = [ "torch_native", "flex_attention", "nsa", + # ransplant from vllm + "dcu_mla", # NVIDIA specific "cutlass_mla", "fa3", @@ -1077,9 +1079,11 @@ class ServerArgs: if ( self.attention_backend == "flashmla" or self.decode_attention_backend == "flashmla" + or self.attention_backend == "dcu_mla" + or self.decode_attention_backend == "dcu_mla" ): logger.warning( - "FlashMLA only supports a page_size of 64, change page_size to 64." + "FlashMLA/DCU MLA only supports a page_size of 64, change page_size to 64." ) self.page_size = 64