fix: trtllm-gen attention take zero-init workspace (#10330)
This commit is contained in:
@@ -58,7 +58,6 @@ class TRTLLMMLAPrefillMetadata:
|
|||||||
class TRTLLMMLADecodeMetadata:
|
class TRTLLMMLADecodeMetadata:
|
||||||
"""Metadata for TRTLLM MLA decode operations."""
|
"""Metadata for TRTLLM MLA decode operations."""
|
||||||
|
|
||||||
workspace: Optional[torch.Tensor] = None
|
|
||||||
block_kv_indices: Optional[torch.Tensor] = None
|
block_kv_indices: Optional[torch.Tensor] = None
|
||||||
max_seq_len: Optional[int] = None
|
max_seq_len: Optional[int] = None
|
||||||
|
|
||||||
@@ -187,9 +186,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.decode_cuda_graph_kv_indices = torch.full(
|
self.decode_cuda_graph_kv_indices = torch.full(
|
||||||
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
self.decode_cuda_graph_workspace = torch.empty(
|
|
||||||
self.workspace_size, dtype=torch.int8, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
||||||
|
|
||||||
@@ -240,7 +236,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
max_seq_len_val = int(seq_lens.max().item())
|
max_seq_len_val = int(seq_lens.max().item())
|
||||||
|
|
||||||
metadata = TRTLLMMLADecodeMetadata(
|
metadata = TRTLLMMLADecodeMetadata(
|
||||||
self.decode_cuda_graph_workspace,
|
|
||||||
block_kv_indices,
|
block_kv_indices,
|
||||||
max_seq_len_val,
|
max_seq_len_val,
|
||||||
)
|
)
|
||||||
@@ -339,7 +334,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
|
|
||||||
max_seq_len_val = int(max_seq)
|
max_seq_len_val = int(max_seq)
|
||||||
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
||||||
self.workspace_buffer, block_kv_indices, max_seq_len_val
|
block_kv_indices, max_seq_len_val
|
||||||
)
|
)
|
||||||
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
|
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
|
||||||
else:
|
else:
|
||||||
@@ -513,7 +508,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
||||||
query=query,
|
query=query,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
workspace_buffer=metadata.workspace,
|
workspace_buffer=self.workspace_buffer,
|
||||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||||
kv_lora_rank=self.kv_lora_rank,
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
|
|||||||
Reference in New Issue
Block a user