keep rotary_embedding only (#2997)
This commit is contained in:
@@ -144,20 +144,6 @@ class RotaryEmbedding(CustomOp):
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
|
||||||
# are in-place operations that update the query and key tensors.
|
|
||||||
if offsets is not None:
|
|
||||||
ops.batched_rotary_embedding(
|
|
||||||
positions,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
self.head_size,
|
|
||||||
self.cos_sin_cache,
|
|
||||||
self.is_neox_style,
|
|
||||||
self.rotary_dim,
|
|
||||||
offsets,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ops.rotary_embedding(
|
ops.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
@@ -178,20 +164,6 @@ class RotaryEmbedding(CustomOp):
|
|||||||
from vllm._ipex_ops import ipex_ops as ops
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
|
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
|
||||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
|
||||||
# are in-place operations that update the query and key tensors.
|
|
||||||
if offsets is not None:
|
|
||||||
ops.batched_rotary_embedding(
|
|
||||||
positions,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
self.head_size,
|
|
||||||
self.cos_sin_cache,
|
|
||||||
self.is_neox_style,
|
|
||||||
self.rotary_dim,
|
|
||||||
offsets,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ops.rotary_embedding(
|
ops.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
|
|||||||
Reference in New Issue
Block a user