[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #12) (#6177)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
d68209402d

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-23 14:59:19 +08:00
committed by GitHub
parent 193acc2c19
commit 78af0c30a3
25 changed files with 760 additions and 996 deletions

View File

@@ -15,8 +15,7 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
@triton.jit
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK)
rindex = tl.arange(0, N)[None, :]
@@ -24,8 +23,7 @@ def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None]
xmask = row_idx < M
xs = tl.load(X + (rindex + N * row_idx), mask=xmask,
other=0.0).to(tl.float32)
xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32)
square = xs * xs
square_sum = tl.sum(square, 1)[:, None]
rsqrt = tl.rsqrt(square_sum + eps)
@@ -33,9 +31,7 @@ def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
def l2norm_fwd(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: torch.dtype | None = None):
def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None):
x_shape_og = x.shape
x = x.reshape(-1, x.shape[-1])
# allocate output
@@ -56,7 +52,7 @@ def l2norm_fwd(x: torch.Tensor,
num_core = get_vectorcore_num()
main_bs = triton.cdiv(T, num_core)
num_sub_blocks = triton.cdiv(main_bs, MBLOCK)
grid = (num_core, )
grid = (num_core,)
l2norm_fwd_kernel2_loop[grid](
X=x,
Y=y,