Files

442 lines
15 KiB
Python
Raw Permalink Normal View History

# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def matmul_bias_persistent_kernel(
# Input tensor pointers
x_ptr,
y_ptr,
bias_ptr,
output_ptr,
# Matrix dimensions
M,
N,
K,
# Stride information
stride_xm,
stride_xk, # Strides of x: [M, K]
stride_yk,
stride_yn, # Strides of y: [K, N]
stride_bias, # Stride of bias: [N]
stride_outm,
stride_outn, # Strides of output: [M, N]
# Whether to use bias
has_bias: tl.constexpr,
# Block sizes
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0) # Row block ID
pid_n = tl.program_id(1) # Column block ID
# Calculate the starting position of the current block in the matrix
rm_start = pid_m * BLOCK_M
rn_start = pid_n * BLOCK_N
# Create index ranges
rm = rm_start + tl.arange(0, BLOCK_M) # Row index range [BLOCK_M]
rn = rn_start + tl.arange(0, BLOCK_N) # Column index range [BLOCK_N]
rk = tl.arange(0, BLOCK_K) # K dimension index range [BLOCK_K]
# Initialize accumulator to 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Loop over the K dimension, processing BLOCK_K elements per iteration
for k in range(0, tl.cdiv(K, BLOCK_K)):
k_start = k * BLOCK_K
# Calculate pointer offsets for x (row-major)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] + k_start) * stride_xk
# Calculate pointer offsets for y (row-major)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
y_ptrs = y_ptr + (rk[:, None] + k_start) * stride_yk + rn[None, :] * stride_yn
# Create masks to prevent out-of-bounds access
x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K)
y_mask = ((rk[:, None] + k_start) < K) & (rn[None, :] < N)
# Load data chunks from global memory
x_chunk = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
y_chunk = tl.load(y_ptrs, mask=y_mask, other=0.0).to(tl.float32)
# Compute matrix multiplication accumulation
acc += tl.dot(x_chunk, y_chunk, allow_tf32=False)
# Add bias if the has_bias flag is set
if has_bias:
# Load bias values (broadcast to all rows)
bias_ptrs = bias_ptr + rn * stride_bias
bias_mask = rn < N
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
bias_vals = tl.load(bias_ptrs, mask=bias_mask, other=0.0).to(tl.float32)
# Add bias to accumulator (automatic broadcasting)
acc += bias_vals[None, :]
# Calculate output pointer positions
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
out_ptrs = output_ptr + rm[:, None] * stride_outm + rn[None, :] * stride_outn
out_mask = (rm[:, None] < M) & (rn[None, :] < N)
# Store result to global memory
tl.store(out_ptrs, acc.to(out_ptrs.dtype.element_ty), mask=out_mask)
def matmul_persistent(x, y, bias=None):
"""
Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
Parameters:
x: torch.Tensor, shape [M, K]
y: torch.Tensor, shape [K, N]
bias: torch.Tensor, shape [N] or None
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
Returns:
output: torch.Tensor, shape [M, N]
"""
# Validate input shapes
assert x.dim() == 2, "x must be a 2D tensor"
assert y.dim() == 2, "y must be a 2D tensor"
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
assert x.shape[1] == y.shape[0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
[Feature] Batch invariant torch.compile (#6107) ### What this PR does / why we need it? Building upon https://github.com/vllm-project/vllm-ascend/pull/5517 to enable batch-invariant in vllm-ascend, we observed that the performance of BI in eager mode remains suboptimal. This PR further integrates batch-invariant with torch.compile, which improves inference performance by 350% when tested with Qwen3-0.6B. ### Does this PR introduce _any_ user-facing change? Previously, enabling both aclgraph and Batch-Invariant would cause an "ub overflow" error. This occurred because transposed input tensors could produce incorrect stride() values. To fix this, we now call .contiguous() on the input tensors before passing them to Triton kernels. This ensures a contiguous memory layout and prevents transposed tensors from causing incorrect stride calculations. ### Test Plan pytest -sv --durations=0 tests/e2e/singlecard/test_aclgraph_batch_invariant.py ### Test Result ``` ============================================================================ slowest durations ============================================================================ 87.37s call tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_v1_generation_is_deterministic_across_batch_sizes_with_needle 77.39s call tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN 74.04s call tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_logprobs_without_batch_invariance_should_fail 73.59s call tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_simple_generation (8 durations < 0.005s hidden. Use -vv to show these durations.) ================================================================ 4 passed, 3 warnings in 312.45s (0:05:12) ================================================================ ``` ### Performance export VLLM_BATCH_INVARIANT=1 vllm serve /home/Qwen3-0.6B \ --served-model-name qwen \ --port 8000 \ --max-num-seqs 256 \ --tensor-parallel-size 1 \ --max-model-len 5500 \ --max-num-batched-tokens 5500 \ --reasoning-parser qwen3 \ --gpu-memory-utilization 0.9 \ --compilation_config '{"cudagraph_mode":"FULL_DECODE_ONLY", "cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config '{"ascend_scheduler_config":{"enabled":true},"enable_weight_nz_layout":true}' vllm bench serve --served-model-name qwen --trust-remote-code --backend vllm --model /home/Qwen3-0.6B/ --endpoint /v1/completions --dataset-name random --random-input-len 512 --random-output-len 256 --num-prompts 800 --max-concurrency 8 torch.compile batch invariant performance: ``` ============ Serving Benchmark Result ============ Successful requests: 800 Failed requests: 0 Maximum request concurrency: 8 Benchmark duration (s): 477.21 Total input tokens: 409600 Total generated tokens: 204800 Request throughput (req/s): 1.68 Output token throughput (tok/s): 429.16 Peak output token throughput (tok/s): 472.00 Peak concurrent requests: 16.00 Total token throughput (tok/s): 1287.48 ---------------Time to First Token---------------- Mean TTFT (ms): 285.53 Median TTFT (ms): 312.70 P99 TTFT (ms): 324.22 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 17.59 Median TPOT (ms): 17.50 P99 TPOT (ms): 18.44 ---------------Inter-token Latency---------------- Mean ITL (ms): 17.59 Median ITL (ms): 17.45 P99 ITL (ms): 18.76 ================================================== ``` Eager ``` ============ Serving Benchmark Result ============ Successful requests: 800 Failed requests: 0 Maximum request concurrency: 8 Benchmark duration (s): 1694.70 Total input tokens: 409600 Total generated tokens: 204800 Request throughput (req/s): 0.47 Output token throughput (tok/s): 120.85 Peak output token throughput (tok/s): 136.00 Peak concurrent requests: 16.00 Total token throughput (tok/s): 362.54 ---------------Time to First Token---------------- Mean TTFT (ms): 164.29 Median TTFT (ms): 129.71 P99 TTFT (ms): 1961.66 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 65.81 Median TPOT (ms): 65.15 P99 TPOT (ms): 72.27 ---------------Inter-token Latency---------------- Mean ITL (ms): 65.81 Median ITL (ms): 64.64 P99 ITL (ms): 75.72 ================================================== ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 --------- Signed-off-by: huangning1995 <huangning12@huawei.com>
2026-01-26 09:15:06 +08:00
# Convert tensors to contiguous memory layout.
# This prevents transposed tensors from causing incorrect stride() values,
# which would lead to miscalculated data transfer volumes in subsequent operations.
x = x.contiguous()
y = y.contiguous()
M, K = x.shape
_, N = y.shape
# Validate bias shape (if not None)
if bias is not None:
assert bias.dim() == 1, "bias must be a 1D tensor"
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
assert y.shape[1] == bias.shape[0], (
f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
)
# Allocate output tensor (same data type as x)
output = torch.empty((M, N), dtype=x.dtype, device=x.device)
# Define block sizes (can be adjusted based on hardware)
[Feature] Batch invariant torch.compile (#6107) ### What this PR does / why we need it? Building upon https://github.com/vllm-project/vllm-ascend/pull/5517 to enable batch-invariant in vllm-ascend, we observed that the performance of BI in eager mode remains suboptimal. This PR further integrates batch-invariant with torch.compile, which improves inference performance by 350% when tested with Qwen3-0.6B. ### Does this PR introduce _any_ user-facing change? Previously, enabling both aclgraph and Batch-Invariant would cause an "ub overflow" error. This occurred because transposed input tensors could produce incorrect stride() values. To fix this, we now call .contiguous() on the input tensors before passing them to Triton kernels. This ensures a contiguous memory layout and prevents transposed tensors from causing incorrect stride calculations. ### Test Plan pytest -sv --durations=0 tests/e2e/singlecard/test_aclgraph_batch_invariant.py ### Test Result ``` ============================================================================ slowest durations ============================================================================ 87.37s call tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_v1_generation_is_deterministic_across_batch_sizes_with_needle 77.39s call tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN 74.04s call tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_logprobs_without_batch_invariance_should_fail 73.59s call tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_simple_generation (8 durations < 0.005s hidden. Use -vv to show these durations.) ================================================================ 4 passed, 3 warnings in 312.45s (0:05:12) ================================================================ ``` ### Performance export VLLM_BATCH_INVARIANT=1 vllm serve /home/Qwen3-0.6B \ --served-model-name qwen \ --port 8000 \ --max-num-seqs 256 \ --tensor-parallel-size 1 \ --max-model-len 5500 \ --max-num-batched-tokens 5500 \ --reasoning-parser qwen3 \ --gpu-memory-utilization 0.9 \ --compilation_config '{"cudagraph_mode":"FULL_DECODE_ONLY", "cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config '{"ascend_scheduler_config":{"enabled":true},"enable_weight_nz_layout":true}' vllm bench serve --served-model-name qwen --trust-remote-code --backend vllm --model /home/Qwen3-0.6B/ --endpoint /v1/completions --dataset-name random --random-input-len 512 --random-output-len 256 --num-prompts 800 --max-concurrency 8 torch.compile batch invariant performance: ``` ============ Serving Benchmark Result ============ Successful requests: 800 Failed requests: 0 Maximum request concurrency: 8 Benchmark duration (s): 477.21 Total input tokens: 409600 Total generated tokens: 204800 Request throughput (req/s): 1.68 Output token throughput (tok/s): 429.16 Peak output token throughput (tok/s): 472.00 Peak concurrent requests: 16.00 Total token throughput (tok/s): 1287.48 ---------------Time to First Token---------------- Mean TTFT (ms): 285.53 Median TTFT (ms): 312.70 P99 TTFT (ms): 324.22 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 17.59 Median TPOT (ms): 17.50 P99 TPOT (ms): 18.44 ---------------Inter-token Latency---------------- Mean ITL (ms): 17.59 Median ITL (ms): 17.45 P99 ITL (ms): 18.76 ================================================== ``` Eager ``` ============ Serving Benchmark Result ============ Successful requests: 800 Failed requests: 0 Maximum request concurrency: 8 Benchmark duration (s): 1694.70 Total input tokens: 409600 Total generated tokens: 204800 Request throughput (req/s): 0.47 Output token throughput (tok/s): 120.85 Peak output token throughput (tok/s): 136.00 Peak concurrent requests: 16.00 Total token throughput (tok/s): 362.54 ---------------Time to First Token---------------- Mean TTFT (ms): 164.29 Median TTFT (ms): 129.71 P99 TTFT (ms): 1961.66 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 65.81 Median TPOT (ms): 65.15 P99 TPOT (ms): 72.27 ---------------Inter-token Latency---------------- Mean ITL (ms): 65.81 Median ITL (ms): 64.64 P99 ITL (ms): 75.72 ================================================== ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 --------- Signed-off-by: huangning1995 <huangning12@huawei.com>
2026-01-26 09:15:06 +08:00
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
# Calculate grid size (one thread per block)
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
# Handle case when bias is None
if bias is None:
# Create a dummy bias tensor (will not be used as has_bias=False)
dummy_bias = torch.empty(0, dtype=x.dtype, device=x.device)
has_bias = False
bias_stride = 0
bias_to_pass = dummy_bias
else:
has_bias = True
bias_stride = bias.stride(0)
bias_to_pass = bias
# Launch kernel
matmul_bias_persistent_kernel[grid](
x,
y,
bias_to_pass,
output, # Input/Output tensors
M,
N,
K, # Matrix dimensions
x.stride(0),
x.stride(1), # Strides of x
y.stride(0),
y.stride(1), # Strides of y
bias_stride, # Stride of bias (0 if bias is None)
output.stride(0),
output.stride(1), # Strides of output
has_bias, # Flag indicating whether to use bias
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return output
@triton.jit
def linear_persistent_kernel(
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
a_ptr, # Pointer to tensor a, shape [M, K]
b_ptr, # Pointer to tensor b, shape [N, K]
c_ptr, # Pointer to output tensor c, shape [M, N]
M, # Number of rows in tensor a
N, # Number of rows in tensor b (number of columns in output c)
K, # Number of columns in both tensor a and tensor b
stride_am, # Stride of tensor a along dimension M (typically K)
stride_ak, # Stride of tensor a along dimension K (typically 1)
stride_bn, # Stride of tensor b along dimension N (typically K)
stride_bk, # Stride of tensor b along dimension K (typically 1)
stride_cm, # Stride of tensor c along dimension M (typically N)
stride_cn, # Stride of tensor c along dimension N (typically 1)
BLOCK_M: tl.constexpr, # Block size for M dimension
BLOCK_N: tl.constexpr, # Block size for N dimension
BLOCK_K: tl.constexpr, # Block size for K dimension
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
):
# Get current program's 1D index (1D grid)
pid = tl.program_id(0)
total_blocks = NUM_BLOCKS_M * NUM_BLOCKS_N # Total number of output blocks
# Loop over multiple blocks assigned to the current program
for block_index in range(pid, total_blocks, GRID_SIZE):
# Convert 1D block index to 2D coordinates (m_block, n_block)
m_block = block_index // NUM_BLOCKS_N
n_block = block_index % NUM_BLOCKS_N
# Calculate starting indices of the current output block
start_m = m_block * BLOCK_M
start_n = n_block * BLOCK_N
# Create row and column index ranges within the current block
m_indices = start_m + tl.arange(0, BLOCK_M)
n_indices = start_n + tl.arange(0, BLOCK_N)
# Create masks to handle boundaries
m_mask = m_indices < M
n_mask = n_indices < N
# Initialize accumulator to 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Loop over K dimension with step size BLOCK_K
for k_offset in range(0, K, BLOCK_K):
k_indices = k_offset + tl.arange(0, BLOCK_K)
k_mask = k_indices < K
# Load block of tensor a: shape [BLOCK_M, BLOCK_K]
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[None, :] * stride_ak
a_vals = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
# Load block of tensor b: shape [BLOCK_N, BLOCK_K]
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[None, :] * stride_bk
b_vals = tl.load(b_ptrs, mask=n_mask[:, None] & k_mask[None, :], other=0.0)
# Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N]
b_vals_transposed = tl.trans(b_vals)
# Compute matrix multiplication: a_vals × b_vals_transposed
product = tl.dot(a_vals, b_vals_transposed)
acc += product
# Store result to output tensor c
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[None, :] * stride_cn
tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
def linear_persistent(x, y):
"""
Implement matrix multiplication with Triton: x @ y^T
Uses a fixed-size 1D grid
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
Parameters:
x: torch.Tensor, shape [M, K]
y: torch.Tensor, shape [N, K]
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
Returns:
output: torch.Tensor, shape [M, N]
"""
# Validate input shapes
assert x.dim() == 2, "x must be a 2D tensor"
assert y.dim() == 2, "y must be a 2D tensor"
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
assert x.shape[1] == y.shape[1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
M, K = x.shape
N, _ = y.shape
# Allocate output tensor (same data type as x)
output = torch.zeros((M, N), dtype=x.dtype, device=x.device)
[Lint]Add lint hooks for clang-format, shellcheck, forbidden imports, and boolean context manager checks (#7511) ### What this PR does / why we need it? This PR introduces several upstream `vllm`-aligned lint hooks into `vllm-ascend` and makes them part of the actual `pre-commit` flow. Main changes in this PR: - add `check-boolean-context-manager` to catch boolean expressions in `with` statements - add `check-forbidden-imports` to forbid direct `re` imports and disallowed direct `triton` imports - enable shell script linting through `tools/shellcheck.sh` - add root `.clang-format` aligned with upstream `vllm`, enable `clang-format` in `pre-commit`, temporarily **exclude all `csrc/**`** from `clang-format` to avoid bringing a large native code reformat into this PR This PR focuses on landing the smaller and immediately useful lint alignment first, without mixing in the larger requirements-management migration. ### Does this PR introduce _any_ user-facing change? No. This PR only updates repository lint configuration, static checks, and internal import/style enforcement. It does not change runtime behavior or public interfaces. ### How was this patch tested? Tested locally in the project virtual environment. Commands used: ```bash bash format.sh ``` Verified checks passed: ``` bash ruff check...............................................................Passed ruff format..............................................................Passed codespell................................................................Passed typos....................................................................Passed clang-format.............................................................Passed Lint GitHub Actions workflow files.......................................Passed Lint shell scripts.......................................................Passed Lint PNG exports from excalidraw.........................................Passed Check for spaces in all filenames........................................Passed Enforce __init__.py in Python packages...................................Passed Check for forbidden imports..............................................Passed Check for boolean ops in with-statements.................................Passed Suggestion...............................................................Passed - hook id: suggestion - duration: 0s To bypass pre-commit hooks, add --no-verify to git commit. ``` **note:** clang-format is enabled but currently excludes all csrc/** - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-03-24 20:03:01 +08:00
grid_size = (
triton.runtime.driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2
)
# Define block sizes (can be adjusted based on hardware)
BLOCK_K = 256
if x.dtype == torch.float32:
BLOCK_K = BLOCK_K // 2
grid_size_div4 = grid_size // 4
if M == 0 or N == 0:
BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 256
elif M < 256:
BLOCK_M = M
if grid_size * 128 <= N:
if M <= 128:
BLOCK_N = 256
else:
BLOCK_N = 128
elif grid_size * 32 >= N:
if M > N:
BLOCK_M = triton.cdiv(M, grid_size_div4)
BLOCK_N = triton.cdiv(N, 4)
else:
BLOCK_M = triton.cdiv(M, 4)
BLOCK_N = triton.cdiv(N, grid_size_div4)
else:
BLOCK_N = triton.next_power_of_2(triton.cdiv(N, grid_size))
elif M >= 256 and M < 1024:
if M < N:
BLOCK_M = 256
nums_m = triton.cdiv(M, BLOCK_M)
nums_n = grid_size // nums_m
if 128 * nums_n <= N:
BLOCK_N = 128
else:
BLOCK_N = min(triton.next_power_of_2(triton.cdiv(N, nums_n)), 128)
else:
BLOCK_M = min(triton.cdiv(M, grid_size_div4), 256)
BLOCK_N = min(triton.cdiv(N, 4), 128)
else:
if M > N:
BLOCK_M, BLOCK_N = 256, 128
nums_m = triton.cdiv(M, BLOCK_M)
nums_n = triton.cdiv(N, BLOCK_N)
if nums_m * nums_n < grid_size:
BLOCK_M = triton.cdiv(M, grid_size_div4)
BLOCK_N = triton.cdiv(N, 4)
else:
BLOCK_M, BLOCK_N = 128, 256
# Calculate number of blocks per dimension (ceil division)
num_blocks_m = triton.cdiv(M, BLOCK_M)
num_blocks_n = triton.cdiv(N, BLOCK_N)
# Set fixed 1D grid size
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
grid = (grid_size,)
# Launch kernel
linear_persistent_kernel[grid](
a_ptr=x,
b_ptr=y,
c_ptr=output,
M=M,
N=N,
K=K,
stride_am=x.stride(0),
stride_ak=x.stride(1),
stride_bn=y.stride(0),
stride_bk=y.stride(1),
stride_cm=output.stride(0),
stride_cn=output.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
NUM_BLOCKS_M=num_blocks_m, # Number of blocks in M dimension
NUM_BLOCKS_N=num_blocks_n, # Number of blocks in N dimension
GRID_SIZE=grid_size, # Fixed grid size
)
return output
def mm_batch_invariant(a, b):
return matmul_persistent(a, b)
def bmm_batch_invariant(a, b, *, out=None):
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
# Process each batch separately with our persistent kernel
if a.ndim == 3 and b.ndim == 3:
results = []
for i in range(a.shape[0]):
results.append(matmul_persistent(a[i], b[i]))
result = torch.stack(results, dim=0)
if out is not None:
out.copy_(result)
return out
return result
else:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
raise ValueError(f"bmm_batch_invariant expects 3D tensors, got shapes {a.shape} and {b.shape}")
def addmm_batch_invariant(bias, a, b):
return matmul_persistent(a, b, bias=bias)
def matmul_batch_invariant(a, b, *, out=None):
# torch.matmul can handle various dimensions
# For 2D x 2D, it's the same as matmul
if a.ndim == 2 and b.ndim == 2:
result = matmul_persistent(a, b)
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 3 and b.ndim == 3:
# Handle batched case like bmm
return bmm_batch_invariant(a, b, out=out)
elif a.ndim == 3 and b.ndim == 2:
# Handle 3D x 2D: common for linear layers
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
# Reshape to 2D, do mm, reshape back
batch, seq, hidden = a.shape
a_2d = a.reshape(-1, hidden)
result_2d = matmul_persistent(a_2d, b)
result = result_2d.reshape(batch, seq, -1)
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 2 and b.ndim == 3:
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
# By broadcasting `a` to 3D, we can reuse the batched matrix
# multiplication logic.
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1)
return bmm_batch_invariant(a_expanded, b, out=out)
elif a.ndim == 4 and b.ndim == 4:
# Handle 4D attention tensors: [batch, heads, seq, dim]
# Reshape to 3D, process, reshape back
batch, heads, seq_a, dim_a = a.shape
_, _, dim_b, seq_b = b.shape
# Reshape to [batch*heads, seq_a, dim_a]
a_3d = a.reshape(batch * heads, seq_a, dim_a)
b_3d = b.reshape(batch * heads, dim_b, seq_b)
# Do batched matmul
result_3d = bmm_batch_invariant(a_3d, b_3d)
# Reshape back to [batch, heads, seq_a, seq_b]
result = result_3d.reshape(batch, heads, seq_a, seq_b)
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
f"3D x 2D, 2D x 3D, and 4D x 4D, "
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
f"got shapes {a.shape} and {b.shape}"
)
def linear_batch_invariant(input_, weight, bias=None):
output = linear_persistent(input_, weight)
if bias is not None:
output = output + bias
return output