[Feature] Reuse flashinfer workspace for PD-Multiplexing. (#11540)
This commit is contained in:
@@ -34,7 +34,9 @@ def create_flashinfer_backend(runner):
|
||||
or not runner.plan_stream_for_flashinfer
|
||||
):
|
||||
runner.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
return FlashInferAttnBackend(runner)
|
||||
return FlashInferAttnBackend(
|
||||
runner, init_new_workspace=runner.init_new_workspace
|
||||
)
|
||||
else:
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAAttnBackend,
|
||||
|
||||
@@ -118,6 +118,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
||||
init_new_workspace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -192,7 +193,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
dtype=torch.uint8,
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.workspace_buffer = global_workspace_buffer
|
||||
if init_new_workspace:
|
||||
self.workspace_buffer = torch.empty(
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device=model_runner.device,
|
||||
)
|
||||
else:
|
||||
self.workspace_buffer = global_workspace_buffer
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
if kv_indptr_buf is None:
|
||||
self.kv_indptr = [
|
||||
|
||||
@@ -284,6 +284,7 @@ class ModelRunner:
|
||||
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
||||
self.attention_chunk_size = model_config.attention_chunk_size
|
||||
self.forward_pass_id = 0
|
||||
self.init_new_workspace = False
|
||||
|
||||
# Apply the rank zero filter to logger
|
||||
if server_args.show_time_cost:
|
||||
|
||||
Reference in New Issue
Block a user