Revise MRotaryEmbedding's forward (#11859)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: 羽癫 <yudian.zy@antgroup.com> Co-authored-by: b8zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
@@ -1280,7 +1280,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def forward_native(
|
||||
def _forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
@@ -1340,6 +1340,27 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass with optional Triton kernel acceleration.
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
|
||||
if positions.ndim == 2 and self.mrope_section and _is_cuda:
|
||||
return self._forward_triton(positions, query, key)
|
||||
else:
|
||||
return self._forward_native(positions, query, key)
|
||||
|
||||
def _forward_triton(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
Reference in New Issue
Block a user