[Fix]Fix capture fail bug for DeepSeek (#6275)

This commit is contained in:
Baizhou Zhang
2025-05-21 11:11:20 -07:00
committed by GitHub
parent 55f6005f53
commit d4c038daed
4 changed files with 20 additions and 13 deletions

View File

@@ -754,6 +754,8 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if self.q_lora_rank is not None:
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
@@ -761,7 +763,7 @@ class DeepseekV2AttentionMLA(nn.Module):
k_nope = latent_cache[..., : self.kv_lora_rank]
# overlap qk norm
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q = self.q_a_layernorm(q)