Quick Fix: fix Qwen3-VL launch failure caused by MRotaryEmbedding arg (#10985)
This commit is contained in:
@@ -1065,6 +1065,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
@@ -1075,6 +1076,9 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert (
|
||||
fused_set_kv_buffer_arg is None
|
||||
), "save kv cache is not supported for MRotaryEmbedding."
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
|
||||
@@ -51,7 +51,7 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
|
||||
from sglang.srt.layers.utils import get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -358,6 +358,10 @@ class Qwen3MoeAttention(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.compatible_with_fused_kv_buffer = (
|
||||
False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
|
||||
)
|
||||
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
@@ -427,6 +431,7 @@ class Qwen3MoeAttention(nn.Module):
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
if enable_fused_set_kv_buffer(forward_batch)
|
||||
and self.compatible_with_fused_kv_buffer
|
||||
else None
|
||||
),
|
||||
)
|
||||
@@ -439,7 +444,10 @@ class Qwen3MoeAttention(nn.Module):
|
||||
return hidden_states
|
||||
attn_output = self.attn(
|
||||
*inner_state,
|
||||
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
|
||||
save_kv_cache=not (
|
||||
enable_fused_set_kv_buffer(forward_batch)
|
||||
and self.compatible_with_fused_kv_buffer
|
||||
),
|
||||
)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user