fix custom op version compatibility (#2988)
This commit is contained in:
@@ -27,7 +27,7 @@ runtime_common = [
|
||||
]
|
||||
srt = [
|
||||
"sglang[runtime_common]", "cuda-python",
|
||||
"sgl-kernel>=0.0.2.post14", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
|
||||
"sgl-kernel>=0.0.2.post14", "torch", "vllm==0.6.4.post1",
|
||||
"flashinfer==0.1.6"
|
||||
]
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||
|
||||
|
||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
@@ -51,7 +53,7 @@ def _apply_rotary_emb(
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding")
|
||||
@register_custom_op("sglang_rotary_embedding")
|
||||
class RotaryEmbedding(CustomOp):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user