keep rotary_embedding only (#2997)

This commit is contained in:
Yineng Zhang
2025-01-20 13:21:36 +08:00
committed by GitHub
parent 1a820e38a2
commit 44a9669770

View File

@@ -144,28 +144,14 @@ class RotaryEmbedding(CustomOp):
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding() ops.rotary_embedding(
# are in-place operations that update the query and key tensors. positions,
if offsets is not None: query,
ops.batched_rotary_embedding( key,
positions, self.head_size,
query, self.cos_sin_cache,
key, self.is_neox_style,
self.head_size, )
self.cos_sin_cache,
self.is_neox_style,
self.rotary_dim,
offsets,
)
else:
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key return query, key
def forward_xpu( def forward_xpu(
@@ -178,28 +164,14 @@ class RotaryEmbedding(CustomOp):
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding() ops.rotary_embedding(
# are in-place operations that update the query and key tensors. positions,
if offsets is not None: query,
ops.batched_rotary_embedding( key,
positions, self.head_size,
query, self.cos_sin_cache,
key, self.is_neox_style,
self.head_size, )
self.cos_sin_cache,
self.is_neox_style,
self.rotary_dim,
offsets,
)
else:
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key return query, key
def forward_hpu( def forward_hpu(