Revise MRotaryEmbedding's forward (#11859)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: 羽癫 <yudian.zy@antgroup.com> Co-authored-by: b8zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
@@ -1280,7 +1280,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
def forward_native(
|
def _forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -1340,6 +1340,27 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Forward pass with optional Triton kernel acceleration.
|
||||||
|
Args:
|
||||||
|
positions:
|
||||||
|
[num_tokens,] (text only) or
|
||||||
|
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||||
|
query: [num_tokens, num_heads * head_size]
|
||||||
|
key: [num_tokens, num_kv_heads * head_size]
|
||||||
|
"""
|
||||||
|
assert positions.ndim == 1 or positions.ndim == 2
|
||||||
|
|
||||||
|
if positions.ndim == 2 and self.mrope_section and _is_cuda:
|
||||||
|
return self._forward_triton(positions, query, key)
|
||||||
|
else:
|
||||||
|
return self._forward_native(positions, query, key)
|
||||||
|
|
||||||
|
def _forward_triton(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert positions.ndim == 1 or positions.ndim == 2
|
assert positions.ndim == 1 or positions.ndim == 2
|
||||||
assert key is not None
|
assert key is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user