minor: global workspace buffer for trtllm-gen mha from flashinfer (#8952)
This commit is contained in:
@@ -25,6 +25,9 @@ if TYPE_CHECKING:
|
|||||||
# Constants
|
# Constants
|
||||||
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
||||||
|
|
||||||
|
# Reuse this workspace buffer across all TRTLLM MHA wrappers
|
||||||
|
global_workspace_buffer = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TRTLLMMHAMetadata:
|
class TRTLLMMHAMetadata:
|
||||||
@@ -69,9 +72,15 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|||||||
|
|
||||||
# Workspace allocation
|
# Workspace allocation
|
||||||
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
||||||
self.workspace_buffer = torch.empty(
|
# Allocate buffers
|
||||||
self.workspace_size, dtype=torch.int8, device=self.device
|
global global_workspace_buffer
|
||||||
)
|
if global_workspace_buffer is None:
|
||||||
|
global_workspace_buffer = torch.empty(
|
||||||
|
self.workspace_size,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=model_runner.device,
|
||||||
|
)
|
||||||
|
self.workspace_buffer = global_workspace_buffer
|
||||||
|
|
||||||
# CUDA graph state
|
# CUDA graph state
|
||||||
self.decode_cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user