fix custom op version compatibility (#2988)
This commit is contained in:
@@ -8,6 +8,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||
|
||||
|
||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
@@ -51,7 +53,7 @@ def _apply_rotary_emb(
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding")
|
||||
@register_custom_op("sglang_rotary_embedding")
|
||||
class RotaryEmbedding(CustomOp):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user