8 Commits

Author SHA1 Message Date
maxiao1
46da95569f Merge branch 'v0.5.4_dev_linhai' into 'v0.5.4_dev'
V0.5.4 dev linhai

See merge request OpenDAS/sglang!9
2025-11-04 12:04:48 +00:00
linhai1
ee77577211 V0.5.4 dev linhai 2025-11-04 12:04:47 +00:00
maxiao1
a9e0e668c4 Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'
enable custom_allreduce

See merge request OpenDAS/sglang!6
2025-11-04 09:33:25 +00:00
maxiao1
0c5532b0c1 enable custom_allreduce 2025-11-04 09:33:24 +00:00
maxiao1
785e5e900b Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'
调用vllm里custom all reduce

See merge request OpenDAS/sglang!5
2025-11-03 08:30:29 +00:00
maxiao1
47e4d92348 Merge branch 'v0.5.4_dev_qwennext' into 'v0.5.4_dev'
适配qwen3-next

See merge request OpenDAS/sglang!4
2025-11-03 05:18:16 +00:00
maxiao1
0fbecc4364 Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'
change sgl_kernel WARP_SIZE to 64

See merge request OpenDAS/sglang!3
2025-11-03 03:33:05 +00:00
maxiao
477fddf28d 适配qwen3-next 2025-10-30 18:03:07 +08:00
14 changed files with 972 additions and 22 deletions

View File

