Support overlap-spec-v2 with trtllm_mla attention backend (#11821)
This commit is contained in:
@@ -24,6 +24,7 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.server_args import get_global_server_args
|
from sglang.srt.server_args import get_global_server_args
|
||||||
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
||||||
|
from sglang.srt.utils.common import cached_triton_kernel
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
import flashinfer
|
import flashinfer
|
||||||
@@ -50,6 +51,7 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
|||||||
TRTLLM_BLOCK_CONSTRAINT = 128
|
TRTLLM_BLOCK_CONSTRAINT = 128
|
||||||
|
|
||||||
|
|
||||||
|
@cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def pad_draft_extend_query_kernel(
|
def pad_draft_extend_query_kernel(
|
||||||
q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim]
|
q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim]
|
||||||
@@ -123,6 +125,7 @@ def pad_draft_extend_query_kernel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def unpad_draft_extend_output_kernel(
|
def unpad_draft_extend_output_kernel(
|
||||||
raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
|
raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
|
||||||
@@ -389,7 +392,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
if (
|
if (
|
||||||
not forward_mode.is_decode_or_idle()
|
not forward_mode.is_decode_or_idle()
|
||||||
and not forward_mode.is_target_verify()
|
and not forward_mode.is_target_verify()
|
||||||
and not forward_mode.is_draft_extend()
|
and not forward_mode.is_draft_extend(include_v2=True)
|
||||||
):
|
):
|
||||||
return super().init_forward_metadata_capture_cuda_graph(
|
return super().init_forward_metadata_capture_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
@@ -429,7 +432,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
block_kv_indices,
|
block_kv_indices,
|
||||||
max_seq_len_val,
|
max_seq_len_val,
|
||||||
)
|
)
|
||||||
if forward_mode.is_draft_extend():
|
if forward_mode.is_draft_extend(include_v2=True):
|
||||||
num_tokens_per_bs = num_tokens // bs
|
num_tokens_per_bs = num_tokens // bs
|
||||||
metadata.max_seq_len_q = num_tokens_per_bs + 1
|
metadata.max_seq_len_q = num_tokens_per_bs + 1
|
||||||
metadata.sum_seq_lens_q = num_tokens_per_bs * bs
|
metadata.sum_seq_lens_q = num_tokens_per_bs * bs
|
||||||
@@ -462,7 +465,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
if (
|
if (
|
||||||
not forward_mode.is_decode_or_idle()
|
not forward_mode.is_decode_or_idle()
|
||||||
and not forward_mode.is_target_verify()
|
and not forward_mode.is_target_verify()
|
||||||
and not forward_mode.is_draft_extend()
|
and not forward_mode.is_draft_extend(include_v2=True)
|
||||||
):
|
):
|
||||||
return super().init_forward_metadata_replay_cuda_graph(
|
return super().init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
@@ -481,7 +484,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
|
|
||||||
metadata = self.decode_cuda_graph_metadata[bs]
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
|
|
||||||
if forward_mode.is_draft_extend():
|
if forward_mode.is_draft_extend(include_v2=True):
|
||||||
accept_length = spec_info.accept_length[:bs]
|
accept_length = spec_info.accept_length[:bs]
|
||||||
if spec_info.accept_length_cpu:
|
if spec_info.accept_length_cpu:
|
||||||
metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs])
|
metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs])
|
||||||
@@ -523,7 +526,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
if (
|
if (
|
||||||
forward_batch.forward_mode.is_extend()
|
forward_batch.forward_mode.is_extend()
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
|
||||||
):
|
):
|
||||||
if self.disable_chunked_prefix_cache:
|
if self.disable_chunked_prefix_cache:
|
||||||
super().init_forward_metadata(forward_batch)
|
super().init_forward_metadata(forward_batch)
|
||||||
@@ -544,7 +547,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
elif (
|
elif (
|
||||||
forward_batch.forward_mode.is_decode_or_idle()
|
forward_batch.forward_mode.is_decode_or_idle()
|
||||||
or forward_batch.forward_mode.is_target_verify()
|
or forward_batch.forward_mode.is_target_verify()
|
||||||
or forward_batch.forward_mode.is_draft_extend()
|
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
|
||||||
):
|
):
|
||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
|
|
||||||
@@ -573,7 +576,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
||||||
block_kv_indices, max_seq_len_val
|
block_kv_indices, max_seq_len_val
|
||||||
)
|
)
|
||||||
if forward_batch.forward_mode.is_draft_extend():
|
if forward_batch.forward_mode.is_draft_extend(include_v2=True):
|
||||||
max_seq = forward_batch.seq_lens_cpu.max().item()
|
max_seq = forward_batch.seq_lens_cpu.max().item()
|
||||||
|
|
||||||
sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
|
sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
|
||||||
@@ -922,7 +925,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
forward_batch.forward_mode.is_target_verify()
|
forward_batch.forward_mode.is_target_verify()
|
||||||
or forward_batch.forward_mode.is_draft_extend()
|
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
|
||||||
):
|
):
|
||||||
metadata = (
|
metadata = (
|
||||||
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
||||||
@@ -994,7 +997,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
|
|
||||||
# Reshape output directly without slicing
|
# Reshape output directly without slicing
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_draft_extend():
|
if forward_batch.forward_mode.is_draft_extend(include_v2=True):
|
||||||
raw_out = self.unpad_draft_extend_output(
|
raw_out = self.unpad_draft_extend_output(
|
||||||
raw_out,
|
raw_out,
|
||||||
metadata.cu_seqlens_q,
|
metadata.cu_seqlens_q,
|
||||||
|
|||||||
Reference in New Issue
Block a user