Support overlap-spec-v2 with trtllm_mla attention backend (#11821)
This commit is contained in:
@@ -24,6 +24,7 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
||||
from sglang.srt.utils.common import cached_triton_kernel
|
||||
|
||||
if is_flashinfer_available():
|
||||
import flashinfer
|
||||
@@ -50,6 +51,7 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
||||
TRTLLM_BLOCK_CONSTRAINT = 128
|
||||
|
||||
|
||||
@cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
|
||||
@triton.jit
|
||||
def pad_draft_extend_query_kernel(
|
||||
q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim]
|
||||
@@ -123,6 +125,7 @@ def pad_draft_extend_query_kernel(
|
||||
)
|
||||
|
||||
|
||||
@cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
|
||||
@triton.jit
|
||||
def unpad_draft_extend_output_kernel(
|
||||
raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
|
||||
@@ -389,7 +392,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
if (
|
||||
not forward_mode.is_decode_or_idle()
|
||||
and not forward_mode.is_target_verify()
|
||||
and not forward_mode.is_draft_extend()
|
||||
and not forward_mode.is_draft_extend(include_v2=True)
|
||||
):
|
||||
return super().init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
@@ -429,7 +432,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
block_kv_indices,
|
||||
max_seq_len_val,
|
||||
)
|
||||
if forward_mode.is_draft_extend():
|
||||
if forward_mode.is_draft_extend(include_v2=True):
|
||||
num_tokens_per_bs = num_tokens // bs
|
||||
metadata.max_seq_len_q = num_tokens_per_bs + 1
|
||||
metadata.sum_seq_lens_q = num_tokens_per_bs * bs
|
||||
@@ -462,7 +465,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
if (
|
||||
not forward_mode.is_decode_or_idle()
|
||||
and not forward_mode.is_target_verify()
|
||||
and not forward_mode.is_draft_extend()
|
||||
and not forward_mode.is_draft_extend(include_v2=True)
|
||||
):
|
||||
return super().init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
@@ -481,7 +484,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
|
||||
if forward_mode.is_draft_extend():
|
||||
if forward_mode.is_draft_extend(include_v2=True):
|
||||
accept_length = spec_info.accept_length[:bs]
|
||||
if spec_info.accept_length_cpu:
|
||||
metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs])
|
||||
@@ -523,7 +526,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
|
||||
):
|
||||
if self.disable_chunked_prefix_cache:
|
||||
super().init_forward_metadata(forward_batch)
|
||||
@@ -544,7 +547,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
elif (
|
||||
forward_batch.forward_mode.is_decode_or_idle()
|
||||
or forward_batch.forward_mode.is_target_verify()
|
||||
or forward_batch.forward_mode.is_draft_extend()
|
||||
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
|
||||
):
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
@@ -573,7 +576,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
||||
block_kv_indices, max_seq_len_val
|
||||
)
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
if forward_batch.forward_mode.is_draft_extend(include_v2=True):
|
||||
max_seq = forward_batch.seq_lens_cpu.max().item()
|
||||
|
||||
sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
|
||||
@@ -922,7 +925,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
if (
|
||||
forward_batch.forward_mode.is_target_verify()
|
||||
or forward_batch.forward_mode.is_draft_extend()
|
||||
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
|
||||
):
|
||||
metadata = (
|
||||
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
||||
@@ -994,7 +997,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
# Reshape output directly without slicing
|
||||
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
if forward_batch.forward_mode.is_draft_extend(include_v2=True):
|
||||
raw_out = self.unpad_draft_extend_output(
|
||||
raw_out,
|
||||
metadata.cu_seqlens_q,
|
||||
|
||||
Reference in New Issue
Block a user