Revise MRotaryEmbedding's forward (#11859)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: 羽癫 <yudian.zy@antgroup.com>
Co-authored-by: b8zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
Yuan Luo
2025-10-21 10:38:29 +08:00
committed by GitHub
parent 9c0b1eb5ad
commit 74de76c685

View File

@@ -1280,7 +1280,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def forward_native(
def _forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
@@ -1340,6 +1340,27 @@ class MRotaryEmbedding(RotaryEmbedding):
query: torch.Tensor,
key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass with optional Triton kernel acceleration.
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 1 or positions.ndim == 2
if positions.ndim == 2 and self.mrope_section and _is_cuda:
return self._forward_triton(positions, query, key)
else:
return self._forward_native(positions, query, key)
def _forward_triton(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None