[Feature] Reuse flashinfer workspace for PD-Multiplexing. (#11540)

This commit is contained in:
ykcombat
2025-10-18 02:35:06 +08:00
committed by GitHub
parent 2bc3fcd420
commit f440baa136
3 changed files with 13 additions and 2 deletions

View File

@@ -34,7 +34,9 @@ def create_flashinfer_backend(runner):
or not runner.plan_stream_for_flashinfer or not runner.plan_stream_for_flashinfer
): ):
runner.plan_stream_for_flashinfer = torch.cuda.Stream() runner.plan_stream_for_flashinfer = torch.cuda.Stream()
return FlashInferAttnBackend(runner) return FlashInferAttnBackend(
runner, init_new_workspace=runner.init_new_workspace
)
else: else:
from sglang.srt.layers.attention.flashinfer_mla_backend import ( from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend, FlashInferMLAAttnBackend,

View File

@@ -118,6 +118,7 @@ class FlashInferAttnBackend(AttentionBackend):
skip_prefill: bool = False, skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None, kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None, kv_last_page_len_buf: Optional[torch.Tensor] = None,
init_new_workspace: bool = False,
): ):
super().__init__() super().__init__()
@@ -192,7 +193,14 @@ class FlashInferAttnBackend(AttentionBackend):
dtype=torch.uint8, dtype=torch.uint8,
device=model_runner.device, device=model_runner.device,
) )
self.workspace_buffer = global_workspace_buffer if init_new_workspace:
self.workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
else:
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None: if kv_indptr_buf is None:
self.kv_indptr = [ self.kv_indptr = [

View File

@@ -284,6 +284,7 @@ class ModelRunner:
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
self.attention_chunk_size = model_config.attention_chunk_size self.attention_chunk_size = model_config.attention_chunk_size
self.forward_pass_id = 0 self.forward_pass_id = 0
self.init_new_workspace = False
# Apply the rank zero filter to logger # Apply the rank zero filter to logger
if server_args.show_time_cost: if server_args.show_time_cost: