keep rotary_embedding only (#2997)

This commit is contained in:
Yineng Zhang
2025-01-20 13:21:36 +08:00
committed by GitHub
parent 1a820e38a2
commit 44a9669770

View File

@@ -144,28 +144,14 @@ class RotaryEmbedding(CustomOp):
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
self.rotary_dim,
offsets,
)
else:
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
def forward_xpu(
@@ -178,28 +164,14 @@ class RotaryEmbedding(CustomOp):
from vllm._ipex_ops import ipex_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
self.rotary_dim,
offsets,
)
else:
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
def forward_hpu(