remove useless backend forward in rotary_embedding (#4500)
This commit is contained in:
@@ -169,76 +169,6 @@ class RotaryEmbedding(CustomOp):
|
||||
)
|
||||
return query, key
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
|
||||
ops.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
|
||||
def forward_hpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
positions = positions.flatten()
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions).view(num_tokens, 1, -1)
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
|
||||
# to query hidden dimension, so the original tensors need to be
|
||||
# expanded
|
||||
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
|
||||
# and expansion of cos/sin tensors via concatenation
|
||||
# GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
|
||||
# and expansion of cos/sin tensors via repeat_interleave
|
||||
rope_mode: RotaryPosEmbeddingMode
|
||||
if self.is_neox_style:
|
||||
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
||||
cos = torch.cat((cos, cos), dim=-1)
|
||||
sin = torch.cat((sin, sin), dim=-1)
|
||||
else:
|
||||
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
|
||||
sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1])
|
||||
cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1])
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||
|
||||
Reference in New Issue
Block a user