diff --git a/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py b/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py index 7a07efd93..8bc7d1e01 100644 --- a/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py +++ b/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py @@ -11,8 +11,8 @@ from vllm import _custom_ops as ops from sglang.srt.utils import is_hip -is_hip_ = is_hip() -fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn +_is_hip = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn def vllm_scaled_fp8_quant( diff --git a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py index 5cee72ebb..b01e5ceb2 100644 --- a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py +++ b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py @@ -8,8 +8,8 @@ from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_ from sglang.srt.utils import is_hip -is_hip_ = is_hip() -fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn +_is_hip = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn @triton.jit diff --git a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py index ed0bfc78b..ef50957e2 100644 --- a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py +++ b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py @@ -9,8 +9,8 @@ from vllm import _custom_ops as ops from sglang.srt.utils import is_hip -is_hip_ = is_hip() -fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn +_is_hip = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn def vllm_per_token_quant_fp8( diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc index 6da15a384..5f337c9da 100644 --- a/sgl-kernel/csrc/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -61,11 +61,15 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + /* + * From csrc/speculative + */ m.def( "verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " diff --git a/sgl-kernel/tests/test_per_tensor_quant_fp8.py b/sgl-kernel/tests/test_per_tensor_quant_fp8.py index 620fa2dba..0840f298f 100644 --- a/sgl-kernel/tests/test_per_tensor_quant_fp8.py +++ b/sgl-kernel/tests/test_per_tensor_quant_fp8.py @@ -7,8 +7,8 @@ from sgl_kernel import sgl_per_tensor_quant_fp8 from sglang.srt.utils import is_hip -is_hip_ = is_hip() -fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn +_is_hip = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn def sglang_scaled_fp8_quant( diff --git a/sgl-kernel/tests/test_per_token_quant_fp8.py b/sgl-kernel/tests/test_per_token_quant_fp8.py index 00a80ca01..80efd06e7 100644 --- a/sgl-kernel/tests/test_per_token_quant_fp8.py +++ b/sgl-kernel/tests/test_per_token_quant_fp8.py @@ -7,8 +7,8 @@ from sgl_kernel import sgl_per_token_quant_fp8 from sglang.srt.utils import is_hip -is_hip_ = is_hip() -fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn +_is_hip = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn def torch_per_token_quant_fp8(tensor, inv_scale):