[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #12) (#6177)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
d68209402d

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-23 14:59:19 +08:00
committed by GitHub
parent 193acc2c19
commit 78af0c30a3
25 changed files with 760 additions and 996 deletions

View File

@@ -14,7 +14,6 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Optional
import torch
import triton # type: ignore
@@ -61,22 +60,19 @@ def split_qkv_rmsnorm_rope_kernel(
for row_idx in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE)
valid_mask = col_indices < q_hidden_size
input_values = (tl.load(input_ptr + input_offset + col_indices,
mask=valid_mask,
other=0.0).to(tl.float32).reshape(
Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM))
input_values = (
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
.to(tl.float32)
.reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
)
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
Q_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = (input_values * reciprocal_std
) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(Q_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = input_values * reciprocal_std # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values +
bias_values).to(tl.bfloat16)
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(
tl.bfloat16)
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
@@ -93,8 +89,7 @@ def split_qkv_rmsnorm_rope_kernel(
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
dtype=tl.bfloat16)
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
cat_x = tl.insert_slice(
cat_x,
-x2,
@@ -127,22 +122,19 @@ def split_qkv_rmsnorm_rope_kernel(
for row_idx in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
valid_mask = col_indices < kv_hidden_size
input_values = (tl.load(input_ptr + input_offset + col_indices,
mask=valid_mask,
other=0.0).to(tl.float32).reshape(
KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM))
input_values = (
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
.to(tl.float32)
.reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
)
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
KV_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = (input_values * reciprocal_std
) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(KV_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = input_values * reciprocal_std # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values +
bias_values).to(tl.bfloat16)
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(
tl.bfloat16)
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
@@ -158,8 +150,7 @@ def split_qkv_rmsnorm_rope_kernel(
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
dtype=tl.bfloat16)
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
cat_x = tl.insert_slice(
cat_x,
-x2,
@@ -189,12 +180,8 @@ def split_qkv_rmsnorm_rope_kernel(
for _ in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
valid_mask = col_indices < kv_hidden_size
input_values = tl.load(input_ptr + input_offset + col_indices,
mask=valid_mask,
other=0.0)
tl.store(v_ptr + output_offset + col_indices,
input_values,
mask=valid_mask)
input_values = tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
tl.store(v_ptr + output_offset + col_indices, input_values, mask=valid_mask)
input_offset += input_offset_step
output_offset += output_offset_step
@@ -209,27 +196,18 @@ def split_qkv_rmsnorm_rope_impl(
kv_hidden_size: int,
head_dim: int,
eps: float,
q_bias: Optional[torch.Tensor] = None,
k_bias: Optional[torch.Tensor] = None,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
KV_BLOCK_SIZE = triton.next_power_of_2(head_dim)
assert KV_BLOCK_SIZE == head_dim
assert head_dim == KV_BLOCK_SIZE
assert q_hidden_size % kv_hidden_size == 0
Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim
batch_size = input.shape[0]
total_hidden_size = q_hidden_size + kv_hidden_size * 2
q_output = torch.empty(batch_size,
q_hidden_size,
device=input.device,
dtype=input.dtype)
k_output = torch.empty(batch_size,
kv_hidden_size,
device=input.device,
dtype=input.dtype)
v_output = torch.empty(batch_size,
kv_hidden_size,
device=input.device,
dtype=input.dtype)
q_output = torch.empty(batch_size, q_hidden_size, device=input.device, dtype=input.dtype)
k_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype)
v_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype)
n_cols = kv_hidden_size // KV_BLOCK_SIZE
num_vectorcore = get_vectorcore_num()
assert num_vectorcore % n_cols == 0
@@ -271,8 +249,8 @@ def split_qkv_rmsnorm_rope_impl_fake(
kv_hidden_size: int,
head_dim: int,
eps: float,
q_bias: Optional[torch.Tensor] = None,
k_bias: Optional[torch.Tensor] = None,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Fake implementation for shape inference during Dynamo/AOT tracing.
# Note: sin and cos are not used in shape computation, but must be present in signature.
@@ -298,8 +276,10 @@ def split_qkv_rmsnorm_rope_impl_fake(
return q_output, k_output, v_output
direct_register_custom_op(op_name="qkv_rmsnorm_rope",
op_func=split_qkv_rmsnorm_rope_impl,
fake_impl=split_qkv_rmsnorm_rope_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(
op_name="qkv_rmsnorm_rope",
op_func=split_qkv_rmsnorm_rope_impl,
fake_impl=split_qkv_rmsnorm_rope_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1",
)