diff --git a/.github/workflows/pr-test-npu.yml b/.github/workflows/pr-test-npu.yml index fe5c8fad1..03c1784f0 100644 --- a/.github/workflows/pr-test-npu.yml +++ b/.github/workflows/pr-test-npu.yml @@ -117,7 +117,7 @@ jobs: curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl - name: Run test - timeout-minutes: 60 + timeout-minutes: 120 env: SGLANG_USE_MODELSCOPE: true SGLANG_IS_IN_CI: true diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 7a393748b..495beb009 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -384,19 +384,83 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): dispatch_output: StandardDispatchOutput, ) -> CombineInput: - from sglang.srt.layers.moe.fused_moe_native import moe_forward_native + import torch_npu + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput x = dispatch_output.hidden_states - topk_output = dispatch_output.topk_output + topk_weights, topk_ids, _ = dispatch_output.topk_output - output = moe_forward_native( - layer, - x, - topk_output, - self.moe_runner_config, + original_dtype = x.dtype + num_tokens = x.shape[0] + topk_weights = topk_weights.to(x.dtype) + topk_ids = topk_ids.to(torch.int32) + num_experts = layer.num_experts + top_k = layer.top_k + row_idx_len = num_tokens * top_k + row_idx = ( + torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device) + .view(top_k, -1) + .permute(1, 0) + .contiguous() ) - return StandardCombineInput(hidden_states=output) + + hidden_states, expanded_row_idx, expanded_expert_idx = ( + torch_npu.npu_moe_init_routing( + x, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens + ) + ) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts + ) + + expert_tokens = expert_tokens.to(torch.int64) + if layer.w13_weight.shape[-1] == layer.hidden_size: + w13 = layer.w13_weight.transpose(1, 2) + w2 = layer.w2_weight.transpose(1, 2) + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w13], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + output_dtype=original_dtype, + )[0] + + # act_fn: + if self.moe_runner_config.activation == "silu": + hidden_states = torch_npu.npu_swiglu(hidden_states) + else: + from sglang.srt.layers.activation import GeluAndMul + + hidden_states = GeluAndMul()(hidden_states) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + output_dtype=original_dtype, + )[0] + + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + + return StandardCombineInput(hidden_states=final_hidden_states) def forward_tpu(self, *args, **kwargs) -> CombineInput: raise NotImplementedError("The TPU backend currently does not support MoE.") diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index f8fad8bcb..e95247041 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -17,7 +17,11 @@ from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import is_layer_skipped -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import is_npu, set_weight_attrs + +_is_npu = is_npu() +if not _is_npu: + from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe if TYPE_CHECKING: from sglang.srt.layers.moe import MoeRunnerConfig diff --git a/test/srt/ascend/test_ascend_tp4_bf16.py b/test/srt/ascend/test_ascend_tp4_bf16.py new file mode 100644 index 000000000..bb7d90e4f --- /dev/null +++ b/test/srt/ascend/test_ascend_tp4_bf16.py @@ -0,0 +1,101 @@ +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, + run_bench_offline_throughput, +) + +TEST_MODEL_MATRIX = { + "Qwen/Qwen3-30B-A3B-Instruct-2507": { + "accuracy": 0.90, + "latency": 180, + "output_throughput": 20, + }, +} + + +class TestAscendTp4Bf16(CustomTestCase): + + @classmethod + def setUpClass(cls): + cls.models = TEST_MODEL_MATRIX.keys() + cls.base_url = DEFAULT_URL_FOR_TEST + cls.url = urlparse(DEFAULT_URL_FOR_TEST) + cls.common_args = [ + "--trust-remote-code", + "--mem-fraction-static", + 0.7, + "--max-running-requests", + 32, + "--attention-backend", + "ascend", + "--cuda-graph-max-bs", + 32, + "--tp-size", + 4, + ] + + def test_a_gsm8k(self): + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing accuracy: {model} ===##") + + process = popen_launch_server( + model, + self.base_url, + timeout=1800, + other_args=[ + *self.common_args, + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=128, + host=f"http://{self.url.hostname}", + port=int(self.url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual( + metrics["accuracy"], + TEST_MODEL_MATRIX[model]["accuracy"], + ) + finally: + kill_process_tree(process.pid) + + def test_b_throughput(self): + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing throughput: {model} ===##") + + output_throughput = run_bench_offline_throughput( + model, + [ + *self.common_args, + ], + ) + + print(f"##=== {model} throughput: {output_throughput} ===##") + + if is_in_ci(): + self.assertGreater( + output_throughput, + TEST_MODEL_MATRIX[model]["output_throughput"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 10dd064d3..a918a6339 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -294,6 +294,7 @@ suite_ascend = { ], "per-commit-4-ascend-npu": [ TestFile("ascend/test_ascend_mla_w8a8int8.py", 400), + TestFile("ascend/test_ascend_tp4_bf16.py", 400), ], }