remove useless backend forward in rotary_embedding (#4500)
This commit is contained in:
@@ -169,76 +169,6 @@ class RotaryEmbedding(CustomOp):
|
|||||||
)
|
)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
def forward_xpu(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
from vllm._ipex_ops import ipex_ops as ops
|
|
||||||
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
|
|
||||||
ops.rotary_embedding(
|
|
||||||
positions,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
self.head_size,
|
|
||||||
self.cos_sin_cache,
|
|
||||||
self.is_neox_style,
|
|
||||||
)
|
|
||||||
return query, key
|
|
||||||
|
|
||||||
def forward_hpu(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
from habana_frameworks.torch.hpex.kernels import (
|
|
||||||
RotaryPosEmbeddingMode,
|
|
||||||
apply_rotary_pos_emb,
|
|
||||||
)
|
|
||||||
|
|
||||||
positions = positions.flatten()
|
|
||||||
if offsets is not None:
|
|
||||||
positions = positions + offsets
|
|
||||||
num_tokens = positions.shape[0]
|
|
||||||
cos_sin = self.cos_sin_cache.index_select(0, positions).view(num_tokens, 1, -1)
|
|
||||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
|
||||||
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
|
|
||||||
# to query hidden dimension, so the original tensors need to be
|
|
||||||
# expanded
|
|
||||||
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
|
|
||||||
# and expansion of cos/sin tensors via concatenation
|
|
||||||
# GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
|
|
||||||
# and expansion of cos/sin tensors via repeat_interleave
|
|
||||||
rope_mode: RotaryPosEmbeddingMode
|
|
||||||
if self.is_neox_style:
|
|
||||||
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
|
||||||
cos = torch.cat((cos, cos), dim=-1)
|
|
||||||
sin = torch.cat((sin, sin), dim=-1)
|
|
||||||
else:
|
|
||||||
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
|
|
||||||
sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1])
|
|
||||||
cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1])
|
|
||||||
|
|
||||||
query_shape = query.shape
|
|
||||||
query = query.view(num_tokens, -1, self.head_size)
|
|
||||||
query_rot = query[..., : self.rotary_dim]
|
|
||||||
query_pass = query[..., self.rotary_dim :]
|
|
||||||
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
|
||||||
|
|
||||||
key_shape = key.shape
|
|
||||||
key = key.view(num_tokens, -1, self.head_size)
|
|
||||||
key_rot = key[..., : self.rotary_dim]
|
|
||||||
key_pass = key[..., self.rotary_dim :]
|
|
||||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
|
||||||
return query, key
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||||
|
|||||||
Reference in New Issue
Block a user