Files
xc-llm-ascend/vllm_ascend/ops/fla.py

300 lines
10 KiB
Python
Raw Normal View History

[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
# Copyright (c) 2024, Tri Dao.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
# mypy: ignore-errors
import torch
import torch.nn.functional as F
[BugFix][mian] Fixed a triton kernel bug of layer_norm_fwd_kernel for Qwen3-next (#3549) ### What this PR does / why we need it? Fixes triton kernel **layer_norm_fwd_kernel**, descripted by https://github.com/vllm-project/vllm-ascend/issues/3548 ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? The environment is same with this issue, https://github.com/vllm-project/vllm-ascend/issues/3548. Starts a vllm server with: ```shell vllm serve /home/model/Qwen3-Next-80B-A3B-Instruct --port 22 --host 0.0.0.0 --served-model-name qwen3_next_mtp_0 --tensor-parallel-size 4 --max-model-len 32000 --gpu-memory-utilization 0.7 --enforce-eager ``` The, we start an aisbench clinet like: ```shell ais_bench --models vllm_api_general_chat --datasets ceval_gen_0_shot_cot_chat_prompt --dump-eval-details ``` Whose config is: ```python # a big batch_size and a large max_out_len dict( abbr='vllm-api-general-chat', attr='service', batch_size=512, generation_kwargs=dict(temperature=0.7, top_k=20, top_p=0.8), host_ip='xxx.xxx.xxx.xxx', host_port=8881, max_out_len=30000, model='qwen3_next_mtp_0', path='', pred_postprocessor=dict( type= 'ais_bench.benchmark.utils.model_postprocessors.extract_non_reasoning_content' ), request_rate=0, retry=2, trust_remote_code=False, type='ais_bench.benchmark.models.VLLMCustomAPIChat'), ``` **Results:** ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 72.1 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 98.3%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 72.1 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 100.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 71.4 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 100.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 49.6 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 86.1%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 59.8 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 61.2 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:45:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 61.8 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:45:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 62.4 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 90.8%, Prefix cache hit rate: 0.0% ``` We can see when we sent a bunch of requests and the **KV cache usage reaches 100.0%**. We won't get a **coreDim=xxx can't be greater than UINT16_MAX.** Exception. ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.6 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 98.3%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.3 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.6 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:18:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.9 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:18:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 22.7 tokens/s, Running: 2 reqs, Waiting: 6 reqs, GPU KV cache usage: 81.9%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48568 - "POST /v1/chat/completions HTTP/1.1" 200 OK ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48580 - "POST /v1/chat/completions HTTP/1.1" 200 OK ``` And after a few minutes, these two requests have been done. ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.3 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 41.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48712 - "POST /v1/chat/completions HTTP/1.1" 200 OK ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.8 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0% ``` Finally, all requests are done. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: drslark <slarksblood@qq.com>
2025-10-21 20:20:57 +08:00
from vllm.triton_utils import tl, triton
MAX_CORES = 65535
@triton.heuristics({
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
})
@triton.jit
def layer_norm_fwd_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Z, # pointer to the other branch
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_z_row,
M, # number of rows in X_base
N, # number of columns in X_base
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
N_CORES: tl.constexpr,
):
# Map the program id to the row of X_base and Y_base it should compute.
row = tl.program_id(0)
group = tl.program_id(1)
BLOCK_ROWS = M if M < N_CORES else N_CORES
n_iters = M // BLOCK_ROWS
remain = M % BLOCK_ROWS
if row < remain:
n_iters = n_iters + 1
for i in tl.range(n_iters):
X_base = X + (i * BLOCK_ROWS *
stride_x_row) + row * stride_x_row + group * N
Y_base = Y + (i * BLOCK_ROWS *
stride_y_row) + row * stride_y_row + group * N
if HAS_Z:
Z_base = Z + (i * BLOCK_ROWS *
stride_z_row) + row * stride_z_row + group * N
if not IS_RMS_NORM:
Mean_base = Mean + (i * BLOCK_ROWS) + group * M
Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M
W_base = W + group * N
if HAS_BIAS:
B_base = B + group * N
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X_base + cols, mask=cols < N, other=0.).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean_base + row, mean)
xbar = tl.where(cols < N, x - mean, 0.)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd_base + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
w = tl.load(W_base + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B_base + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w + b if HAS_BIAS else x_hat * w
if HAS_Z and NORM_BEFORE_GATE:
z = tl.load(Z_base + cols, mask=mask).to(tl.float32)
y *= z * tl.sigmoid(z)
# Write output
tl.store(Y_base + cols, y, mask=mask)
[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
def _layer_norm_fwd(
x,
weight,
bias,
eps,
z=None,
out=None,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
M, N = x.shape
if group_size is None:
group_size = N
assert N % group_size == 0
ngroups = N // group_size
assert x.stride(-1) == 1
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N, )
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N, )
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
if not is_rms_norm else None)
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
[BugFix][mian] Fixed a triton kernel bug of layer_norm_fwd_kernel for Qwen3-next (#3549) ### What this PR does / why we need it? Fixes triton kernel **layer_norm_fwd_kernel**, descripted by https://github.com/vllm-project/vllm-ascend/issues/3548 ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? The environment is same with this issue, https://github.com/vllm-project/vllm-ascend/issues/3548. Starts a vllm server with: ```shell vllm serve /home/model/Qwen3-Next-80B-A3B-Instruct --port 22 --host 0.0.0.0 --served-model-name qwen3_next_mtp_0 --tensor-parallel-size 4 --max-model-len 32000 --gpu-memory-utilization 0.7 --enforce-eager ``` The, we start an aisbench clinet like: ```shell ais_bench --models vllm_api_general_chat --datasets ceval_gen_0_shot_cot_chat_prompt --dump-eval-details ``` Whose config is: ```python # a big batch_size and a large max_out_len dict( abbr='vllm-api-general-chat', attr='service', batch_size=512, generation_kwargs=dict(temperature=0.7, top_k=20, top_p=0.8), host_ip='xxx.xxx.xxx.xxx', host_port=8881, max_out_len=30000, model='qwen3_next_mtp_0', path='', pred_postprocessor=dict( type= 'ais_bench.benchmark.utils.model_postprocessors.extract_non_reasoning_content' ), request_rate=0, retry=2, trust_remote_code=False, type='ais_bench.benchmark.models.VLLMCustomAPIChat'), ``` **Results:** ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 72.1 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 98.3%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 72.1 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 100.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 71.4 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 100.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 49.6 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 86.1%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 59.8 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 61.2 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:45:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 61.8 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:45:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 62.4 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 90.8%, Prefix cache hit rate: 0.0% ``` We can see when we sent a bunch of requests and the **KV cache usage reaches 100.0%**. We won't get a **coreDim=xxx can't be greater than UINT16_MAX.** Exception. ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.6 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 98.3%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.3 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.6 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:18:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.9 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:18:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 22.7 tokens/s, Running: 2 reqs, Waiting: 6 reqs, GPU KV cache usage: 81.9%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48568 - "POST /v1/chat/completions HTTP/1.1" 200 OK ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48580 - "POST /v1/chat/completions HTTP/1.1" 200 OK ``` And after a few minutes, these two requests have been done. ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.3 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 41.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48712 - "POST /v1/chat/completions HTTP/1.1" 200 OK ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.8 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0% ``` Finally, all requests are done. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: drslark <slarksblood@qq.com>
2025-10-21 20:20:57 +08:00
grid = (M if M < MAX_CORES else MAX_CORES, ngroups)
[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
with torch.npu.device(x.device.index):
layer_norm_fwd_kernel[grid](
[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
[BugFix][mian] Fixed a triton kernel bug of layer_norm_fwd_kernel for Qwen3-next (#3549) ### What this PR does / why we need it? Fixes triton kernel **layer_norm_fwd_kernel**, descripted by https://github.com/vllm-project/vllm-ascend/issues/3548 ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? The environment is same with this issue, https://github.com/vllm-project/vllm-ascend/issues/3548. Starts a vllm server with: ```shell vllm serve /home/model/Qwen3-Next-80B-A3B-Instruct --port 22 --host 0.0.0.0 --served-model-name qwen3_next_mtp_0 --tensor-parallel-size 4 --max-model-len 32000 --gpu-memory-utilization 0.7 --enforce-eager ``` The, we start an aisbench clinet like: ```shell ais_bench --models vllm_api_general_chat --datasets ceval_gen_0_shot_cot_chat_prompt --dump-eval-details ``` Whose config is: ```python # a big batch_size and a large max_out_len dict( abbr='vllm-api-general-chat', attr='service', batch_size=512, generation_kwargs=dict(temperature=0.7, top_k=20, top_p=0.8), host_ip='xxx.xxx.xxx.xxx', host_port=8881, max_out_len=30000, model='qwen3_next_mtp_0', path='', pred_postprocessor=dict( type= 'ais_bench.benchmark.utils.model_postprocessors.extract_non_reasoning_content' ), request_rate=0, retry=2, trust_remote_code=False, type='ais_bench.benchmark.models.VLLMCustomAPIChat'), ``` **Results:** ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 72.1 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 98.3%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 72.1 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 100.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 71.4 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 100.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 49.6 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 86.1%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 59.8 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 61.2 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:45:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 61.8 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:45:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 62.4 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 90.8%, Prefix cache hit rate: 0.0% ``` We can see when we sent a bunch of requests and the **KV cache usage reaches 100.0%**. We won't get a **coreDim=xxx can't be greater than UINT16_MAX.** Exception. ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.6 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 98.3%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.3 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.6 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:18:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.9 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:18:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 22.7 tokens/s, Running: 2 reqs, Waiting: 6 reqs, GPU KV cache usage: 81.9%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48568 - "POST /v1/chat/completions HTTP/1.1" 200 OK ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48580 - "POST /v1/chat/completions HTTP/1.1" 200 OK ``` And after a few minutes, these two requests have been done. ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.3 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 41.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48712 - "POST /v1/chat/completions HTTP/1.1" 200 OK ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.8 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0% ``` Finally, all requests are done. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: drslark <slarksblood@qq.com>
2025-10-21 20:20:57 +08:00
N_CORES=MAX_CORES,
[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
num_warps=num_warps,
)
return out, mean, rstd
class LayerNormFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = _layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
return y.reshape(x_shape_og)
def torch_chunk_gated_delta_rule(
query,
key,
value,
[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
g,
beta,
chunk_size=64,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = F.normalize(query, p=2, dim=-1)
key = F.normalize(key, p=2, dim=-1)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch_size, sequence_length, num_heads, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
tot_heads = num_heads + pad_size
scale = 1 / (query.shape[-1]**0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
# reshape to chunks
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (query, key, value, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size,
chunk_size,
dtype=torch.bool,
device=query.device),
diagonal=0)
# chunk decay
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) -
g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -(
(k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (torch.zeros(batch_size, sequence_length,
k_head_dim, v_head_dim).to(value) if
initial_state is None else initial_state.to(value))
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size,
chunk_size,
dtype=torch.bool,
device=query.device),
diagonal=1)
# for each chunk
for i in range(0, tot_heads // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) *
decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp() +
(k_i *
(g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
-1, -2) @ v_new)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
core_attn_out.shape[1], -1,
core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :num_heads]
core_attn_out = core_attn_out.transpose(1,
2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state