@@ -19,14 +19,15 @@ logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var( use_vllm_custom_allreduce = get_bool_env_var(
"USE_VLLM_CUSTOM_ALLREDUCE", default="false" "USE_VLLM_CUSTOM_ALLREDUCE", default="false"
) )
use_dcu_custom_allreduce= get_bool_env_var(
"USE_DCU_CUSTOM_ALLREDUCE", default="false"
)
if not is_hpu(): if not is_hpu():
# ROCm does not use vllm custom allreduce # ROCm does not use vllm custom allreduce
# if use_vllm_custom_allreduce and not is_hip(): if use_vllm_custom_allreduce and not is_hip():
if use_vllm_custom_allreduce:
try: try:
import vllm._C # noqa: F401 import vllm._C # noqa: F401
print("[DEBUG] ✅ Using vLLM custom allreduce (vllm._C successfully imported)")
except ImportError as e: except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e) logger.warning("Failed to import from vllm._C with %r", e)
else: else:
@@ -35,12 +36,15 @@ if not is_hpu():
except ImportError as e: except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e) logger.warning("Failed to import from custom_ar with %r", e)
if use_dcu_custom_allreduce:
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
# if not is_hip() and not is_npu(): if not is_hip() and not is_npu():
if not is_npu():
if use_vllm_custom_allreduce: if use_vllm_custom_allreduce:
custom_op = torch.ops._C_custom_ar custom_op = torch.ops._C_custom_ar
print("[DEBUG] ✅ custom_op = torch.ops._C_custom_ar (vLLM path active)")
else: else:
custom_op = sgl_kernel.allreduce custom_op = sgl_kernel.allreduce
@@ -79,8 +83,79 @@ if not is_npu():
) -> None: ) -> None:
custom_op.register_graph_buffers(fa, handles, offsets) custom_op.register_graph_buffers(fa, handles, offsets)
elif is_hip and use_dcu_custom_allreduce:
# custom ar
def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor,
rank: int, fully_connected: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
fully_connected)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
reg_buffer_sz_bytes)
def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa)
def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()
def register_buffer(fa: int, ipc_tensors: list[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa: int, handles: list[list[int]],
offsets: list[list[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]:
return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size)
def open_mem_handle(mem_handle: torch.Tensor):
return torch.ops._C_custom_ar.open_mem_handle(mem_handle)
def free_shared_buffer(ptr: int) -> None:
torch.ops._C_custom_ar.free_shared_buffer(ptr)
def read_cache(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
torch.ops._C_cache_ops.read_cache(keys, values, key_caches,
value_caches, slot_mapping,
kv_cache_dtype)
def write_cache_multi_layers(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
torch.ops._C_cache_ops.write_cache_multi_layers(keys, values, key_caches,
value_caches, slot_mapping,
kv_cache_dtype)
else: else:
# ROCM custom allreduce # sgl_kernel ROCM custom allreduce
def init_custom_ar( def init_custom_ar(
meta: torch.Tensor, meta: torch.Tensor,

View File

@@ -27,10 +27,11 @@ _is_hip = is_hip()
try: try:
# if ops.use_vllm_custom_allreduce and not _is_hip: if ops.use_vllm_custom_allreduce and not _is_hip:
if ops.use_vllm_custom_allreduce:
# Use vLLM custom allreduce # Use vLLM custom allreduce
ops.meta_size() ops.meta_size()
elif ops.use_dcu_custom_allreduce:
ops.meta_size()
else: else:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM) # Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel # noqa: F401 import sgl_kernel # noqa: F401
@@ -420,3 +421,274 @@ class CustomAllreduce:
def __del__(self): def __del__(self):
self.close() self.close()
class DCUCustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8, 16]
# max_size: max supported allreduce size
def __init__(self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 512) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-GPU environment
logger.info("Custom allreduce is disabled because "
"of missing custom allreduce library")
return
self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"
" spans across nodes.")
return
rank = dist.get_rank(group=self.group)
self.rank = rank
world_size = dist.get_world_size(group=self.group)
# if world_size > envs.VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX:
if world_size > 16:
return
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.",
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(torch.cuda.device_count()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# assert current_platform.is_cuda_alike()
# fully_connected = current_platform.is_fully_connected(
# physical_device_ids)
if _is_cuda or _is_hip:
fully_connected = is_full_nvlink(physical_device_ids, world_size)
# if world_size > 2 and not fully_connected:
if not fully_connected:
max_size = 32 * 8192 * 2
# if not envs.VLLM_PCIE_USE_CUSTOM_ALLREDUCE:
# logger.warning(
# "Custom allreduce is disabled because it's not supported on"
# " more than two PCIe-only GPUs. To silence this warning, "
# "specify disable_custom_all_reduce=True explicitly.")
# return
logger.warning(
"We are using PCIe's custom allreduce."
"If the performance is poor, we can add "
"--disable-custom-all-reduce in the instruction.")
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not _is_hip and not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.")
return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
group=group,
uncached=True)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8,
device=self.device)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.fully_connected = fully_connected
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
self.fully_connected)
ops.register_buffer(self._ptr, self.buffer_ptrs)
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None]
for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i],
src=rank,
group=self.group,
device="cpu")
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
return inp_size <= self.max_size
def all_reduce(self,
inp: torch.Tensor,
*,
out: torch.Tensor = None,
registered: bool = False):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
self.max_size)
return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=False)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False)
def close(self):
if not self.disabled and self._ptr:
if ops is not None:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
def __del__(self):
self.close()
@staticmethod
def create_shared_buffer(size_in_bytes: int,
group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False) -> list[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: list[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer) # type: ignore
else:
pointers.append(ops.open_mem_handle(h))
return pointers
@staticmethod
def free_shared_buffer(pointers: list[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = 0) -> None:
if rank is None:
rank = dist.get_rank(group=group)
if ops is not None:
ops.free_shared_buffer(pointers[rank])

View File

@@ -53,6 +53,7 @@ from sglang.srt.utils import (
is_xpu, is_xpu,
supports_custom_op, supports_custom_op,
) )
from sglang.srt import _custom_ops as ops
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu = is_cpu() _is_cpu = is_cpu()
@@ -303,7 +304,7 @@ class GroupCoordinator:
# Lazy import to avoid documentation build error # Lazy import to avoid documentation build error
from sglang.srt.distributed.device_communicators.custom_all_reduce import ( from sglang.srt.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce, CustomAllreduce, DCUCustomAllreduce
) )
from sglang.srt.distributed.device_communicators.pymscclpp import ( from sglang.srt.distributed.device_communicators.pymscclpp import (
PyMscclppCommunicator, PyMscclppCommunicator,
@@ -347,11 +348,17 @@ class GroupCoordinator:
else: else:
ca_max_size = 8 * 1024 * 1024 ca_max_size = 8 * 1024 * 1024
try: try:
self.ca_comm = CustomAllreduce( if is_hip() and ops.use_dcu_custom_allreduce:
group=self.cpu_group, self.ca_comm = DCUCustomAllreduce(
device=self.device, group=self.cpu_group,
max_size=ca_max_size, device=self.device,
) )
else:
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
max_size=ca_max_size,
)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"Setup Custom allreduce failed with {e}. To silence this " f"Setup Custom allreduce failed with {e}. To silence this "

View File

@@ -99,7 +99,6 @@ def create_triton_backend(runner):
return TritonAttnBackend(runner) return TritonAttnBackend(runner)
@register_attention_backend("torch_native") @register_attention_backend("torch_native")
def create_torch_native_backend(runner): def create_torch_native_backend(runner):
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
@@ -120,6 +119,11 @@ def create_flashmla_backend(runner):
return FlashMLABackend(runner) return FlashMLABackend(runner)
@register_attention_backend("dcu_mla")
def create_dcu_mla_backend(runner):
from sglang.srt.layers.attention.dcu_mla_backend import DCUMLABackend
return DCUMLABackend(runner)
@register_attention_backend("fa3") @register_attention_backend("fa3")
def create_flashattention_v3_backend(runner): def create_flashattention_v3_backend(runner):

View File

@@ -0,0 +1,484 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
import triton
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
try:
from flash_mla import (
flash_mla_with_kvcache,
flash_mla_with_kvcache_quantization,
get_mla_metadata
)
_has_flash_mla = True
except Exception:
try:
from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata
)
_has_flash_mla = False
except Exception:
raise ImportError(
"Can not import FlashMLA。Please perform the following operations to use flashmla:\n"
" pip install flash-mla\n"
" or\n"
" pip install vllm"
)
PAGE_SIZE = 64 # 强制64
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInput
@dataclass
class VllmMLADecodeMetadata:
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
num_splits: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None
class DCUMLABackend(AttentionBackend):
def __init__(
self,
model_runner: "ModelRunner",
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None,
):
super().__init__()
if model_runner.server_args.page_size != PAGE_SIZE:
raise ValueError(
f"dcu_mla backend requires page_size={PAGE_SIZE}, "
f"but got the {model_runner.server_args.page_size}"
)
self.num_q_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.device = model_runner.device
self.max_context_len = model_runner.model_config.context_len
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.forward_metadata: Union[VllmMLADecodeMetadata] = None
self.skip_prefill = skip_prefill
if not skip_prefill:
# 先用triton backend后面考虑替换
# from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
# self.triton_backend = TritonAttnBackend(
# model_runner,
# skip_prefill=False,
# kv_indptr_buf=kv_indptr_buf,
# )
# prefill改用flash attn
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
self.flashattn_backend = FlashAttentionBackend(
model_runner,
skip_prefill=False,
)
def _build_decode_metadata(
self,
forward_batch: ForwardBatch,
seq_lens: torch.Tensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
bs = forward_batch.batch_size
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
# 参考vllm官方博客分页
block_kv_indices = torch.full(
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=seq_lens.device
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), self.num_q_heads, 1
)
return (mla_metadata, num_splits), num_splits, block_kv_indices
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
# decode用flashmla
(mla_metadata, num_splits), num_splits_t, block_kv_indices = (
self._build_decode_metadata(forward_batch, forward_batch.seq_lens)
)
self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata, num_splits_t, block_kv_indices
)
elif forward_batch.forward_mode.is_target_verify():
seq_lens = forward_batch.seq_lens + self.num_draft_tokens
(mla_metadata, num_splits), num_splits_t, block_kv_indices = (
self._build_decode_metadata(forward_batch, seq_lens)
)
self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata, num_splits_t, block_kv_indices
)
else:
# prefill/extend用triton backend -> 改用flash attn
if not self.skip_prefill:
# self.triton_backend.init_forward_metadata(forward_batch)
self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None,
):
if block_kv_indices is None:
cuda_graph_kv_indices = torch.full(
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
1,
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = block_kv_indices
if self.num_draft_tokens:
mla_metadata, num_splits = get_mla_metadata(
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
self.num_draft_tokens * self.num_q_heads,
1,
)
else:
mla_metadata, num_splits = get_mla_metadata(
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
self.num_q_heads,
1,
)
self.cuda_graph_mla_metadata = mla_metadata
self.cuda_graph_num_splits = num_splits
self.cuda_graph_kv_indices = cuda_graph_kv_indices
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional["SpecInput"],
):
if forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), num_q_heads, 1
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata = VllmMLADecodeMetadata(
self.cuda_graph_mla_metadata,
self.cuda_graph_num_splits[: bs + 1],
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
)
elif forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), self.num_draft_tokens * self.num_q_heads, 1
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata = VllmMLADecodeMetadata(
self.cuda_graph_mla_metadata,
self.cuda_graph_num_splits[: bs + 1],
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
)
else:
if not self.skip_prefill:
# self.triton_backend.init_forward_metadata_capture_cuda_graph(
# bs,
# num_tokens,
# req_pool_indices,
# seq_lens,
# encoder_lens,
# forward_mode,
# spec_info,
# )
self.flashattn_backend.init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional["SpecInput"],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), num_q_heads, 1
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]
elif forward_mode.is_target_verify():
seq_lens = seq_lens[:bs] + self.num_draft_tokens
seq_lens_cpu = seq_lens_cpu[:bs] + self.num_draft_tokens
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), self.num_draft_tokens * self.num_q_heads, 1
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]
else:
if not self.skip_prefill:
# self.triton_backend.init_forward_metadata_replay_cuda_graph(
# bs,
# req_pool_indices,
# seq_lens,
# seq_lens_sum,
# encoder_lens,
# forward_mode,
# spec_info,
# seq_lens_cpu,
# )
self.flashattn_backend.init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def _call_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor,
block_table: torch.Tensor, cache_seqlens: torch.Tensor,
scaling: float):
o, _ = flash_mla_with_kvcache(
q=reshape_q,
k_cache=k_cache_reshaped,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=scaling,
causal=True,
)
return o
def _call_fp8_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor,
block_table: torch.Tensor, cache_seqlens: torch.Tensor,
scaling: float):
assert _has_flash_mla, "FP8 KV cache 需要flash_mla包"
o, _ = flash_mla_with_kvcache_quantization(
q=reshape_q,
k_cache=k_cache_reshaped,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=scaling,
causal=True,
is_fp8_kvcache=True,
)
return o
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: "RadixAttention",
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
if self.data_type in (
getattr(torch, "float8_e4m3fn", None),
getattr(torch, "float8_e4m3fnuz", None),
getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None),
):
o = self._call_fp8_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32), layer.scaling,
)
else:
o = self._call_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32), layer.scaling,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: "RadixAttention",
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
sinks=None,
):
if (
forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
# flash_attn不支持fp8fp8无法正常执行extend
if not self.skip_prefill:
# return self.triton_backend.forward_extend(
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return self.flashattn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, sinks
)
else:
raise RuntimeError("skip prefill but use forward_extend")
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
if self.data_type in (
getattr(torch, "float8_e4m3fn", None),
getattr(torch, "float8_e4m3fnuz", None),
getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None),
):
o = self._call_fp8_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
layer.scaling,
)
else:
o = self._call_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
layer.scaling,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

