[Fix]Fix capture fail bug for DeepSeek (#6275)
This commit is contained in:
@@ -47,6 +47,13 @@ from sglang.srt.utils import (
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
# Detect whether the current forward pass is in capture mode
|
||||
is_capture_mode = False
|
||||
|
||||
|
||||
def get_is_capture_mode():
|
||||
return is_capture_mode
|
||||
|
||||
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||
for sub in model._modules.values():
|
||||
@@ -311,17 +318,12 @@ class CudaGraphRunner:
|
||||
|
||||
@contextmanager
|
||||
def model_capture_mode(self):
|
||||
if hasattr(self.model_runner.model, "capture_mode"):
|
||||
self.model_runner.model.capture_mode = True
|
||||
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
||||
self.model_runner.token_to_kv_pool.capture_mode = True
|
||||
global is_capture_mode
|
||||
is_capture_mode = True
|
||||
|
||||
yield
|
||||
|
||||
if hasattr(self.model_runner.model, "capture_mode"):
|
||||
self.model_runner.model.capture_mode = False
|
||||
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
||||
self.model_runner.token_to_kv_pool.capture_mode = False
|
||||
is_capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
@@ -612,6 +614,7 @@ class CudaGraphRunner:
|
||||
|
||||
# Replay
|
||||
self.graphs[self.bs].replay()
|
||||
|
||||
output = self.output_buffers[self.bs]
|
||||
if isinstance(output, LogitsProcessorOutput):
|
||||
return LogitsProcessorOutput(
|
||||
|
||||
Reference in New Issue
Block a user