From d4c038daede43544d107f81cb5b6337c7a13803a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 21 May 2025 11:11:20 -0700 Subject: [PATCH] [Fix]Fix capture fail bug for DeepSeek (#6275) --- python/sglang/srt/mem_cache/memory_pool.py | 5 +++-- .../srt/model_executor/cuda_graph_runner.py | 19 +++++++++++-------- python/sglang/srt/models/deepseek_v2.py | 4 +++- python/sglang/srt/models/mllama.py | 5 +++-- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 8c88b9436..79fb1b3b4 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -266,7 +266,6 @@ class MHATokenToKVPool(KVCache): self._create_buffers() self.layer_transfer_counter = None - self.capture_mode = False self.device_module = torch.get_device_module(self.device) self.alt_stream = self.device_module.Stream() if is_cuda else None @@ -385,6 +384,8 @@ class MHATokenToKVPool(KVCache): k_scale: Optional[float] = None, v_scale: Optional[float] = None, ): + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + layer_id = layer.layer_id if cache_k.dtype != self.dtype: if k_scale is not None: @@ -398,7 +399,7 @@ class MHATokenToKVPool(KVCache): cache_k = cache_k.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype) - if self.capture_mode and self.alt_stream is not None: + if get_is_capture_mode() and self.alt_stream is not None: # Overlap the copy of K and V cache for small batch size current_stream = self.device_module.current_stream() self.alt_stream.wait_stream(current_stream) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 40f136deb..308bf92dd 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -47,6 +47,13 @@ from sglang.srt.utils import ( if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner +# Detect whether the current forward pass is in capture mode +is_capture_mode = False + + +def get_is_capture_mode(): + return is_capture_mode + def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): for sub in model._modules.values(): @@ -311,17 +318,12 @@ class CudaGraphRunner: @contextmanager def model_capture_mode(self): - if hasattr(self.model_runner.model, "capture_mode"): - self.model_runner.model.capture_mode = True - if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"): - self.model_runner.token_to_kv_pool.capture_mode = True + global is_capture_mode + is_capture_mode = True yield - if hasattr(self.model_runner.model, "capture_mode"): - self.model_runner.model.capture_mode = False - if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"): - self.model_runner.token_to_kv_pool.capture_mode = False + is_capture_mode = False def can_run(self, forward_batch: ForwardBatch): if self.enable_dp_attention or self.enable_sp_layernorm: @@ -612,6 +614,7 @@ class CudaGraphRunner: # Replay self.graphs[self.bs].replay() + output = self.output_buffers[self.bs] if isinstance(output, LogitsProcessorOutput): return LogitsProcessorOutput( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 62ecce141..8fadf590d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -754,6 +754,8 @@ class DeepseekV2AttentionMLA(nn.Module): forward_batch: ForwardBatch, zero_allocator: BumpAllocator, ) -> torch.Tensor: + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + if self.q_lora_rank is not None: q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 @@ -761,7 +763,7 @@ class DeepseekV2AttentionMLA(nn.Module): k_nope = latent_cache[..., : self.kv_lora_rank] # overlap qk norm - if self.alt_stream is not None and torch.cuda.is_current_stream_capturing(): + if self.alt_stream is not None and get_is_capture_mode(): current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) q = self.q_a_layernorm(q) diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 6439f9327..fed9e4b59 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -836,7 +836,6 @@ class MllamaForConditionalGeneration(nn.Module): prefix="multi_modal_projector", ) self.logits_processor = LogitsProcessor(config.text_config) - self.capture_mode = False def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): pixel_values = torch.cat( @@ -969,6 +968,8 @@ class MllamaForConditionalGeneration(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, ) -> Union[Tuple, CausalLMOutputWithPast]: + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = ( self._batch_image_inputs(forward_batch) ) @@ -977,7 +978,7 @@ class MllamaForConditionalGeneration(nn.Module): cross_attention_mask = None cross_attention_states = None - if self.capture_mode: + if get_is_capture_mode(): # NOTE: when doing cuda graph capture, we do not want to skip cross attention # Make is a constant value to avoid cuda graph capture issue skip_cross_attention = False