[2/2] Speed up trtllm_mla attention backend (#10474)
This commit is contained in:
@@ -22,7 +22,7 @@ from sglang.srt.layers.attention.utils import (
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
||||
|
||||
if is_flashinfer_available():
|
||||
import flashinfer
|
||||
@@ -32,6 +32,11 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import concat_mla_absorb_q
|
||||
|
||||
# Constants
|
||||
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
||||
|
||||
@@ -482,7 +487,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
q_rope_reshaped = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
||||
if _is_cuda and q_nope.shape[-1] == 512 and q_rope_reshaped.shape[-1] == 64:
|
||||
query = concat_mla_absorb_q(q_nope, q_rope_reshaped)
|
||||
else:
|
||||
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
||||
else:
|
||||
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
Reference in New Issue
Block a user