[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #12) (#6177)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
d68209402d

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-23 14:59:19 +08:00
committed by GitHub
parent 193acc2c19
commit 78af0c30a3
25 changed files with 760 additions and 996 deletions

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# mypy: ignore-errors
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
@@ -66,9 +65,13 @@ def solve_tril_16x16_kernel(
offs_cols_in_block = tl.arange(0, 16)
# 2 Calculate the pointer of each element
ptr_A_subrec16 = (A + row_start_o * H * BT + col_start_o +
offs_rows_in_block[:, None] * H * BT +
offs_cols_in_block[None, :])
ptr_A_subrec16 = (
A
+ row_start_o * H * BT
+ col_start_o
+ offs_rows_in_block[:, None] * H * BT
+ offs_cols_in_block[None, :]
)
# 3 Create a mask to prevent out-of-bounds access
global_rows = row_start_o + offs_rows_in_block[:, None]
@@ -76,14 +79,14 @@ def solve_tril_16x16_kernel(
load_mask = (global_rows < T) & (global_cols < BT)
# 4 Use mask to safely load data
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask,
other=0.0).to(tl.float32)
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32)
b_A = tl.insert_slice(
ful=b_A,
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
offsets=[blkid, 0, 0],
sizes=[1, 16, 16],
strides=[1, 1, 1])
strides=[1, 1, 1],
)
local_ori_A = tl.trans(b_A, (1, 0, 2))
local_ori_A = tl.reshape(local_ori_A, (16, 16 * N_BLOCKS))
@@ -97,9 +100,7 @@ def solve_tril_16x16_kernel(
# for loop to update N_BLOCKS row vector
for i in range(1, 16):
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0),
(1, 16 * N_BLOCKS),
(16 * N_BLOCKS, 1))
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))
dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
@@ -107,34 +108,27 @@ def solve_tril_16x16_kernel(
b_a = b_a + dot_product
b_a_new_expanded = b_a[:, None, :]
b_A = tl.insert_slice(ful=b_A,
sub=b_a_new_expanded,
offsets=[0, i, 0],
sizes=[N_BLOCKS, 1, 16],
strides=[1, 1, 1])
b_A = tl.insert_slice(
ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1]
)
on_diagonal = (rows == cols)
on_diagonal = rows == cols
b_A = tl.where(on_diagonal, b_A + 1.0, b_A)
b_A = tl.reshape(b_A, (N_BLOCKS * 16, 16))
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0),
(N_BLOCKS * 16, 16), (1, 0))
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0), (N_BLOCKS * 16, 16), (1, 0))
# 1 Create in-block offset
offs_rows_to_store = tl.arange(0, N_BLOCKS * 16)
offs_cols_to_store = tl.arange(0, 16)
# 2 Calculate the pointer of each element
p_Ai = (Ad + base_t * H * 16 + 0 +
offs_rows_to_store[:, None] * H * 16 +
offs_cols_to_store[None, :])
p_Ai = Ad + base_t * H * 16 + 0 + offs_rows_to_store[:, None] * H * 16 + offs_cols_to_store[None, :]
# 3 Create a mask to prevent out-of-bounds access, only check rows
global_store_rows = base_t + offs_rows_to_store[:, None]
store_mask = global_store_rows < T
# 4 use mask to save data safely
tl.store(p_Ai,
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
mask=store_mask)
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=store_mask)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@@ -169,18 +163,12 @@ def merge_16x16_to_32x32_inverse_kernel(
Ad += (bos * H + i_h) * 16
Ai += (bos * H + i_h) * 32
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
(16, 16), (1, 0))
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0),
(16, 16), (1, 0))
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0),
(16, 16), (1, 0))
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0),
(16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16),
(16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
(16, 16), (1, 0))
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0))
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
@@ -313,26 +301,20 @@ def merge_16x16_to_64x64_inverse_kernel(
offs_n = tl.arange(0, 32)
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
tl.store(ptr_Ai,
Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
mask=mask_store)
tl.store(ptr_Ai, Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
# store Ai_22_32 to (i_t * 64 + 32, 32)
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
offs_n = 32 + tl.arange(0, 32)
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
tl.store(ptr_Ai,
Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
mask=mask_store)
tl.store(ptr_Ai, Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
# store Ai_21_32 to (i_t * 64 + 32, 32)
offs_n = tl.arange(0, 32)
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
tl.store(ptr_Ai,
Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
mask=mask_store)
tl.store(ptr_Ai, Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
# zero out the upper-right 32 * 32 block (rows 0 ~ 31, cols 32 ~ 63)
offs_m = i_t * 64 + tl.arange(0, 32)
@@ -345,7 +327,7 @@ def merge_16x16_to_64x64_inverse_kernel(
def solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
cu_seqlens: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
"""
@@ -367,19 +349,12 @@ def solve_tril(
assert A.shape[-1] in [16, 32, 64]
B, T, H, BT = A.shape
Ad = torch.empty(B,
T,
H,
16,
device=A.device,
dtype=torch.float if BT != 16 else output_dtype)
Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype)
LARGE_BLOCK_T = 608 * 2
chunk_indices = (prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T)
if cu_seqlens is not None else None)
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(
T, LARGE_BLOCK_T)
chunk_indices = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) if cu_seqlens is not None else None
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, LARGE_BLOCK_T)
solve_tril_16x16_kernel[NT, B * H](
A=A,
@@ -398,10 +373,8 @@ def solve_tril(
return Ad
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
merge_fn = (merge_16x16_to_32x32_inverse_kernel
if BT == 32 else merge_16x16_to_64x64_inverse_kernel)
chunk_indices = (prepare_chunk_indices(cu_seqlens, BT)
if cu_seqlens is not None else None)
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
merge_fn[NT, B * H](