[Fix]Fix capture fail bug for DeepSeek (#6275)
This commit is contained in:
@@ -266,7 +266,6 @@ class MHATokenToKVPool(KVCache):
|
|||||||
self._create_buffers()
|
self._create_buffers()
|
||||||
|
|
||||||
self.layer_transfer_counter = None
|
self.layer_transfer_counter = None
|
||||||
self.capture_mode = False
|
|
||||||
self.device_module = torch.get_device_module(self.device)
|
self.device_module = torch.get_device_module(self.device)
|
||||||
self.alt_stream = self.device_module.Stream() if is_cuda else None
|
self.alt_stream = self.device_module.Stream() if is_cuda else None
|
||||||
|
|
||||||
@@ -385,6 +384,8 @@ class MHATokenToKVPool(KVCache):
|
|||||||
k_scale: Optional[float] = None,
|
k_scale: Optional[float] = None,
|
||||||
v_scale: Optional[float] = None,
|
v_scale: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
|
|
||||||
layer_id = layer.layer_id
|
layer_id = layer.layer_id
|
||||||
if cache_k.dtype != self.dtype:
|
if cache_k.dtype != self.dtype:
|
||||||
if k_scale is not None:
|
if k_scale is not None:
|
||||||
@@ -398,7 +399,7 @@ class MHATokenToKVPool(KVCache):
|
|||||||
cache_k = cache_k.view(self.store_dtype)
|
cache_k = cache_k.view(self.store_dtype)
|
||||||
cache_v = cache_v.view(self.store_dtype)
|
cache_v = cache_v.view(self.store_dtype)
|
||||||
|
|
||||||
if self.capture_mode and self.alt_stream is not None:
|
if get_is_capture_mode() and self.alt_stream is not None:
|
||||||
# Overlap the copy of K and V cache for small batch size
|
# Overlap the copy of K and V cache for small batch size
|
||||||
current_stream = self.device_module.current_stream()
|
current_stream = self.device_module.current_stream()
|
||||||
self.alt_stream.wait_stream(current_stream)
|
self.alt_stream.wait_stream(current_stream)
|
||||||
|
|||||||
@@ -47,6 +47,13 @@ from sglang.srt.utils import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
# Detect whether the current forward pass is in capture mode
|
||||||
|
is_capture_mode = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_is_capture_mode():
|
||||||
|
return is_capture_mode
|
||||||
|
|
||||||
|
|
||||||
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||||
for sub in model._modules.values():
|
for sub in model._modules.values():
|
||||||
@@ -311,17 +318,12 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def model_capture_mode(self):
|
def model_capture_mode(self):
|
||||||
if hasattr(self.model_runner.model, "capture_mode"):
|
global is_capture_mode
|
||||||
self.model_runner.model.capture_mode = True
|
is_capture_mode = True
|
||||||
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
|
||||||
self.model_runner.token_to_kv_pool.capture_mode = True
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
if hasattr(self.model_runner.model, "capture_mode"):
|
is_capture_mode = False
|
||||||
self.model_runner.model.capture_mode = False
|
|
||||||
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
|
||||||
self.model_runner.token_to_kv_pool.capture_mode = False
|
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
@@ -612,6 +614,7 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
self.graphs[self.bs].replay()
|
self.graphs[self.bs].replay()
|
||||||
|
|
||||||
output = self.output_buffers[self.bs]
|
output = self.output_buffers[self.bs]
|
||||||
if isinstance(output, LogitsProcessorOutput):
|
if isinstance(output, LogitsProcessorOutput):
|
||||||
return LogitsProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
|
|||||||
@@ -754,6 +754,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
zero_allocator: BumpAllocator,
|
zero_allocator: BumpAllocator,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
|
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
||||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||||
@@ -761,7 +763,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
k_nope = latent_cache[..., : self.kv_lora_rank]
|
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||||
|
|
||||||
# overlap qk norm
|
# overlap qk norm
|
||||||
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
|
if self.alt_stream is not None and get_is_capture_mode():
|
||||||
current_stream = torch.cuda.current_stream()
|
current_stream = torch.cuda.current_stream()
|
||||||
self.alt_stream.wait_stream(current_stream)
|
self.alt_stream.wait_stream(current_stream)
|
||||||
q = self.q_a_layernorm(q)
|
q = self.q_a_layernorm(q)
|
||||||
|
|||||||
@@ -836,7 +836,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
prefix="multi_modal_projector",
|
prefix="multi_modal_projector",
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config.text_config)
|
self.logits_processor = LogitsProcessor(config.text_config)
|
||||||
self.capture_mode = False
|
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||||
pixel_values = torch.cat(
|
pixel_values = torch.cat(
|
||||||
@@ -969,6 +968,8 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
|
|
||||||
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
|
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
|
||||||
self._batch_image_inputs(forward_batch)
|
self._batch_image_inputs(forward_batch)
|
||||||
)
|
)
|
||||||
@@ -977,7 +978,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
cross_attention_mask = None
|
cross_attention_mask = None
|
||||||
cross_attention_states = None
|
cross_attention_states = None
|
||||||
|
|
||||||
if self.capture_mode:
|
if get_is_capture_mode():
|
||||||
# NOTE: when doing cuda graph capture, we do not want to skip cross attention
|
# NOTE: when doing cuda graph capture, we do not want to skip cross attention
|
||||||
# Make is a constant value to avoid cuda graph capture issue
|
# Make is a constant value to avoid cuda graph capture issue
|
||||||
skip_cross_attention = False
|
skip_cross_attention = False
|
||||||
|
|||||||
Reference in New Issue
Block a user