diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index d9868b307..59bc12219 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -25,6 +25,9 @@ if TYPE_CHECKING: # Constants DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB +# Reuse this workspace buffer across all TRTLLM MHA wrappers +global_workspace_buffer = None + @dataclass class TRTLLMMHAMetadata: @@ -69,9 +72,15 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): # Workspace allocation self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024 - self.workspace_buffer = torch.empty( - self.workspace_size, dtype=torch.int8, device=self.device - ) + # Allocate buffers + global global_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.empty( + self.workspace_size, + dtype=torch.uint8, + device=model_runner.device, + ) + self.workspace_buffer = global_workspace_buffer # CUDA graph state self.decode_cuda_graph_metadata = {}