fix: trtllm-gen attention take zero-init workspace (#10330)
This commit is contained in:
@@ -58,7 +58,6 @@ class TRTLLMMLAPrefillMetadata:
|
||||
class TRTLLMMLADecodeMetadata:
|
||||
"""Metadata for TRTLLM MLA decode operations."""
|
||||
|
||||
workspace: Optional[torch.Tensor] = None
|
||||
block_kv_indices: Optional[torch.Tensor] = None
|
||||
max_seq_len: Optional[int] = None
|
||||
|
||||
@@ -187,9 +186,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
self.decode_cuda_graph_kv_indices = torch.full(
|
||||
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.decode_cuda_graph_workspace = torch.empty(
|
||||
self.workspace_size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
|
||||
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
||||
|
||||
@@ -240,7 +236,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
max_seq_len_val = int(seq_lens.max().item())
|
||||
|
||||
metadata = TRTLLMMLADecodeMetadata(
|
||||
self.decode_cuda_graph_workspace,
|
||||
block_kv_indices,
|
||||
max_seq_len_val,
|
||||
)
|
||||
@@ -339,7 +334,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
max_seq_len_val = int(max_seq)
|
||||
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
||||
self.workspace_buffer, block_kv_indices, max_seq_len_val
|
||||
block_kv_indices, max_seq_len_val
|
||||
)
|
||||
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
|
||||
else:
|
||||
@@ -513,7 +508,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=metadata.workspace,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
|
||||
Reference in New Issue
Block a user