Co-authored-by: averyhuang <averyh@nvidia.com>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Support attention backend for TRTLLM MLA kernels from flashinfer.
|
||||
Support attention backend for TRTLLM MHA kernels from flashinfer.
|
||||
The kernel supports sm100 only, with sliding window and attention sink features.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
@@ -57,11 +58,6 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
|
||||
# MHA-specific dimensions
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.sliding_window_size = (
|
||||
model_runner.sliding_window_size
|
||||
if model_runner.sliding_window_size is not None
|
||||
else -1 # -1 indicates full attention
|
||||
)
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
# Runtime parameters
|
||||
@@ -117,10 +113,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
metadata = TRTLLMMHAMetadata()
|
||||
|
||||
# Get sequence information
|
||||
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
|
||||
|
||||
# Precompute maximum sequence length
|
||||
metadata.max_seq_len_k = seq_lens.max().item()
|
||||
metadata.max_seq_len_k = self.max_context_len
|
||||
|
||||
# Precompute page table
|
||||
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
||||
@@ -149,7 +145,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
||||
metadata.max_seq_len_k = max_len
|
||||
metadata.max_seq_len_k = self.max_context_len
|
||||
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
page_indices = self.req_to_token[
|
||||
@@ -217,6 +213,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Run forward for decode using TRTLLM MHA kernel."""
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
@@ -228,7 +225,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
# shape conversion:
|
||||
# [bs, page_size, num_kv_heads, head_dim] -> [bs, num_kv_heads, page_size, head_dim]
|
||||
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
|
||||
k_cache = k_cache.view(
|
||||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||||
).permute(0, 2, 1, 3)
|
||||
@@ -237,7 +234,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
).permute(0, 2, 1, 3)
|
||||
kv_cache = (k_cache, v_cache)
|
||||
|
||||
# TODO: bmm1_scale and bmm2_scale might require modification
|
||||
# TODO: add support for quantization
|
||||
q_scale = 1.0
|
||||
k_scale = (
|
||||
layer.k_scale_float
|
||||
@@ -246,6 +243,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
)
|
||||
bmm1_scale = q_scale * k_scale * layer.scaling
|
||||
bmm2_scale = 1.0
|
||||
# sink: additional value per head in the denominator of the softmax.
|
||||
attention_sink = kwargs.get("sinks", None)
|
||||
|
||||
# Call TRT-LLM kernel
|
||||
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
|
||||
@@ -258,8 +257,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
max_seq_len=self.forward_metadata.max_seq_len_k,
|
||||
bmm1_scale=bmm1_scale,
|
||||
bmm2_scale=bmm2_scale,
|
||||
window_left=self.sliding_window_size,
|
||||
window_left=layer.sliding_window_size,
|
||||
# TODO: add attention_sink operation or nvfp4 scale factor if needed
|
||||
sinks=attention_sink,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
@@ -272,6 +272,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
if save_kv_cache and k is not None:
|
||||
@@ -279,6 +280,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
|
||||
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
k_cache = k_cache.view(
|
||||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||||
@@ -288,8 +290,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
).permute(0, 2, 1, 3)
|
||||
kv_cache = (k_cache, v_cache)
|
||||
|
||||
# TODO: bmm1_scale and bmm2_scale might require modification
|
||||
# TODO: Change once quantization is supported
|
||||
# sink: additional value per head in the denominator of the softmax.
|
||||
attention_sink = kwargs.get("sinks", None)
|
||||
# TODO: add support for quantization
|
||||
q_scale = 1.0
|
||||
k_scale = (
|
||||
layer.k_scale_float
|
||||
@@ -312,8 +315,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
batch_size=forward_batch.batch_size,
|
||||
cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
|
||||
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
|
||||
window_left=self.sliding_window_size,
|
||||
window_left=layer.sliding_window_size,
|
||||
# TODO: add attention_sink operation or nvfp4 scale factor if needed
|
||||
sinks=attention_sink,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
@@ -1443,13 +1443,13 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
return CutlassMLABackend(self)
|
||||
elif self.server_args.attention_backend == "trtllm_mla":
|
||||
elif backend_str == "trtllm_mla":
|
||||
if not self.use_mla_backend:
|
||||
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
||||
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
||||
|
||||
return TRTLLMMLABackend(self)
|
||||
elif self.server_args.attention_backend == "trtllm_mha":
|
||||
elif backend_str == "trtllm_mha":
|
||||
if self.use_mla_backend:
|
||||
raise ValueError(
|
||||
"trtllm_mha backend can only be used with non-MLA models."
|
||||
@@ -1460,7 +1460,7 @@ class ModelRunner:
|
||||
|
||||
return TRTLLMHAAttnBackend(self)
|
||||
|
||||
elif self.server_args.attention_backend == "intel_amx":
|
||||
elif backend_str == "intel_amx":
|
||||
from sglang.srt.layers.attention.intel_amx_backend import (
|
||||
IntelAMXAttnBackend,
|
||||
)
|
||||
|
||||
@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
|
||||
hidden_states, forward_batch, inner_state = intermediate_state
|
||||
if inner_state is None:
|
||||
return hidden_states
|
||||
attn_output = self.attn(*inner_state, sinks=self.sinks)
|
||||
attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32))
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@@ -445,7 +445,11 @@ class ServerArgs:
|
||||
"trtllm_mla backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
if self.attention_backend == "trtllm_mha":
|
||||
if (
|
||||
self.attention_backend == "trtllm_mha"
|
||||
or self.decode_attention_backend == "trtllm_mha"
|
||||
or self.prefill_attention_backend == "trtllm_mha"
|
||||
):
|
||||
if not is_sm100_supported():
|
||||
raise ValueError(
|
||||
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
|
||||
@@ -459,11 +463,18 @@ class ServerArgs:
|
||||
|
||||
if self.speculative_algorithm is not None:
|
||||
raise ValueError(
|
||||
"trtllm_mla backend does not support speculative decoding yet."
|
||||
"trtllm_mha backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
model_arch = self.get_hf_config().architectures[0]
|
||||
if model_arch in ["GptOssForCausalLM"]:
|
||||
self.attention_backend = "triton"
|
||||
if self.attention_backend is None:
|
||||
# default is triton, but we could have trtllm_mha as an option
|
||||
self.attention_backend = "triton"
|
||||
assert (
|
||||
self.attention_backend == "trtllm_mha"
|
||||
or self.attention_backend == "triton"
|
||||
)
|
||||
|
||||
# Check if FlashInfer MXFP4 MoE is enabled
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
Reference in New Issue
Block a user