[Reland] perf: optimize qwen-vl with symm mem allreduce (#11457)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -3,13 +3,13 @@ MiB = 1024 * 1024
|
|||||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||||
9: {
|
9: {
|
||||||
2: 64 * MiB, # 64 MB
|
2: 64 * MiB, # 64 MB
|
||||||
4: 32 * MiB, # 32 MB
|
4: 64 * MiB, # 64 MB
|
||||||
6: 64 * MiB, # 64 MB
|
6: 128 * MiB, # 128 MB
|
||||||
8: 64 * MiB, # 64 MB
|
8: 128 * MiB, # 128 MB
|
||||||
},
|
},
|
||||||
10: {
|
10: {
|
||||||
2: 64 * MiB, # 64 MB
|
2: 64 * MiB, # 64 MB
|
||||||
4: 32 * MiB, # 32 MB
|
4: 64 * MiB, # 64 MB
|
||||||
6: 128 * MiB, # 128 MB
|
6: 128 * MiB, # 128 MB
|
||||||
8: 128 * MiB, # 128 MB
|
8: 128 * MiB, # 128 MB
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -615,8 +615,11 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
|
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
|
||||||
pynccl_comm = self.pynccl_comm
|
pynccl_comm = self.pynccl_comm
|
||||||
|
symm_mem_comm = self.symm_mem_comm
|
||||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||||
pynccl_comm.all_reduce(input_)
|
pynccl_comm.all_reduce(input_)
|
||||||
|
elif symm_mem_comm is not None and not symm_mem_comm.disabled:
|
||||||
|
symm_mem_comm.all_reduce(input_)
|
||||||
else:
|
else:
|
||||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||||
|
|
||||||
|
|||||||
@@ -1008,6 +1008,17 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
|
|||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
|
||||||
|
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
||||||
|
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
|
||||||
|
interleaved [THTHWHTHW...TT], preserving frequency continuity.
|
||||||
|
"""
|
||||||
|
x_t = x[0].clone()
|
||||||
|
x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
|
||||||
|
x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
|
||||||
|
return x_t
|
||||||
|
|
||||||
|
|
||||||
class MRotaryEmbedding(RotaryEmbedding):
|
class MRotaryEmbedding(RotaryEmbedding):
|
||||||
"""Rotary Embedding with Multimodal Sections."""
|
"""Rotary Embedding with Multimodal Sections."""
|
||||||
|
|
||||||
@@ -1020,12 +1031,14 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
mrope_section: Optional[List[int]] = None,
|
mrope_section: Optional[List[int]] = None,
|
||||||
|
mrope_interleaved: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mrope_section = mrope_section
|
self.mrope_section = mrope_section
|
||||||
|
self.mrope_interleaved = mrope_interleaved
|
||||||
if self.mrope_section:
|
if self.mrope_section:
|
||||||
expected_sum = rotary_dim // 2
|
expected_sum = rotary_dim // 2
|
||||||
actual_sum = sum(self.mrope_section)
|
actual_sum = sum(self.mrope_section)
|
||||||
@@ -1086,15 +1099,18 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
if positions.ndim == 2:
|
if positions.ndim == 2:
|
||||||
assert self.mrope_section
|
assert self.mrope_section
|
||||||
|
if self.mrope_interleaved:
|
||||||
cos = torch.cat(
|
cos = apply_interleaved_rope(cos, self.mrope_section)
|
||||||
[m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
|
sin = apply_interleaved_rope(sin, self.mrope_section)
|
||||||
dim=-1,
|
else:
|
||||||
)
|
cos = torch.cat(
|
||||||
sin = torch.cat(
|
[m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
|
||||||
[m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
|
dim=-1,
|
||||||
dim=-1,
|
)
|
||||||
)
|
sin = torch.cat(
|
||||||
|
[m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
query_shape = query.shape
|
query_shape = query.shape
|
||||||
query = query.view(num_tokens, -1, self.head_size)
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
@@ -1768,6 +1784,7 @@ def get_rope(
|
|||||||
is_neox_style,
|
is_neox_style,
|
||||||
dtype,
|
dtype,
|
||||||
mrope_section=rope_scaling["mrope_section"],
|
mrope_section=rope_scaling["mrope_section"],
|
||||||
|
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
rotary_emb = RotaryEmbedding(
|
rotary_emb = RotaryEmbedding(
|
||||||
|
|||||||
Reference in New Issue
Block a user