diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index c89fe809c..5cffd3b97 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -34,7 +34,9 @@ def create_flashinfer_backend(runner): or not runner.plan_stream_for_flashinfer ): runner.plan_stream_for_flashinfer = torch.cuda.Stream() - return FlashInferAttnBackend(runner) + return FlashInferAttnBackend( + runner, init_new_workspace=runner.init_new_workspace + ) else: from sglang.srt.layers.attention.flashinfer_mla_backend import ( FlashInferMLAAttnBackend, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index aeb06bfa9..ab4398b0b 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -118,6 +118,7 @@ class FlashInferAttnBackend(AttentionBackend): skip_prefill: bool = False, kv_indptr_buf: Optional[torch.Tensor] = None, kv_last_page_len_buf: Optional[torch.Tensor] = None, + init_new_workspace: bool = False, ): super().__init__() @@ -192,7 +193,14 @@ class FlashInferAttnBackend(AttentionBackend): dtype=torch.uint8, device=model_runner.device, ) - self.workspace_buffer = global_workspace_buffer + if init_new_workspace: + self.workspace_buffer = torch.empty( + global_config.flashinfer_workspace_size, + dtype=torch.uint8, + device=model_runner.device, + ) + else: + self.workspace_buffer = global_workspace_buffer max_bs = model_runner.req_to_token_pool.size if kv_indptr_buf is None: self.kv_indptr = [ diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4ef8bc993..b1b8b7ff3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -284,6 +284,7 @@ class ModelRunner: self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.attention_chunk_size = model_config.attention_chunk_size self.forward_pass_id = 0 + self.init_new_workspace = False # Apply the rank zero filter to logger if server_args.show_time_cost: