diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 408a66257..b8d62c3fa 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -58,7 +58,6 @@ class TRTLLMMLAPrefillMetadata: class TRTLLMMLADecodeMetadata: """Metadata for TRTLLM MLA decode operations.""" - workspace: Optional[torch.Tensor] = None block_kv_indices: Optional[torch.Tensor] = None max_seq_len: Optional[int] = None @@ -187,9 +186,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): self.decode_cuda_graph_kv_indices = torch.full( (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device ) - self.decode_cuda_graph_workspace = torch.empty( - self.workspace_size, dtype=torch.int8, device=self.device - ) super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf) @@ -240,7 +236,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): max_seq_len_val = int(seq_lens.max().item()) metadata = TRTLLMMLADecodeMetadata( - self.decode_cuda_graph_workspace, block_kv_indices, max_seq_len_val, ) @@ -339,7 +334,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): max_seq_len_val = int(max_seq) self.forward_decode_metadata = TRTLLMMLADecodeMetadata( - self.workspace_buffer, block_kv_indices, max_seq_len_val + block_kv_indices, max_seq_len_val ) forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata else: @@ -513,7 +508,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, kv_cache=kv_cache, - workspace_buffer=metadata.workspace, + workspace_buffer=self.workspace_buffer, qk_nope_head_dim=self.qk_nope_head_dim, kv_lora_rank=self.kv_lora_rank, qk_rope_head_dim=self.qk_rope_head_dim,