[kernel] Use sgl_kernel rope (#3169)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Byron Hsu
2025-01-27 22:33:11 -08:00
committed by GitHub
parent 81262c7b72
commit 988d0a4bfc
2 changed files with 45 additions and 16 deletions

View File

@@ -6,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.utils import is_cuda_available
_is_cuda_available = is_cuda_available()
if _is_cuda_available:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -75,7 +81,9 @@ class RotaryEmbedding(CustomOp):
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
if not _is_cuda_available:
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
@@ -141,17 +149,25 @@ class RotaryEmbedding(CustomOp):
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
if _is_cuda_available:
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
head_size=self.head_size,
cos_sin_cache=self.cos_sin_cache,
is_neox=self.is_neox_style,
)
else:
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
def forward_xpu(