From 08fab2b0c4208151e72fbb7a0455168ec78be033 Mon Sep 17 00:00:00 2001 From: eigen <52445717+yyihuang@users.noreply.github.com> Date: Fri, 8 Aug 2025 03:12:12 -0400 Subject: [PATCH] minor: global workspace buffer for trtllm-gen mha from flashinfer (#8952) --- .../srt/layers/attention/trtllm_mha_backend.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index d9868b307..59bc12219 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -25,6 +25,9 @@ if TYPE_CHECKING: # Constants DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB +# Reuse this workspace buffer across all TRTLLM MHA wrappers +global_workspace_buffer = None + @dataclass class TRTLLMMHAMetadata: @@ -69,9 +72,15 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): # Workspace allocation self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024 - self.workspace_buffer = torch.empty( - self.workspace_size, dtype=torch.int8, device=self.device - ) + # Allocate buffers + global global_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.empty( + self.workspace_size, + dtype=torch.uint8, + device=model_runner.device, + ) + self.workspace_buffer = global_workspace_buffer # CUDA graph state self.decode_cuda_graph_metadata = {}