[Fix]Fix capture fail bug for DeepSeek (#6275)
This commit is contained in:
@@ -754,6 +754,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
@@ -761,7 +763,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||
|
||||
# overlap qk norm
|
||||
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
|
||||
if self.alt_stream is not None and get_is_capture_mode():
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
q = self.q_a_layernorm(q)
|
||||
|
||||
Reference in New Issue
Block a user