diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 63a6c9a12..56632092e 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -169,76 +169,6 @@ class RotaryEmbedding(CustomOp): ) return query, key - def forward_xpu( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - from vllm._ipex_ops import ipex_ops as ops - - self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) - ops.rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - ) - return query, key - - def forward_hpu( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - from habana_frameworks.torch.hpex.kernels import ( - RotaryPosEmbeddingMode, - apply_rotary_pos_emb, - ) - - positions = positions.flatten() - if offsets is not None: - positions = positions + offsets - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions).view(num_tokens, 1, -1) - cos, sin = cos_sin.chunk(2, dim=-1) - # HPU RoPE kernel requires hidden dimension for cos and sin to be equal - # to query hidden dimension, so the original tensors need to be - # expanded - # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE - # and expansion of cos/sin tensors via concatenation - # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE - # and expansion of cos/sin tensors via repeat_interleave - rope_mode: RotaryPosEmbeddingMode - if self.is_neox_style: - rope_mode = RotaryPosEmbeddingMode.BLOCKWISE - cos = torch.cat((cos, cos), dim=-1) - sin = torch.cat((sin, sin), dim=-1) - else: - rope_mode = RotaryPosEmbeddingMode.PAIRWISE - sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1]) - cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1]) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key - def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}"