[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #12) (#6177)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
d68209402d

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-23 14:59:19 +08:00
committed by GitHub
parent 193acc2c19
commit 78af0c30a3
25 changed files with 760 additions and 996 deletions

View File

@@ -26,8 +26,7 @@ def _swiglu_quant_kernel(
else:
gl_offsets = tl.arange(0, NUM_EXPERTS_ALGIN)
gl_mask = gl_offsets < NUM_EXPERTS
group_list = tl.load(group_list_ptr + gl_offsets, gl_mask,
other=0).to(tl.int32)
group_list = tl.load(group_list_ptr + gl_offsets, gl_mask, other=0).to(tl.int32)
total_rows = tl.sum(group_list)
block_size = (total_rows - 1) // NUM_CORES + 1
@@ -41,14 +40,8 @@ def _swiglu_quant_kernel(
# swiglu
x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS)
cur_x = tl.load(x_ptr + x_offsets)
x1 = tl.extract_slice(cur_x,
offsets=(0, ),
sizes=(HALF_COLS, ),
strides=(1, ))
x2 = tl.extract_slice(cur_x,
offsets=(HALF_COLS, ),
sizes=(HALF_COLS, ),
strides=(1, ))
x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
x2 = tl.extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
out = x1 * tl.sigmoid(x1) * x2
# quant
@@ -57,20 +50,13 @@ def _swiglu_quant_kernel(
# store scale
tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty))
for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE):
tmp_out = tl.extract_slice(out,
offsets=(col_blk_idx, ),
sizes=(COL_BLOCK_SIZE, ),
strides=(1, ))
tmp_out = (tmp_out.to(tl.float32) / scale).to(
x_ptr.dtype.element_ty)
tmp_out = tl.extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty)
tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate")
o_offsets = (row_idx * HALF_COLS + col_blk_idx +
tl.arange(0, COL_BLOCK_SIZE))
o_offsets = row_idx * HALF_COLS + col_blk_idx + tl.arange(0, COL_BLOCK_SIZE)
mask = (col_blk_idx + tl.arange(0, COL_BLOCK_SIZE)) < HALF_COLS
tl.store(out_ptr + o_offsets,
tmp_out.to(out_ptr.dtype.element_ty),
mask=mask)
tl.store(out_ptr + o_offsets, tmp_out.to(out_ptr.dtype.element_ty), mask=mask)
else:
# store out
o_offsets = row_idx * HALF_COLS + tl.arange(0, HALF_COLS)
@@ -80,12 +66,11 @@ def _swiglu_quant_kernel(
def swiglu_quant(x, group_list, group_list_type, need_quant=True):
# group_list_type must be 0 cusum or 1 count
if group_list_type not in [0, 1]:
raise ValueError(
f"group_list_type must be 0 or 1, but got {group_list_type}")
raise ValueError(f"group_list_type must be 0 or 1, but got {group_list_type}")
s, h = x.shape
out_dtype = torch.int8 if need_quant else x.dtype
out = torch.empty((s, h // 2), dtype=out_dtype, device=x.device)
scale = torch.empty((s, ), dtype=torch.float32, device=x.device)
scale = torch.empty((s,), dtype=torch.float32, device=x.device)
num_experts = group_list.shape[0]
# ub must be 32-byte aligned on npu
if group_list.dtype == torch.int64:
@@ -93,12 +78,10 @@ def swiglu_quant(x, group_list, group_list_type, need_quant=True):
elif group_list.dtype == torch.int32:
num_experts_algin = (num_experts + 15) // 16 * 16
else:
raise ValueError(
f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}"
)
raise ValueError(f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}")
num_vectorcore = get_vectorcore_num()
_swiglu_quant_kernel[(num_vectorcore, )](
_swiglu_quant_kernel[(num_vectorcore,)](
x,
group_list,
out,