[2/2] Speed up trtllm_mla attention backend (#10474)
This commit is contained in:
@@ -22,7 +22,7 @@ from sglang.srt.layers.attention.utils import (
|
|||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
import flashinfer
|
import flashinfer
|
||||||
@@ -32,6 +32,11 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.speculative.spec_info import SpecInfo
|
from sglang.srt.speculative.spec_info import SpecInfo
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
if _is_cuda:
|
||||||
|
from sgl_kernel import concat_mla_absorb_q
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
||||||
|
|
||||||
@@ -482,6 +487,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
q_rope_reshaped = q_rope.view(
|
q_rope_reshaped = q_rope.view(
|
||||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
)
|
)
|
||||||
|
if _is_cuda and q_nope.shape[-1] == 512 and q_rope_reshaped.shape[-1] == 64:
|
||||||
|
query = concat_mla_absorb_q(q_nope, q_rope_reshaped)
|
||||||
|
else:
|
||||||
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
||||||
else:
|
else:
|
||||||
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
||||||
|
|||||||
Reference in New Issue
Block a user