Rename files in sgl kernel to avoid nested folder structure (#4213)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
152
sgl-kernel/python/sgl_kernel/elementwise.py
Normal file
152
sgl-kernel/python/sgl_kernel/elementwise.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream
|
||||
|
||||
|
||||
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
||||
# Kudos to @yzh119
|
||||
def rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps)
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def gemma_fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm(
|
||||
input, residual, weight, eps, get_cuda_stream()
|
||||
)
|
||||
|
||||
|
||||
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
|
||||
assert (
|
||||
input.shape[:-1] == output.shape[:-1]
|
||||
), f"{input.shape[:-1]} != {output.shape[:-1]}"
|
||||
assert (
|
||||
input.shape[-1] == 2 * output.shape[-1]
|
||||
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
|
||||
|
||||
|
||||
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def apply_rope_with_cos_sin_cache_inplace(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
) -> None:
|
||||
r"""
|
||||
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
||||
This is designed to be compatible with the SGL/vLLM implementation.
|
||||
The result is inplace applied to the input tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positions : torch.Tensor
|
||||
Position indices, shape: ``(nnz)``.
|
||||
query : torch.Tensor
|
||||
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
|
||||
key : torch.Tensor
|
||||
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
|
||||
cos_sin_cache : torch.Tensor
|
||||
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
|
||||
Cosine is the first half and Sine is the second half on rotary_dim.
|
||||
is_neox : bool
|
||||
Whether to use Neox style RoPE, default: ``True``.
|
||||
|
||||
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
||||
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
||||
dimensions ``([..., head_dim//2:])``.
|
||||
|
||||
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
||||
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
||||
Note
|
||||
----
|
||||
The rotary dimension is determined by the cosine cache and sine cache.
|
||||
"""
|
||||
if cos_sin_cache.dtype != torch.float32:
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
positions = positions.int()
|
||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
|
||||
q=query.view(query.shape[0], -1, head_size),
|
||||
k=key.view(key.shape[0], -1, head_size),
|
||||
q_rope=query.view(query.shape[0], -1, head_size),
|
||||
k_rope=key.view(key.shape[0], -1, head_size),
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
pos_ids=positions,
|
||||
interleave=(not is_neox),
|
||||
cuda_stream=get_cuda_stream(),
|
||||
)
|
||||
Reference in New Issue
Block a user