[Fix]Fix capture fail bug for DeepSeek (#6275)

This commit is contained in:
Baizhou Zhang
2025-05-21 11:11:20 -07:00
committed by GitHub
parent 55f6005f53
commit d4c038daed
4 changed files with 20 additions and 13 deletions

View File

@@ -47,6 +47,13 @@ from sglang.srt.utils import (
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
# Detect whether the current forward pass is in capture mode
is_capture_mode = False
def get_is_capture_mode():
return is_capture_mode
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values():
@@ -311,17 +318,12 @@ class CudaGraphRunner:
@contextmanager
def model_capture_mode(self):
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = True
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
self.model_runner.token_to_kv_pool.capture_mode = True
global is_capture_mode
is_capture_mode = True
yield
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = False
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
self.model_runner.token_to_kv_pool.capture_mode = False
is_capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm:
@@ -612,6 +614,7 @@ class CudaGraphRunner:
# Replay
self.graphs[self.bs].replay()
output = self.output_buffers[self.bs]
if isinstance(output, LogitsProcessorOutput):
return LogitsProcessorOutput(