View File

@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from sgl_kernel.sparse_flash_attn import ( from sgl_kernel.sparse_flash_attn import (
convert_vertical_slash_indexes, convert_vertical_slash_indexes,
convert_vertical_slash_indexes_mergehead, convert_vertical_slash_indexes_mergehead,

View File

@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel import merge_state_v2 from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass @dataclass

View File

@@ -0,0 +1,94 @@
from flash_attn import (
flash_attn_varlen_func as flash_attn_varlen_func_interface,
flash_attn_with_kvcache as flash_attn_with_kvcache_interface
)
from typing import Optional, Union
import torch
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
attention_chunk: Optional[int] = None,
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
sinks=None,
ver=3,
):
return flash_attn_with_kvcache_interface(
q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]),
k_cache=k_cache,
v_cache=v_cache,
block_table=page_table,
cache_seqlens=cache_seqlens,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
return_softmax_lse=return_softmax_lse,
num_splits=num_splits,
)
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q=None,
max_seqlen_k=None,
seqused_q=None,
seqused_k=None,
page_table=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
num_splits=1,
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
sinks=None,
ver=3,
):
return flash_attn_varlen_func_interface(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=softmax_scale,
causal=causal,
)

View File

@@ -45,7 +45,8 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
) )
else: else:
from sgl_kernel.flash_attn import flash_attn_with_kvcache # from sgl_kernel.flash_attn import flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_with_kvcache
@dataclass(frozen=True) @dataclass(frozen=True)

View File

@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel import merge_state_v2 from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
class XPUAttentionBackend(AttentionBackend): class XPUAttentionBackend(AttentionBackend):

View File

@@ -165,6 +165,7 @@ MLA_ATTENTION_BACKENDS = [
"triton", "triton",
"flashmla", "flashmla",
"cutlass_mla", "cutlass_mla",
"dcu_mla",
"trtllm_mla", "trtllm_mla",
"ascend", "ascend",
"nsa", "nsa",

View File

@@ -342,6 +342,10 @@ def handle_attention_flashmla(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "flashmla") return _handle_attention_backend(attn, forward_batch, "flashmla")
def handle_attention_dcu_mla(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "dcu_mla")
def handle_attention_cutlass_mla(attn, forward_batch): def handle_attention_cutlass_mla(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "cutlass_mla") return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
@@ -3577,6 +3581,7 @@ AttentionBackendRegistry.register("ascend", handle_attention_ascend)
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer) AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
AttentionBackendRegistry.register("fa3", handle_attention_fa3) AttentionBackendRegistry.register("fa3", handle_attention_fa3)
AttentionBackendRegistry.register("flashmla", handle_attention_flashmla) AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
AttentionBackendRegistry.register("dcu_mla", handle_attention_dcu_mla)
AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla) AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
AttentionBackendRegistry.register("fa4", handle_attention_fa4) AttentionBackendRegistry.register("fa4", handle_attention_fa4)
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla) AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)

