use default for torch.ops (#4835)

This commit is contained in:
Yineng Zhang
2025-03-27 19:09:58 -07:00
committed by GitHub
parent 10a9ab7b07
commit 31dfff7da7
7 changed files with 51 additions and 47 deletions

View File

@@ -14,14 +14,14 @@ def rmsnorm(
) -> torch.Tensor:
if out is None:
out = torch.empty_like(input)
torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream())
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, get_cuda_stream())
return out
def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps)
torch.ops.sgl_kernel.fused_add_rmsnorm.default(input, residual, weight, eps)
def gemma_rmsnorm(
@@ -32,14 +32,16 @@ def gemma_rmsnorm(
) -> torch.Tensor:
if out is None:
out = torch.empty_like(input)
torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
torch.ops.sgl_kernel.gemma_rmsnorm.default(
out, input, weight, eps, get_cuda_stream()
)
return out
def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm(
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
input, residual, weight, eps, get_cuda_stream()
)
@@ -65,7 +67,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream())
torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream())
return out
@@ -80,7 +82,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream())
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream())
return out
@@ -95,7 +97,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream())
torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream())
return out
@@ -139,7 +141,7 @@ def apply_rope_with_cos_sin_cache_inplace(
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size),
q_rope=query.view(query.shape[0], -1, head_size),