From 44a966977083f3a7d7cc2a268f46a63e76d049a8 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 20 Jan 2025 13:21:36 +0800 Subject: [PATCH] keep rotary_embedding only (#2997) --- python/sglang/srt/layers/rotary_embedding.py | 60 ++++++-------------- 1 file changed, 16 insertions(+), 44 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 964152905..43478f39d 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -144,28 +144,14 @@ class RotaryEmbedding(CustomOp): from vllm import _custom_ops as ops self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - ops.batched_rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - self.rotary_dim, - offsets, - ) - else: - ops.rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - ) + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) return query, key def forward_xpu( @@ -178,28 +164,14 @@ class RotaryEmbedding(CustomOp): from vllm._ipex_ops import ipex_ops as ops self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - ops.batched_rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - self.rotary_dim, - offsets, - ) - else: - ops.rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - ) + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) return query, key def forward_hpu(