From 9c3e95d98be95ce08d157d2331c091e28f24120b Mon Sep 17 00:00:00 2001 From: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com> Date: Fri, 15 Aug 2025 12:32:51 -0700 Subject: [PATCH] [AMD] Expand test coverage for AMD CI and enable apply_token_bitmask_inplace_cuda in sgl-kernel (#8268) --- .github/workflows/pr-test-amd.yml | 1 + .../srt/constrained/xgrammar_backend.py | 16 ++++++++--- sgl-kernel/csrc/common_extension_rocm.cc | 6 ++++ .../apply_token_bitmask_inplace_cuda.cu | 15 +++++++++- sgl-kernel/setup_rocm.py | 1 + test/srt/run_suite.py | 28 ++++++++++++++++++- 6 files changed, 61 insertions(+), 6 deletions(-) diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index 9756356bb..cb08ec534 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -322,6 +322,7 @@ jobs: docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_align.py docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_topk_softmax.py docker exec -w /sglang-checkout/sgl-kernel/tests/speculative ci_sglang python3 -m pytest test_eagle_utils.py + docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_apply_token_bitmask_inplace.py pr-test-amd-finish: if: always() diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 92e171662..6118aa22b 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -32,10 +32,15 @@ from sglang.srt.constrained.base_grammar_backend import ( BaseGrammarBackend, BaseGrammarObject, ) -from sglang.srt.constrained.triton_ops.bitmask_ops import ( - apply_token_bitmask_inplace_triton, -) +from sglang.srt.utils import is_hip +_is_hip = is_hip() +if _is_hip: + from sgl_kernel import apply_token_bitmask_inplace_cuda +else: + from sglang.srt.constrained.triton_ops.bitmask_ops import ( + apply_token_bitmask_inplace_triton, + ) logger = logging.getLogger(__name__) @@ -94,7 +99,10 @@ class XGrammarGrammar(BaseGrammarObject): def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: if logits.device.type == "cuda": - apply_token_bitmask_inplace_triton(logits, vocab_mask) + if _is_hip: + apply_token_bitmask_inplace_cuda(logits, vocab_mask) + else: + apply_token_bitmask_inplace_triton(logits, vocab_mask) elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu: self.apply_vocab_mask_cpu(logits, vocab_mask) else: diff --git a/sgl-kernel/csrc/common_extension_rocm.cc b/sgl-kernel/csrc/common_extension_rocm.cc index aaf474fb2..a97f17336 100644 --- a/sgl-kernel/csrc/common_extension_rocm.cc +++ b/sgl-kernel/csrc/common_extension_rocm.cc @@ -114,6 +114,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> " "()"); m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); + + /* + * From XGrammar + */ + m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()"); + m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.cu b/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.cu index a5d954e7f..84b678551 100644 --- a/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.cu +++ b/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.cu @@ -25,19 +25,24 @@ #include #include -#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 + +#if !defined(USE_ROCM) && (!defined(CUDA_VERSION) || CUDA_VERSION < 12040) void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional indices = at::nullopt) { TORCH_CHECK(false, "CUDA version must be >= 12.4 for ApplyTokenBitmaskInplace"); } #else #ifndef CUDART_INF_FP16 +#ifndef USE_ROCM #define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U) #endif +#endif #ifndef CUDART_INF_BF16 +#ifndef USE_ROCM #define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) #endif +#endif constexpr int32_t BITS_PER_BLOCK = 32; constexpr int32_t THREADS_PER_THREAD_BLOCK = 256; @@ -49,12 +54,20 @@ __device__ T NegativeInfinity() { template <> __device__ __half NegativeInfinity<__half>() { +#ifdef USE_ROCM + return __float2half(-INFINITY); +#else return -CUDART_INF_FP16; +#endif } template <> __device__ __nv_bfloat16 NegativeInfinity<__nv_bfloat16>() { +#ifdef USE_ROCM + return __nv_bfloat16(-INFINITY); +#else return -CUDART_INF_BF16; +#endif } template diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index a919d8f3b..ac61e4df9 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -48,6 +48,7 @@ sources = [ "csrc/moe/moe_topk_softmax_kernels.cu", "csrc/speculative/eagle_utils.cu", "csrc/common_extension_rocm.cc", + "csrc/grammar/apply_token_bitmask_inplace_cuda.cu", ] cxx_flags = ["-O3"] diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 276d2866d..bea31af00 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -158,40 +158,66 @@ suites = { # Add AMD tests suite_amd = { "per-commit-amd": [ + TestFile("lora/test_lora.py", 200), + TestFile("lora/test_lora_eviction.py", 200), TestFile("lora/test_lora_backend.py", 99), TestFile("lora/test_multi_lora_backend.py", 60), TestFile("lora/test_lora_cuda_graph.py", 250), + TestFile("lora/test_lora_qwen3.py", 97), + TestFile("models/test_embedding_models.py", 73), + TestFile("models/test_compressed_tensors_models.py", 42), TestFile("models/test_qwen_models.py", 82), TestFile("models/test_reward_models.py", 132), + TestFile("models/test_transformers_models.py", 320), + TestFile("openai_server/basic/test_protocol.py", 10), + TestFile("openai_server/basic/test_serving_chat.py", 10), + TestFile("openai_server/basic/test_serving_completions.py", 10), + TestFile("openai_server/basic/test_serving_embedding.py", 10), TestFile("openai_server/basic/test_openai_embedding.py", 141), + TestFile("openai_server/basic/test_openai_server.py", 149), TestFile("openai_server/features/test_enable_thinking.py", 70), + TestFile("openai_server/features/test_json_constrained.py", 98), + TestFile("openai_server/features/test_json_mode.py", 90), + TestFile("openai_server/features/test_openai_server_ebnf.py", 95), + # TestFile("openai_server/features/test_openai_server_hidden_states.py", 240), TestFile("openai_server/features/test_reasoning_content.py", 89), + TestFile("openai_server/function_call/test_openai_function_calling.py", 60), + TestFile("openai_server/function_call/test_tool_choice.py", 226), TestFile("openai_server/validation/test_large_max_new_tokens.py", 41), + TestFile("openai_server/validation/test_matched_stop.py", 60), + TestFile("openai_server/validation/test_openai_server_ignore_eos.py", 85), TestFile("openai_server/validation/test_request_length_validation.py", 31), TestFile("quant/test_block_int8.py", 22), TestFile("quant/test_awq_dequant.py", 2), TestFile("rl/test_update_weights_from_disk.py", 114), + # TestFile("rl/test_update_weights_from_tensor.py", 48), TestFile("test_abort.py", 51), TestFile("test_create_kvindices.py", 2), TestFile("test_chunked_prefill.py", 313), + TestFile("test_ebnf_constrained.py", 108), TestFile("test_eval_fp8_accuracy.py", 303), TestFile("test_function_call_parser.py", 10), TestFile("test_fused_moe.py", 30), TestFile("test_input_embeddings.py", 38), + TestFile("test_io_struct.py", 8), + TestFile("test_jinja_template_utils.py", 1), + TestFile("test_metrics.py", 32), TestFile("test_mla.py", 242), TestFile("test_mla_deepseek_v3.py", 221), - TestFile("test_metrics.py", 32), TestFile("test_no_chunked_prefill.py", 108), # TestFile("test_no_overlap_scheduler.py", 234), # Disabled temporarily and track in #7703 TestFile("test_penalty.py", 41), TestFile("test_page_size.py", 60), TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_radix_attention.py", 105), + TestFile("test_regex_constrained.py", 64), TestFile("test_retract_decode.py", 54), TestFile("test_reasoning_parser.py", 5), TestFile("test_rope_rocm.py", 3), TestFile("test_server_args.py", 1), TestFile("test_skip_tokenizer_init.py", 117), + TestFile("test_srt_engine.py", 261), + TestFile("test_srt_endpoint.py", 130), TestFile("test_torch_compile.py", 76), TestFile("test_torch_compile_moe.py", 172), TestFile("test_torch_native_attention_backend.py", 123),