[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #12) (#6177)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
d68209402d

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-23 14:59:19 +08:00
committed by GitHub
parent 193acc2c19
commit 78af0c30a3
25 changed files with 760 additions and 996 deletions

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# mypy: ignore-errors
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
@@ -16,11 +15,13 @@ from vllm.triton_utils import tl, triton
from .utils import prepare_chunk_indices, safe_exp
@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
'USE_G': lambda args: args['g_cumsum'] is not None,
})
@triton.jit(do_not_specialize=['T'])
@triton.heuristics(
{
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"USE_G": lambda args: args["g_cumsum"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel(
k,
beta, # [H, B, T]
@@ -44,10 +45,11 @@ def chunk_scaled_dot_kkt_fwd_kernel(
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t_i * 2).to(
tl.int32), tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = (
tl.load(chunk_indices + i_t_i * 2).to(tl.int32),
tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
@@ -55,39 +57,37 @@ def chunk_scaled_dot_kkt_fwd_kernel(
o_t = tl.arange(0, BT)
o_t_fp32 = o_t.to(tl.float32)
p_beta = tl.make_block_ptr(beta + i_h * bt_stride + bos, (T, ), (1, ),
(i_t * BT, ), (BT, ), (0, ))
b_beta = tl.load(p_beta, boundary_check=(0, ))
p_beta = tl.make_block_ptr(beta + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K,
(T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
(BT, BK), (1, 0))
p_k = tl.make_block_ptr(
k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_A += tl.dot(b_k, tl.trans(b_k))
if USE_G:
p_g = tl.make_block_ptr(g_cumsum + i_h * bt_stride + bos, (T, ),
(1, ), (i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
p_g = tl.make_block_ptr(g_cumsum + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_diff = b_g[:, None] - b_g[None, :]
b_A *= safe_exp(b_g_diff)
b_A *= b_beta[:, None]
b_A = tl.where(o_t_fp32[:, None] > o_t_fp32[None, :], b_A, 0)
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
(i_t * BT, 0), (BT, BT), (1, 0))
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
beta: torch.Tensor,
g_cumsum: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32) -> torch.Tensor:
k: torch.Tensor,
beta: torch.Tensor,
g_cumsum: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
r"""
Compute beta * K * K^T.
@@ -117,8 +117,7 @@ def chunk_scaled_dot_kkt_fwd(
BT = chunk_size
if cu_seqlens is not None:
cu_seqlens = cu_seqlens.cpu()
chunk_indices = (prepare_chunk_indices(cu_seqlens, BT)
if cu_seqlens is not None else None)
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
chunk_indices = chunk_indices.npu()
cu_seqlens = cu_seqlens.npu()
else: