From 31dfff7da7ade6703303a67bfe6ef52ead97640a Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 27 Mar 2025 19:09:58 -0700 Subject: [PATCH] use default for torch.ops (#4835) --- sgl-kernel/python/sgl_kernel/allreduce.py | 30 ++++++++++----------- sgl-kernel/python/sgl_kernel/attention.py | 2 +- sgl-kernel/python/sgl_kernel/elementwise.py | 18 +++++++------ sgl-kernel/python/sgl_kernel/gemm.py | 26 +++++++++--------- sgl-kernel/python/sgl_kernel/moe.py | 4 +-- sgl-kernel/python/sgl_kernel/sampling.py | 10 +++---- sgl-kernel/python/sgl_kernel/speculative.py | 8 +++--- 7 files changed, 51 insertions(+), 47 deletions(-) diff --git a/sgl-kernel/python/sgl_kernel/allreduce.py b/sgl-kernel/python/sgl_kernel/allreduce.py index 0924e7f35..3a6e58d4f 100644 --- a/sgl-kernel/python/sgl_kernel/allreduce.py +++ b/sgl-kernel/python/sgl_kernel/allreduce.py @@ -12,49 +12,49 @@ if torch.version.hip is not None: rank: int, full_nvlink: bool, ) -> int: - return torch.ops.sgl_kernel.init_custom_ar( + return torch.ops.sgl_kernel.init_custom_ar.default( meta, rank_data, handles, offsets, rank, full_nvlink ) def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - torch.ops.sgl_kernel.all_reduce_reg(fa, inp, out) + torch.ops.sgl_kernel.all_reduce_reg.default(fa, inp, out) def all_reduce_unreg( fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor ) -> None: - torch.ops.sgl_kernel.all_reduce_unreg(fa, inp, reg_buffer, out) + torch.ops.sgl_kernel.all_reduce_unreg.default(fa, inp, reg_buffer, out) def dispose(fa: int) -> None: - torch.ops.sgl_kernel.dispose(fa) + torch.ops.sgl_kernel.dispose.default(fa) def meta_size() -> int: - return torch.ops.sgl_kernel.meta_size() + return torch.ops.sgl_kernel.meta_size.default() def register_buffer( fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] ) -> None: - return torch.ops.sgl_kernel.register_buffer(fa, t, handles, offsets) + return torch.ops.sgl_kernel.register_buffer.default(fa, t, handles, offsets) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: - return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa) + return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa) def register_graph_buffers( fa: int, handles: List[str], offsets: List[List[int]] ) -> None: - torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets) + torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets) def allocate_meta_buffer(size: int) -> torch.Tensor: - return torch.ops.sgl_kernel.allocate_meta_buffer(size) + return torch.ops.sgl_kernel.allocate_meta_buffer.default(size) def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: - return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle(inp) + return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp) else: # TRTLLM custom allreduce def init_custom_reduce( rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out ): - return torch.ops.sgl_kernel.init_custom_ar( + return torch.ops.sgl_kernel.init_custom_ar.default( rank_id, num_devices, rank_data, @@ -65,13 +65,13 @@ else: ) def custom_dispose(fa): - torch.ops.sgl_kernel.dispose(fa) + torch.ops.sgl_kernel.dispose.default(fa) def custom_reduce(fa, inp, out): - torch.ops.sgl_kernel.all_reduce(fa, inp, out) + torch.ops.sgl_kernel.all_reduce.default(fa, inp, out) def get_graph_buffer_ipc_meta(fa): - return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa) + return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa) def register_graph_buffers(fa, handles, offsets): - torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets) + torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets) diff --git a/sgl-kernel/python/sgl_kernel/attention.py b/sgl-kernel/python/sgl_kernel/attention.py index 53fec4dd1..6ad1d347e 100644 --- a/sgl-kernel/python/sgl_kernel/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -2,6 +2,6 @@ import torch def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): - torch.ops.sgl_kernel.lightning_attention_decode( + torch.ops.sgl_kernel.lightning_attention_decode.default( q, k, v, past_kv, slope, output, new_kv ) diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 9e0b11c2f..3db157156 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -14,14 +14,14 @@ def rmsnorm( ) -> torch.Tensor: if out is None: out = torch.empty_like(input) - torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream()) + torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, get_cuda_stream()) return out def fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: - torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps) + torch.ops.sgl_kernel.fused_add_rmsnorm.default(input, residual, weight, eps) def gemma_rmsnorm( @@ -32,14 +32,16 @@ def gemma_rmsnorm( ) -> torch.Tensor: if out is None: out = torch.empty_like(input) - torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream()) + torch.ops.sgl_kernel.gemma_rmsnorm.default( + out, input, weight, eps, get_cuda_stream() + ) return out def gemma_fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: - torch.ops.sgl_kernel.gemma_fused_add_rmsnorm( + torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default( input, residual, weight, eps, get_cuda_stream() ) @@ -65,7 +67,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream()) return out @@ -80,7 +82,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream()) return out @@ -95,7 +97,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream()) return out @@ -139,7 +141,7 @@ def apply_rope_with_cos_sin_cache_inplace( if cos_sin_cache.dtype != torch.float32: raise ValueError("cos_sin_cache should be float32") - torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache( + torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default( q=query.view(query.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size), q_rope=query.view(query.shape[0], -1, head_size), diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index f63360722..b1ef34596 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -7,11 +7,11 @@ from sgl_kernel.utils import _get_cache_buf, get_cuda_stream def awq_dequantize( qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor ) -> torch.ByteTensor: - return torch.ops.sgl_kernel.awq_dequantize(qweight, scales, qzeros) + return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros) def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): - return torch.ops.sgl_kernel.int8_scaled_mm( + return torch.ops.sgl_kernel.int8_scaled_mm.default( mat_a, mat_b, scales_a, @@ -22,7 +22,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): - return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm( + return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default( mat_a, mat_b, scales_a, @@ -32,7 +32,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): - return torch.ops.sgl_kernel.fp8_scaled_mm( + return torch.ops.sgl_kernel.fp8_scaled_mm.default( mat_a, mat_b, scales_a, @@ -51,7 +51,7 @@ def _bmm_fp8_internal( B_scale: torch.Tensor, ) -> None: cublas_handle = torch.cuda.current_blas_handle() - torch.ops.sgl_kernel.bmm_fp8( + torch.ops.sgl_kernel.bmm_fp8.default( A, B, D, @@ -91,7 +91,7 @@ def sgl_per_token_group_quant_fp8( fp8_min: float, fp8_max: float, ) -> None: - torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8( + torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default( input, output_q, output_s, group_size, eps, fp8_min, fp8_max ) @@ -105,7 +105,7 @@ def sgl_per_token_group_quant_int8( int8_min: float, int8_max: float, ) -> None: - torch.ops.sgl_kernel.sgl_per_token_group_quant_int8( + torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default( input, output_q, output_s, group_size, eps, int8_min, int8_max ) @@ -116,7 +116,9 @@ def sgl_per_tensor_quant_fp8( output_s: torch.Tensor, is_static: bool, ) -> None: - torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static) + torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default( + input, output_q, output_s, is_static + ) def cublas_grouped_gemm( @@ -129,7 +131,7 @@ def cublas_grouped_gemm( len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0 ), "Inputs/weights/outputs should not be empty!" cublas_handle = torch.cuda.current_blas_handle() - torch.ops.sgl_kernel.cublas_grouped_gemm( + torch.ops.sgl_kernel.cublas_grouped_gemm.default( inputs, weights, outputs, @@ -144,7 +146,7 @@ def sgl_per_token_quant_fp8( output_q: torch.Tensor, output_s: torch.Tensor, ) -> None: - torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s) + torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s) def cutlass_scaled_fp4_mm( @@ -158,7 +160,7 @@ def cutlass_scaled_fp4_mm( assert a.ndim == 2 and b.ndim == 2 m, n = a.shape[0], b.shape[0] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - torch.ops.sgl_kernels.cutlass_scaled_fp4_mm( + torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default( out, a, b, block_scale_a, block_scale_b, alpha ) return out @@ -210,7 +212,7 @@ def scaled_fp4_quant( (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 ) - torch.ops.sgl_kernels.scaled_fp4_quant( + torch.ops.sgl_kernel.scaled_fp4_quant.default( output, input, output_scale, input_global_scale ) output_scale = output_scale.view(torch.float8_e4m3fn) diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index 02ce026e8..84c79fdb7 100644 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -11,7 +11,7 @@ def moe_align_block_size( token_cnts_buffer, cumsum_buffer, ): - torch.ops.sgl_kernel.moe_align_block_size( + torch.ops.sgl_kernel.moe_align_block_size.default( topk_ids, num_experts, block_size, @@ -29,6 +29,6 @@ def topk_softmax( token_expert_indices: torch.Tensor, gating_output: float, ) -> None: - torch.ops.sgl_kernel.topk_softmax( + torch.ops.sgl_kernel.topk_softmax.default( topk_weights, topk_ids, token_expert_indices, gating_output ) diff --git a/sgl-kernel/python/sgl_kernel/sampling.py b/sgl-kernel/python/sgl_kernel/sampling.py index 2f57f1313..7c94e9eda 100644 --- a/sgl-kernel/python/sgl_kernel/sampling.py +++ b/sgl-kernel/python/sgl_kernel/sampling.py @@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal( probs = probs.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None renorm_probs = torch.empty_like(probs) - torch.ops.sgl_kernel.top_k_renorm_probs( + torch.ops.sgl_kernel.top_k_renorm_probs.default( probs, renorm_probs, maybe_top_k_arr, @@ -40,7 +40,7 @@ def _top_p_renorm_probs_internal( probs = probs.float() maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None renorm_probs = torch.empty_like(probs) - torch.ops.sgl_kernel.top_p_renorm_probs( + torch.ops.sgl_kernel.top_p_renorm_probs.default( probs, renorm_probs, maybe_top_p_arr, @@ -75,7 +75,7 @@ def _top_p_sampling_from_probs_internal( ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device) - torch.ops.sgl_kernel.top_p_sampling_from_probs( + torch.ops.sgl_kernel.top_p_sampling_from_probs.default( probs, uniform_samples, samples, @@ -121,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal( ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device) - torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs( + torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default( probs, uniform_samples, samples, @@ -179,7 +179,7 @@ def _min_p_sampling_from_probs_internal( maybe_min_p_arr.float() if maybe_min_p_arr is not None else None ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) - torch.ops.sgl_kernel.min_p_sampling_from_probs( + torch.ops.sgl_kernel.min_p_sampling_from_probs.default( probs, uniform_samples, samples, diff --git a/sgl-kernel/python/sgl_kernel/speculative.py b/sgl-kernel/python/sgl_kernel/speculative.py index ebec2a5a9..6eee58394 100644 --- a/sgl-kernel/python/sgl_kernel/speculative.py +++ b/sgl-kernel/python/sgl_kernel/speculative.py @@ -17,7 +17,7 @@ def tree_speculative_sampling_target_only( threshold_acc: float = 1.0, deterministic: bool = True, ) -> None: - torch.ops.sgl_kernel.tree_speculative_sampling_target_only( + torch.ops.sgl_kernel.tree_speculative_sampling_target_only.default( predicts, accept_index, accept_token_num, @@ -45,7 +45,7 @@ def verify_tree_greedy( retrive_next_sibling: torch.Tensor, target_predict: torch.Tensor, ) -> None: - torch.ops.sgl_kernel.verify_tree_greedy( + torch.ops.sgl_kernel.verify_tree_greedy.default( predicts, accept_index, accept_token_num, @@ -71,7 +71,7 @@ def build_tree_kernel_efficient( depth: int, draft_token_num: int, ) -> None: - torch.ops.sgl_kernel.build_tree_kernel_efficient( + torch.ops.sgl_kernel.build_tree_kernel_efficient.default( parent_list, selected_index, verified_seq_len, @@ -92,7 +92,7 @@ def segment_packbits( output_indptr: torch.Tensor, y: torch.Tensor, ) -> None: - torch.ops.sgl_kernel.segment_packbits( + torch.ops.sgl_kernel.segment_packbits.default( x, input_indptr, output_indptr,