perf: optimize qwen-vl with symm mem allreduce (#11381)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
Yuan Luo
2025-10-10 22:24:45 +08:00
committed by GitHub
parent a1a20b4c7c
commit 3b9d97f335
5 changed files with 82 additions and 17 deletions

View File

@@ -1008,6 +1008,17 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
return cache
def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t = x[0].clone()
x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
return x_t
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
@@ -1020,12 +1031,14 @@ class MRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
mrope_interleaved: bool = False,
) -> None:
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
self.mrope_section = mrope_section
self.mrope_interleaved = mrope_interleaved
if self.mrope_section:
expected_sum = rotary_dim // 2
actual_sum = sum(self.mrope_section)
@@ -1086,15 +1099,18 @@ class MRotaryEmbedding(RotaryEmbedding):
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
dim=-1,
)
if self.mrope_interleaved:
cos = apply_interleaved_rope(cos, self.mrope_section)
sin = apply_interleaved_rope(sin, self.mrope_section)
else:
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
dim=-1,
)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
@@ -1773,6 +1789,7 @@ def get_rope(
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
)
else:
rotary_emb = RotaryEmbedding(