View File

@@ -396,7 +396,7 @@ class Qwen3GatedDeltaNet(nn.Module):
def _forward_input_proj(self, hidden_states: torch.Tensor): def _forward_input_proj(self, hidden_states: torch.Tensor):
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0 DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0
seq_len, _ = hidden_states.shape seq_len, _ = hidden_states.shape
if seq_len < DUAL_STREAM_TOKEN_THRESHOLD: if seq_len < DUAL_STREAM_TOKEN_THRESHOLD and self.alt_stream is not None:
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream) self.alt_stream.wait_stream(current_stream)
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)

View File

@@ -102,6 +102,8 @@ ATTENTION_BACKEND_CHOICES = [
"torch_native", "torch_native",
"flex_attention", "flex_attention",
"nsa", "nsa",
# ransplant from vllm
"dcu_mla",
# NVIDIA specific # NVIDIA specific
"cutlass_mla", "cutlass_mla",
"fa3", "fa3",
@@ -1077,9 +1079,11 @@ class ServerArgs:
if ( if (
self.attention_backend == "flashmla" self.attention_backend == "flashmla"
or self.decode_attention_backend == "flashmla" or self.decode_attention_backend == "flashmla"
or self.attention_backend == "dcu_mla"
or self.decode_attention_backend == "dcu_mla"
): ):
logger.warning( logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64." "FlashMLA/DCU MLA only supports a page_size of 64, change page_size to 64."
) )
self.page_size = 64 self.page_size = 64