Quick Fix: fix Qwen3-VL launch failure caused by MRotaryEmbedding arg (#10985)
This commit is contained in:
@@ -1065,6 +1065,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
|
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""PyTorch-native implementation equivalent to forward().
|
"""PyTorch-native implementation equivalent to forward().
|
||||||
|
|
||||||
@@ -1075,6 +1076,9 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
query: [num_tokens, num_heads * head_size]
|
query: [num_tokens, num_heads * head_size]
|
||||||
key: [num_tokens, num_kv_heads * head_size]
|
key: [num_tokens, num_kv_heads * head_size]
|
||||||
"""
|
"""
|
||||||
|
assert (
|
||||||
|
fused_set_kv_buffer_arg is None
|
||||||
|
), "save kv cache is not supported for MRotaryEmbedding."
|
||||||
assert positions.ndim == 1 or positions.ndim == 2
|
assert positions.ndim == 1 or positions.ndim == 2
|
||||||
|
|
||||||
num_tokens = positions.shape[-1]
|
num_tokens = positions.shape[-1]
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
|
||||||
from sglang.srt.layers.utils import get_layer_id
|
from sglang.srt.layers.utils import get_layer_id
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
@@ -358,6 +358,10 @@ class Qwen3MoeAttention(nn.Module):
|
|||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||||
)
|
)
|
||||||
|
self.compatible_with_fused_kv_buffer = (
|
||||||
|
False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
|
||||||
|
)
|
||||||
|
|
||||||
self.attn = RadixAttention(
|
self.attn = RadixAttention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -427,6 +431,7 @@ class Qwen3MoeAttention(nn.Module):
|
|||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
if enable_fused_set_kv_buffer(forward_batch)
|
if enable_fused_set_kv_buffer(forward_batch)
|
||||||
|
and self.compatible_with_fused_kv_buffer
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -439,7 +444,10 @@ class Qwen3MoeAttention(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
attn_output = self.attn(
|
attn_output = self.attn(
|
||||||
*inner_state,
|
*inner_state,
|
||||||
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
|
save_kv_cache=not (
|
||||||
|
enable_fused_set_kv_buffer(forward_batch)
|
||||||
|
and self.compatible_with_fused_kv_buffer
|
||||||
|
),
|
||||||
)
|
)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|||||||
Reference in New Issue
Block a user