Compare commits
8 Commits
v0.5.4_dev
...
v0.5.4_dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
46da95569f | ||
|
|
ee77577211 | ||
|
|
a9e0e668c4 | ||
|
|
0c5532b0c1 | ||
|
|
785e5e900b | ||
|
|
47e4d92348 | ||
|
|
0fbecc4364 | ||
|
|
477fddf28d |
@@ -99,7 +99,6 @@ def create_triton_backend(runner):
|
|||||||
|
|
||||||
return TritonAttnBackend(runner)
|
return TritonAttnBackend(runner)
|
||||||
|
|
||||||
|
|
||||||
@register_attention_backend("torch_native")
|
@register_attention_backend("torch_native")
|
||||||
def create_torch_native_backend(runner):
|
def create_torch_native_backend(runner):
|
||||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||||
@@ -120,6 +119,11 @@ def create_flashmla_backend(runner):
|
|||||||
|
|
||||||
return FlashMLABackend(runner)
|
return FlashMLABackend(runner)
|
||||||
|
|
||||||
|
@register_attention_backend("dcu_mla")
|
||||||
|
def create_dcu_mla_backend(runner):
|
||||||
|
from sglang.srt.layers.attention.dcu_mla_backend import DCUMLABackend
|
||||||
|
|
||||||
|
return DCUMLABackend(runner)
|
||||||
|
|
||||||
@register_attention_backend("fa3")
|
@register_attention_backend("fa3")
|
||||||
def create_flashattention_v3_backend(runner):
|
def create_flashattention_v3_backend(runner):
|
||||||
|
|||||||
484
python/sglang/srt/layers/attention/dcu_mla_backend.py
Normal file
484
python/sglang/srt/layers/attention/dcu_mla_backend.py
Normal file
@@ -0,0 +1,484 @@
|
|||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
|
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
||||||
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_mla import (
|
||||||
|
flash_mla_with_kvcache,
|
||||||
|
flash_mla_with_kvcache_quantization,
|
||||||
|
get_mla_metadata
|
||||||
|
)
|
||||||
|
_has_flash_mla = True
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
from vllm.attention.ops.flashmla import (
|
||||||
|
flash_mla_with_kvcache,
|
||||||
|
get_mla_metadata
|
||||||
|
)
|
||||||
|
_has_flash_mla = False
|
||||||
|
except Exception:
|
||||||
|
raise ImportError(
|
||||||
|
"Can not import FlashMLA。Please perform the following operations to use flashmla:\n"
|
||||||
|
" pip install flash-mla\n"
|
||||||
|
" or\n"
|
||||||
|
" pip install vllm"
|
||||||
|
)
|
||||||
|
|
||||||
|
PAGE_SIZE = 64 # 强制64
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
from sglang.srt.speculative.spec_info import SpecInput
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VllmMLADecodeMetadata:
|
||||||
|
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||||
|
num_splits: Optional[torch.Tensor] = None
|
||||||
|
block_kv_indices: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
class DCUMLABackend(AttentionBackend):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_runner: "ModelRunner",
|
||||||
|
skip_prefill: bool = False,
|
||||||
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||||
|
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if model_runner.server_args.page_size != PAGE_SIZE:
|
||||||
|
raise ValueError(
|
||||||
|
f"dcu_mla backend requires page_size={PAGE_SIZE}, "
|
||||||
|
f"but got the {model_runner.server_args.page_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_q_heads = (
|
||||||
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
|
)
|
||||||
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
|
||||||
|
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
||||||
|
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||||
|
self.v_head_dim = model_runner.model_config.v_head_dim
|
||||||
|
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
||||||
|
|
||||||
|
self.data_type = model_runner.kv_cache_dtype
|
||||||
|
self.q_data_type = model_runner.dtype
|
||||||
|
|
||||||
|
self.device = model_runner.device
|
||||||
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||||
|
|
||||||
|
self.forward_metadata: Union[VllmMLADecodeMetadata] = None
|
||||||
|
|
||||||
|
self.skip_prefill = skip_prefill
|
||||||
|
if not skip_prefill:
|
||||||
|
# 先用triton backend,后面考虑替换
|
||||||
|
# from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||||
|
# self.triton_backend = TritonAttnBackend(
|
||||||
|
# model_runner,
|
||||||
|
# skip_prefill=False,
|
||||||
|
# kv_indptr_buf=kv_indptr_buf,
|
||||||
|
# )
|
||||||
|
# prefill改用flash attn
|
||||||
|
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
||||||
|
self.flashattn_backend = FlashAttentionBackend(
|
||||||
|
model_runner,
|
||||||
|
skip_prefill=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_decode_metadata(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
bs = forward_batch.batch_size
|
||||||
|
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
||||||
|
|
||||||
|
# 参考vllm官方博客分页
|
||||||
|
block_kv_indices = torch.full(
|
||||||
|
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=seq_lens.device
|
||||||
|
)
|
||||||
|
create_flashmla_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
None,
|
||||||
|
block_kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
max_seqlen_pad,
|
||||||
|
)
|
||||||
|
|
||||||
|
mla_metadata, num_splits = get_mla_metadata(
|
||||||
|
seq_lens.to(torch.int32), self.num_q_heads, 1
|
||||||
|
)
|
||||||
|
return (mla_metadata, num_splits), num_splits, block_kv_indices
|
||||||
|
|
||||||
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
|
||||||
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
|
# decode用flashmla
|
||||||
|
(mla_metadata, num_splits), num_splits_t, block_kv_indices = (
|
||||||
|
self._build_decode_metadata(forward_batch, forward_batch.seq_lens)
|
||||||
|
)
|
||||||
|
self.forward_metadata = VllmMLADecodeMetadata(
|
||||||
|
mla_metadata, num_splits_t, block_kv_indices
|
||||||
|
)
|
||||||
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
|
seq_lens = forward_batch.seq_lens + self.num_draft_tokens
|
||||||
|
(mla_metadata, num_splits), num_splits_t, block_kv_indices = (
|
||||||
|
self._build_decode_metadata(forward_batch, seq_lens)
|
||||||
|
)
|
||||||
|
self.forward_metadata = VllmMLADecodeMetadata(
|
||||||
|
mla_metadata, num_splits_t, block_kv_indices
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# prefill/extend用triton backend -> 改用flash attn
|
||||||
|
if not self.skip_prefill:
|
||||||
|
# self.triton_backend.init_forward_metadata(forward_batch)
|
||||||
|
self.flashattn_backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
|
def init_cuda_graph_state(
|
||||||
|
self,
|
||||||
|
max_bs: int,
|
||||||
|
max_num_tokens: int,
|
||||||
|
block_kv_indices: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if block_kv_indices is None:
|
||||||
|
cuda_graph_kv_indices = torch.full(
|
||||||
|
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
|
||||||
|
1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cuda_graph_kv_indices = block_kv_indices
|
||||||
|
|
||||||
|
if self.num_draft_tokens:
|
||||||
|
mla_metadata, num_splits = get_mla_metadata(
|
||||||
|
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
|
||||||
|
self.num_draft_tokens * self.num_q_heads,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mla_metadata, num_splits = get_mla_metadata(
|
||||||
|
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
|
||||||
|
self.num_q_heads,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cuda_graph_mla_metadata = mla_metadata
|
||||||
|
self.cuda_graph_num_splits = num_splits
|
||||||
|
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
||||||
|
|
||||||
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
|
self,
|
||||||
|
bs: int,
|
||||||
|
num_tokens: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional["SpecInput"],
|
||||||
|
):
|
||||||
|
if forward_mode.is_decode_or_idle():
|
||||||
|
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
||||||
|
create_flashmla_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
None,
|
||||||
|
self.cuda_graph_kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
|
)
|
||||||
|
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
|
||||||
|
mla_metadata, num_splits = get_mla_metadata(
|
||||||
|
seq_lens.to(torch.int32), num_q_heads, 1
|
||||||
|
)
|
||||||
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||||
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||||
|
self.forward_metadata = VllmMLADecodeMetadata(
|
||||||
|
self.cuda_graph_mla_metadata,
|
||||||
|
self.cuda_graph_num_splits[: bs + 1],
|
||||||
|
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
|
||||||
|
)
|
||||||
|
elif forward_mode.is_target_verify():
|
||||||
|
seq_lens = seq_lens + self.num_draft_tokens
|
||||||
|
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
||||||
|
create_flashmla_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
None,
|
||||||
|
self.cuda_graph_kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
|
)
|
||||||
|
mla_metadata, num_splits = get_mla_metadata(
|
||||||
|
seq_lens.to(torch.int32), self.num_draft_tokens * self.num_q_heads, 1
|
||||||
|
)
|
||||||
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||||
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||||
|
self.forward_metadata = VllmMLADecodeMetadata(
|
||||||
|
self.cuda_graph_mla_metadata,
|
||||||
|
self.cuda_graph_num_splits[: bs + 1],
|
||||||
|
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not self.skip_prefill:
|
||||||
|
# self.triton_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
|
# bs,
|
||||||
|
# num_tokens,
|
||||||
|
# req_pool_indices,
|
||||||
|
# seq_lens,
|
||||||
|
# encoder_lens,
|
||||||
|
# forward_mode,
|
||||||
|
# spec_info,
|
||||||
|
# )
|
||||||
|
self.flashattn_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
|
bs,
|
||||||
|
num_tokens,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
encoder_lens,
|
||||||
|
forward_mode,
|
||||||
|
spec_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional["SpecInput"],
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
if forward_mode.is_decode_or_idle():
|
||||||
|
assert seq_lens_cpu is not None
|
||||||
|
seq_lens = seq_lens[:bs]
|
||||||
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||||
|
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
|
||||||
|
create_flashmla_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices[:bs],
|
||||||
|
seq_lens,
|
||||||
|
None,
|
||||||
|
self.cuda_graph_kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
|
)
|
||||||
|
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
|
||||||
|
mla_metadata, num_splits = get_mla_metadata(
|
||||||
|
seq_lens.to(torch.int32), num_q_heads, 1
|
||||||
|
)
|
||||||
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||||
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||||
|
self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata
|
||||||
|
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
|
||||||
|
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
|
||||||
|
:bs, :max_seqlen_pad
|
||||||
|
]
|
||||||
|
elif forward_mode.is_target_verify():
|
||||||
|
seq_lens = seq_lens[:bs] + self.num_draft_tokens
|
||||||
|
seq_lens_cpu = seq_lens_cpu[:bs] + self.num_draft_tokens
|
||||||
|
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
|
||||||
|
create_flashmla_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices[:bs],
|
||||||
|
seq_lens,
|
||||||
|
None,
|
||||||
|
self.cuda_graph_kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
|
)
|
||||||
|
mla_metadata, num_splits = get_mla_metadata(
|
||||||
|
seq_lens.to(torch.int32), self.num_draft_tokens * self.num_q_heads, 1
|
||||||
|
)
|
||||||
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||||
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||||
|
self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata
|
||||||
|
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
|
||||||
|
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
|
||||||
|
:bs, :max_seqlen_pad
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
if not self.skip_prefill:
|
||||||
|
# self.triton_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
|
# bs,
|
||||||
|
# req_pool_indices,
|
||||||
|
# seq_lens,
|
||||||
|
# seq_lens_sum,
|
||||||
|
# encoder_lens,
|
||||||
|
# forward_mode,
|
||||||
|
# spec_info,
|
||||||
|
# seq_lens_cpu,
|
||||||
|
# )
|
||||||
|
self.flashattn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
|
bs,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
encoder_lens,
|
||||||
|
forward_mode,
|
||||||
|
spec_info,
|
||||||
|
seq_lens_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def _call_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor,
|
||||||
|
block_table: torch.Tensor, cache_seqlens: torch.Tensor,
|
||||||
|
scaling: float):
|
||||||
|
o, _ = flash_mla_with_kvcache(
|
||||||
|
q=reshape_q,
|
||||||
|
k_cache=k_cache_reshaped,
|
||||||
|
block_table=block_table,
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
head_dim_v=self.kv_lora_rank,
|
||||||
|
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
|
||||||
|
num_splits=self.forward_metadata.num_splits,
|
||||||
|
softmax_scale=scaling,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
return o
|
||||||
|
|
||||||
|
def _call_fp8_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor,
|
||||||
|
block_table: torch.Tensor, cache_seqlens: torch.Tensor,
|
||||||
|
scaling: float):
|
||||||
|
assert _has_flash_mla, "FP8 KV cache 需要flash_mla包"
|
||||||
|
o, _ = flash_mla_with_kvcache_quantization(
|
||||||
|
q=reshape_q,
|
||||||
|
k_cache=k_cache_reshaped,
|
||||||
|
block_table=block_table,
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
head_dim_v=self.kv_lora_rank,
|
||||||
|
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
|
||||||
|
num_splits=self.forward_metadata.num_splits,
|
||||||
|
softmax_scale=scaling,
|
||||||
|
causal=True,
|
||||||
|
is_fp8_kvcache=True,
|
||||||
|
)
|
||||||
|
return o
|
||||||
|
|
||||||
|
def forward_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: "RadixAttention",
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache: bool = True,
|
||||||
|
):
|
||||||
|
cache_loc = forward_batch.out_cache_loc
|
||||||
|
|
||||||
|
if k is not None:
|
||||||
|
assert v is not None
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer,
|
||||||
|
cache_loc,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
|
||||||
|
bs = forward_batch.batch_size
|
||||||
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
|
||||||
|
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
|
||||||
|
|
||||||
|
if self.data_type in (
|
||||||
|
getattr(torch, "float8_e4m3fn", None),
|
||||||
|
getattr(torch, "float8_e4m3fnuz", None),
|
||||||
|
getattr(torch, "float8_e5m2", None),
|
||||||
|
getattr(torch, "float8_e5m2fnuz", None),
|
||||||
|
):
|
||||||
|
o = self._call_fp8_decode(
|
||||||
|
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
|
||||||
|
forward_batch.seq_lens.to(torch.int32), layer.scaling,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o = self._call_decode(
|
||||||
|
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
|
||||||
|
forward_batch.seq_lens.to(torch.int32), layer.scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|
||||||
|
def forward_extend(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: "RadixAttention",
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache: bool = True,
|
||||||
|
sinks=None,
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
forward_batch.forward_mode == ForwardMode.EXTEND
|
||||||
|
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
||||||
|
):
|
||||||
|
# flash_attn不支持fp8,fp8无法正常执行extend
|
||||||
|
if not self.skip_prefill:
|
||||||
|
# return self.triton_backend.forward_extend(
|
||||||
|
# q, k, v, layer, forward_batch, save_kv_cache, sinks
|
||||||
|
# )
|
||||||
|
return self.flashattn_backend.forward_extend(
|
||||||
|
q, k, v, layer, forward_batch, save_kv_cache, sinks
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("skip prefill but use forward_extend")
|
||||||
|
|
||||||
|
cache_loc = forward_batch.out_cache_loc
|
||||||
|
if k is not None:
|
||||||
|
assert v is not None
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||||
|
|
||||||
|
bs = forward_batch.batch_size
|
||||||
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
|
||||||
|
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
|
||||||
|
|
||||||
|
if self.data_type in (
|
||||||
|
getattr(torch, "float8_e4m3fn", None),
|
||||||
|
getattr(torch, "float8_e4m3fnuz", None),
|
||||||
|
getattr(torch, "float8_e5m2", None),
|
||||||
|
getattr(torch, "float8_e5m2fnuz", None),
|
||||||
|
):
|
||||||
|
o = self._call_fp8_decode(
|
||||||
|
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
|
||||||
|
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
|
||||||
|
layer.scaling,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o = self._call_decode(
|
||||||
|
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
|
||||||
|
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
|
||||||
|
layer.scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|
||||||
|
|
||||||
@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
|
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
from sgl_kernel.sparse_flash_attn import (
|
from sgl_kernel.sparse_flash_attn import (
|
||||||
convert_vertical_slash_indexes,
|
convert_vertical_slash_indexes,
|
||||||
convert_vertical_slash_indexes_mergehead,
|
convert_vertical_slash_indexes_mergehead,
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
from sgl_kernel import merge_state_v2
|
from sgl_kernel import merge_state_v2
|
||||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
|
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -0,0 +1,94 @@
|
|||||||
|
from flash_attn import (
|
||||||
|
flash_attn_varlen_func as flash_attn_varlen_func_interface,
|
||||||
|
flash_attn_with_kvcache as flash_attn_with_kvcache_interface
|
||||||
|
)
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def flash_attn_with_kvcache(
|
||||||
|
q,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
k=None,
|
||||||
|
v=None,
|
||||||
|
qv=None,
|
||||||
|
rotary_cos=None,
|
||||||
|
rotary_sin=None,
|
||||||
|
cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
|
||||||
|
cache_batch_idx: Optional[torch.Tensor] = None,
|
||||||
|
cache_leftpad: Optional[torch.Tensor] = None,
|
||||||
|
page_table: Optional[torch.Tensor] = None,
|
||||||
|
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||||
|
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen_q: Optional[int] = None,
|
||||||
|
rotary_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
q_descale: Optional[torch.Tensor] = None,
|
||||||
|
k_descale: Optional[torch.Tensor] = None,
|
||||||
|
v_descale: Optional[torch.Tensor] = None,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=False,
|
||||||
|
window_size=(-1, -1), # -1 means infinite context window
|
||||||
|
attention_chunk: Optional[int] = None,
|
||||||
|
softcap=0.0, # 0.0 means deactivated
|
||||||
|
rotary_interleaved=True,
|
||||||
|
scheduler_metadata=None,
|
||||||
|
num_splits=0, # Can be tuned for speed
|
||||||
|
pack_gqa=None, # Can be tuned for speed
|
||||||
|
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||||
|
return_softmax_lse=False,
|
||||||
|
sinks=None,
|
||||||
|
ver=3,
|
||||||
|
):
|
||||||
|
return flash_attn_with_kvcache_interface(
|
||||||
|
q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]),
|
||||||
|
k_cache=k_cache,
|
||||||
|
v_cache=v_cache,
|
||||||
|
block_table=page_table,
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=softcap,
|
||||||
|
return_softmax_lse=return_softmax_lse,
|
||||||
|
num_splits=num_splits,
|
||||||
|
)
|
||||||
|
|
||||||
|
def flash_attn_varlen_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q=None,
|
||||||
|
max_seqlen_k=None,
|
||||||
|
seqused_q=None,
|
||||||
|
seqused_k=None,
|
||||||
|
page_table=None,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=False,
|
||||||
|
qv=None,
|
||||||
|
q_descale=None,
|
||||||
|
k_descale=None,
|
||||||
|
v_descale=None,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
attention_chunk=0,
|
||||||
|
softcap=0.0,
|
||||||
|
num_splits=1,
|
||||||
|
pack_gqa=None,
|
||||||
|
sm_margin=0,
|
||||||
|
return_softmax_lse=False,
|
||||||
|
sinks=None,
|
||||||
|
ver=3,
|
||||||
|
):
|
||||||
|
return flash_attn_varlen_func_interface(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_q,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_q,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
@@ -45,7 +45,8 @@ if _is_hip:
|
|||||||
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
# from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
||||||
|
from sglang.srt.layers.attention.flashattention_interface import flash_attn_with_kvcache
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
from sgl_kernel import merge_state_v2
|
from sgl_kernel import merge_state_v2
|
||||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
|
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
|
|
||||||
|
|
||||||
class XPUAttentionBackend(AttentionBackend):
|
class XPUAttentionBackend(AttentionBackend):
|
||||||
|
|||||||
@@ -165,6 +165,7 @@ MLA_ATTENTION_BACKENDS = [
|
|||||||
"triton",
|
"triton",
|
||||||
"flashmla",
|
"flashmla",
|
||||||
"cutlass_mla",
|
"cutlass_mla",
|
||||||
|
"dcu_mla",
|
||||||
"trtllm_mla",
|
"trtllm_mla",
|
||||||
"ascend",
|
"ascend",
|
||||||
"nsa",
|
"nsa",
|
||||||
|
|||||||
@@ -342,6 +342,10 @@ def handle_attention_flashmla(attn, forward_batch):
|
|||||||
return _handle_attention_backend(attn, forward_batch, "flashmla")
|
return _handle_attention_backend(attn, forward_batch, "flashmla")
|
||||||
|
|
||||||
|
|
||||||
|
def handle_attention_dcu_mla(attn, forward_batch):
|
||||||
|
return _handle_attention_backend(attn, forward_batch, "dcu_mla")
|
||||||
|
|
||||||
|
|
||||||
def handle_attention_cutlass_mla(attn, forward_batch):
|
def handle_attention_cutlass_mla(attn, forward_batch):
|
||||||
return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
|
return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
|
||||||
|
|
||||||
@@ -3577,6 +3581,7 @@ AttentionBackendRegistry.register("ascend", handle_attention_ascend)
|
|||||||
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
|
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
|
||||||
AttentionBackendRegistry.register("fa3", handle_attention_fa3)
|
AttentionBackendRegistry.register("fa3", handle_attention_fa3)
|
||||||
AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
|
AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
|
||||||
|
AttentionBackendRegistry.register("dcu_mla", handle_attention_dcu_mla)
|
||||||
AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
|
AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
|
||||||
AttentionBackendRegistry.register("fa4", handle_attention_fa4)
|
AttentionBackendRegistry.register("fa4", handle_attention_fa4)
|
||||||
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
|
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
|
||||||
|
|||||||
@@ -396,7 +396,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
|||||||
def _forward_input_proj(self, hidden_states: torch.Tensor):
|
def _forward_input_proj(self, hidden_states: torch.Tensor):
|
||||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0
|
||||||
seq_len, _ = hidden_states.shape
|
seq_len, _ = hidden_states.shape
|
||||||
if seq_len < DUAL_STREAM_TOKEN_THRESHOLD:
|
if seq_len < DUAL_STREAM_TOKEN_THRESHOLD and self.alt_stream is not None:
|
||||||
current_stream = torch.cuda.current_stream()
|
current_stream = torch.cuda.current_stream()
|
||||||
self.alt_stream.wait_stream(current_stream)
|
self.alt_stream.wait_stream(current_stream)
|
||||||
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
|
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
|
||||||
|
|||||||
@@ -102,6 +102,8 @@ ATTENTION_BACKEND_CHOICES = [
|
|||||||
"torch_native",
|
"torch_native",
|
||||||
"flex_attention",
|
"flex_attention",
|
||||||
"nsa",
|
"nsa",
|
||||||
|
# ransplant from vllm
|
||||||
|
"dcu_mla",
|
||||||
# NVIDIA specific
|
# NVIDIA specific
|
||||||
"cutlass_mla",
|
"cutlass_mla",
|
||||||
"fa3",
|
"fa3",
|
||||||
@@ -1077,9 +1079,11 @@ class ServerArgs:
|
|||||||
if (
|
if (
|
||||||
self.attention_backend == "flashmla"
|
self.attention_backend == "flashmla"
|
||||||
or self.decode_attention_backend == "flashmla"
|
or self.decode_attention_backend == "flashmla"
|
||||||
|
or self.attention_backend == "dcu_mla"
|
||||||
|
or self.decode_attention_backend == "dcu_mla"
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"FlashMLA only supports a page_size of 64, change page_size to 64."
|
"FlashMLA/DCU MLA only supports a page_size of 64, change page_size to 64."
|
||||||
)
|
)
|
||||||
self.page_size = 64
|
self.page_size = 64
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user