From 61305291430ab9ca1e30436c21511eab703452a3 Mon Sep 17 00:00:00 2001 From: yhyang201 <47235274+yhyang201@users.noreply.github.com> Date: Wed, 1 Oct 2025 13:17:05 +0800 Subject: [PATCH] Quick Fix: fix Qwen3-VL launch failure caused by MRotaryEmbedding arg (#10985) --- python/sglang/srt/layers/rotary_embedding.py | 4 ++++ python/sglang/srt/models/qwen3_moe.py | 12 ++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 2c7267529..91e58f6a0 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1065,6 +1065,7 @@ class MRotaryEmbedding(RotaryEmbedding): positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward(). @@ -1075,6 +1076,9 @@ class MRotaryEmbedding(RotaryEmbedding): query: [num_tokens, num_heads * head_size] key: [num_tokens, num_kv_heads * head_size] """ + assert ( + fused_set_kv_buffer_arg is None + ), "save kv cache is not supported for MRotaryEmbedding." assert positions.ndim == 1 or positions.ndim == 2 num_tokens = positions.shape[-1] diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index d9ac4684e..c4842416c 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -51,7 +51,7 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -358,6 +358,10 @@ class Qwen3MoeAttention(nn.Module): rope_scaling=rope_scaling, dual_chunk_attention_config=dual_chunk_attention_config, ) + self.compatible_with_fused_kv_buffer = ( + False if isinstance(self.rotary_emb, MRotaryEmbedding) else True + ) + self.attn = RadixAttention( self.num_heads, self.head_dim, @@ -427,6 +431,7 @@ class Qwen3MoeAttention(nn.Module): forward_batch=forward_batch, ) if enable_fused_set_kv_buffer(forward_batch) + and self.compatible_with_fused_kv_buffer else None ), ) @@ -439,7 +444,10 @@ class Qwen3MoeAttention(nn.Module): return hidden_states attn_output = self.attn( *inner_state, - save_kv_cache=not enable_fused_set_kv_buffer(forward_batch), + save_kv_cache=not ( + enable_fused_set_kv_buffer(forward_batch) + and self.compatible_with_fused_kv_buffer + ), ) output, _ = self.o_proj(attn_output) return output