minor: global workspace buffer for trtllm-gen mha from flashinfer (#8952)

This commit is contained in:
eigen
2025-08-08 03:12:12 -04:00
committed by GitHub
parent 0d1e27a0c5
commit 08fab2b0c4

View File

@@ -25,6 +25,9 @@ if TYPE_CHECKING:
# Constants
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
# Reuse this workspace buffer across all TRTLLM MHA wrappers
global_workspace_buffer = None
@dataclass
class TRTLLMMHAMetadata:
@@ -69,9 +72,15 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
# Workspace allocation
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
self.workspace_buffer = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
# Allocate buffers
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
self.workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
# CUDA graph state
self.decode_cuda_graph_metadata = {}