[Fix]Fix capture fail bug for DeepSeek (#6275)
This commit is contained in:
@@ -266,7 +266,6 @@ class MHATokenToKVPool(KVCache):
|
||||
self._create_buffers()
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
self.capture_mode = False
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
self.alt_stream = self.device_module.Stream() if is_cuda else None
|
||||
|
||||
@@ -385,6 +384,8 @@ class MHATokenToKVPool(KVCache):
|
||||
k_scale: Optional[float] = None,
|
||||
v_scale: Optional[float] = None,
|
||||
):
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
if k_scale is not None:
|
||||
@@ -398,7 +399,7 @@ class MHATokenToKVPool(KVCache):
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
cache_v = cache_v.view(self.store_dtype)
|
||||
|
||||
if self.capture_mode and self.alt_stream is not None:
|
||||
if get_is_capture_mode() and self.alt_stream is not None:
|
||||
# Overlap the copy of K and V cache for small batch size
|
||||
current_stream = self.device_module.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
|
||||
@@ -47,6 +47,13 @@ from sglang.srt.utils import (
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
# Detect whether the current forward pass is in capture mode
|
||||
is_capture_mode = False
|
||||
|
||||
|
||||
def get_is_capture_mode():
|
||||
return is_capture_mode
|
||||
|
||||
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||
for sub in model._modules.values():
|
||||
@@ -311,17 +318,12 @@ class CudaGraphRunner:
|
||||
|
||||
@contextmanager
|
||||
def model_capture_mode(self):
|
||||
if hasattr(self.model_runner.model, "capture_mode"):
|
||||
self.model_runner.model.capture_mode = True
|
||||
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
||||
self.model_runner.token_to_kv_pool.capture_mode = True
|
||||
global is_capture_mode
|
||||
is_capture_mode = True
|
||||
|
||||
yield
|
||||
|
||||
if hasattr(self.model_runner.model, "capture_mode"):
|
||||
self.model_runner.model.capture_mode = False
|
||||
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
||||
self.model_runner.token_to_kv_pool.capture_mode = False
|
||||
is_capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
@@ -612,6 +614,7 @@ class CudaGraphRunner:
|
||||
|
||||
# Replay
|
||||
self.graphs[self.bs].replay()
|
||||
|
||||
output = self.output_buffers[self.bs]
|
||||
if isinstance(output, LogitsProcessorOutput):
|
||||
return LogitsProcessorOutput(
|
||||
|
||||
@@ -754,6 +754,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
@@ -761,7 +763,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||
|
||||
# overlap qk norm
|
||||
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
|
||||
if self.alt_stream is not None and get_is_capture_mode():
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
q = self.q_a_layernorm(q)
|
||||
|
||||
@@ -836,7 +836,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
prefix="multi_modal_projector",
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.text_config)
|
||||
self.capture_mode = False
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
pixel_values = torch.cat(
|
||||
@@ -969,6 +968,8 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
|
||||
self._batch_image_inputs(forward_batch)
|
||||
)
|
||||
@@ -977,7 +978,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
cross_attention_mask = None
|
||||
cross_attention_states = None
|
||||
|
||||
if self.capture_mode:
|
||||
if get_is_capture_mode():
|
||||
# NOTE: when doing cuda graph capture, we do not want to skip cross attention
|
||||
# Make is a constant value to avoid cuda graph capture issue
|
||||
skip_cross_attention = False
|
||||
|
||||
Reference in New Issue
Block a user