use default for torch.ops (#4835)
This commit is contained in:
@@ -14,14 +14,14 @@ def rmsnorm(
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps)
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm.default(input, residual, weight, eps)
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
@@ -32,14 +32,16 @@ def gemma_rmsnorm(
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.gemma_rmsnorm.default(
|
||||
out, input, weight, eps, get_cuda_stream()
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def gemma_fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm(
|
||||
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
|
||||
input, residual, weight, eps, get_cuda_stream()
|
||||
)
|
||||
|
||||
@@ -65,7 +67,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
@@ -80,7 +82,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
@@ -95,7 +97,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
@@ -139,7 +141,7 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
if cos_sin_cache.dtype != torch.float32:
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
|
||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
|
||||
q=query.view(query.shape[0], -1, head_size),
|
||||
k=key.view(key.shape[0], -1, head_size),
|
||||
q_rope=query.view(query.shape[0], -1, head_size),
|
||||
|
||||
Reference in New Issue
Block a user