From 534f32d27c0cc48731fbdae3701fbb6c3bb4332a Mon Sep 17 00:00:00 2001 From: drslark <96540755+drslark@users.noreply.github.com> Date: Tue, 21 Oct 2025 20:20:57 +0800 Subject: [PATCH] [BugFix][mian] Fixed a triton kernel bug of layer_norm_fwd_kernel for Qwen3-next (#3549) ### What this PR does / why we need it? Fixes triton kernel **layer_norm_fwd_kernel**, descripted by https://github.com/vllm-project/vllm-ascend/issues/3548 ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? The environment is same with this issue, https://github.com/vllm-project/vllm-ascend/issues/3548. Starts a vllm server with: ```shell vllm serve /home/model/Qwen3-Next-80B-A3B-Instruct --port 22 --host 0.0.0.0 --served-model-name qwen3_next_mtp_0 --tensor-parallel-size 4 --max-model-len 32000 --gpu-memory-utilization 0.7 --enforce-eager ``` The, we start an aisbench clinet like: ```shell ais_bench --models vllm_api_general_chat --datasets ceval_gen_0_shot_cot_chat_prompt --dump-eval-details ``` Whose config is: ```python # a big batch_size and a large max_out_len dict( abbr='vllm-api-general-chat', attr='service', batch_size=512, generation_kwargs=dict(temperature=0.7, top_k=20, top_p=0.8), host_ip='xxx.xxx.xxx.xxx', host_port=8881, max_out_len=30000, model='qwen3_next_mtp_0', path='', pred_postprocessor=dict( type= 'ais_bench.benchmark.utils.model_postprocessors.extract_non_reasoning_content' ), request_rate=0, retry=2, trust_remote_code=False, type='ais_bench.benchmark.models.VLLMCustomAPIChat'), ``` **Results:** ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 72.1 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 98.3%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 72.1 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 100.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 71.4 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 100.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 49.6 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 86.1%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 59.8 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 61.2 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:45:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 61.8 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:45:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 62.4 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 90.8%, Prefix cache hit rate: 0.0% ``` We can see when we sent a bunch of requests and the **KV cache usage reaches 100.0%**. We won't get a **coreDim=xxx can't be greater than UINT16_MAX.** Exception. ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.6 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 98.3%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.3 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.6 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:18:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.9 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:18:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 22.7 tokens/s, Running: 2 reqs, Waiting: 6 reqs, GPU KV cache usage: 81.9%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48568 - "POST /v1/chat/completions HTTP/1.1" 200 OK ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48580 - "POST /v1/chat/completions HTTP/1.1" 200 OK ``` And after a few minutes, these two requests have been done. ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.3 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 41.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48712 - "POST /v1/chat/completions HTTP/1.1" 200 OK ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.8 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0% ``` Finally, all requests are done. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: drslark --- vllm_ascend/ops/fla.py | 89 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 85 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/ops/fla.py b/vllm_ascend/ops/fla.py index b200c67..7903900 100644 --- a/vllm_ascend/ops/fla.py +++ b/vllm_ascend/ops/fla.py @@ -8,9 +8,89 @@ import torch import torch.nn.functional as F -import triton -from vllm.model_executor.layers.fla.ops.layernorm_guard import \ - layer_norm_fwd_kernel +from vllm.triton_utils import tl, triton + +MAX_CORES = 65535 + + +@triton.heuristics({ + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, +}) +@triton.jit +def layer_norm_fwd_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X_base + N, # number of columns in X_base + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + N_CORES: tl.constexpr, +): + # Map the program id to the row of X_base and Y_base it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + + BLOCK_ROWS = M if M < N_CORES else N_CORES + n_iters = M // BLOCK_ROWS + remain = M % BLOCK_ROWS + if row < remain: + n_iters = n_iters + 1 + + for i in tl.range(n_iters): + X_base = X + (i * BLOCK_ROWS * + stride_x_row) + row * stride_x_row + group * N + Y_base = Y + (i * BLOCK_ROWS * + stride_y_row) + row * stride_y_row + group * N + if HAS_Z: + Z_base = Z + (i * BLOCK_ROWS * + stride_z_row) + row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean_base = Mean + (i * BLOCK_ROWS) + group * M + Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M + W_base = W + group * N + if HAS_BIAS: + B_base = B + group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X_base + cols, mask=cols < N, other=0.).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean_base + row, mean) + xbar = tl.where(cols < N, x - mean, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd_base + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W_base + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B_base + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z_base + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y_base + cols, y, mask=mask) def _layer_norm_fwd( @@ -55,7 +135,7 @@ def _layer_norm_fwd( "This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) - grid = (M, ngroups) + grid = (M if M < MAX_CORES else MAX_CORES, ngroups) with torch.npu.device(x.device.index): layer_norm_fwd_kernel[grid]( x, @@ -74,6 +154,7 @@ def _layer_norm_fwd( BLOCK_N=BLOCK_N, NORM_BEFORE_GATE=norm_before_gate, IS_RMS_NORM=is_rms_norm, + N_CORES=MAX_CORES, num_warps=num_warps, ) return out, mean, rstd