[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #12) (#6177)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
d68209402d

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-23 14:59:19 +08:00
committed by GitHub
parent 193acc2c19
commit 78af0c30a3
25 changed files with 760 additions and 996 deletions

View File

@@ -14,9 +14,10 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from vllm.triton_utils import tl, triton
import torch
from typing import Tuple
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
@@ -48,16 +49,16 @@ def _triton_rope(
This triton kernel applies rotary embedding on q and k.
It supports rope_dim != head_dim scenario.
It supports both neox style and non-neox style rope computation.
Input tensor layout assumptions:
q size: (num_tokens, num_q_heads, head_dim)
q stride: (num_q_heads * head_dim, head_dim, 1)
k size: (num_tokens, num_kv_heads, head_dim)
k stride: (num_kv_heads * head_dim, head_dim, 1)
cos/sin size: (num_tokens, rope_dim/2)
cos/sin stride: (rope_dim/2, 1)
Different compute pattern of IS_NEOX_STYLE:
if IS_NEOX_STYLE:
@@ -88,10 +89,8 @@ def _triton_rope(
cos_offsets = tl.arange(0, pad_rope_dim // 2)
cos_mask = cos_offsets < (rope_dim // 2)
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask,
other=0).to(tl.float32)
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask,
other=0).to(tl.float32)
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
# ####################################################################
# Load the left and right half of q and k for the current
@@ -99,28 +98,20 @@ def _triton_rope(
# ####################################################################
# left half of the head
if IS_NEOX_STYLE:
first_half_q_offsets = tl.arange(
0, pad_n_qh)[:, None] * hd + tl.arange(
0, pad_rope_dim // 2)[None, :]
first_half_k_offsets = tl.arange(
0, pad_n_kh)[:, None] * hd + tl.arange(
0, pad_rope_dim // 2)[None, :]
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_rope_dim // 2)[None, :]
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_rope_dim // 2)[None, :]
else:
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + (
2 * tl.arange(0, pad_rope_dim // 2)[None, :])
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + (
2 * tl.arange(0, pad_rope_dim // 2)[None, :])
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + (2 * tl.arange(0, pad_rope_dim // 2)[None, :])
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + (2 * tl.arange(0, pad_rope_dim // 2)[None, :])
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
0, pad_rope_dim // 2)[None, :] < (rope_dim // 2))
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
0, pad_rope_dim // 2)[None, :] < (rope_dim // 2))
q_tile_1 = tl.load(q_start_ptr + first_half_q_offsets,
mask=first_q_mask,
other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_start_ptr + first_half_k_offsets,
mask=first_k_mask,
other=0).to(sin_row.dtype)
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
tl.arange(0, pad_rope_dim // 2)[None, :] < (rope_dim // 2)
)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
tl.arange(0, pad_rope_dim // 2)[None, :] < (rope_dim // 2)
)
q_tile_1 = tl.load(q_start_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_start_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
# right half of the head
if IS_NEOX_STYLE:
@@ -131,41 +122,29 @@ def _triton_rope(
second_half_k_offsets = first_half_k_offsets + 1
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(q_start_ptr + second_half_q_offsets,
mask=second_q_mask,
other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_start_ptr + second_half_k_offsets,
mask=second_k_mask,
other=0).to(sin_row.dtype)
q_tile_2 = tl.load(q_start_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_start_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_start_ptr + first_half_q_offsets,
new_q_tile_1,
mask=first_q_mask)
tl.store(q_start_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_start_ptr + second_half_q_offsets,
new_q_tile_2,
mask=second_q_mask)
tl.store(q_start_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_start_ptr + first_half_k_offsets,
new_k_tile_1,
mask=first_k_mask)
tl.store(k_start_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_start_ptr + second_half_k_offsets,
new_k_tile_2,
mask=second_k_mask)
tl.store(k_start_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
def rope_forward_triton(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
rope_dim: int = -1,
is_neox_style: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
rope_dim: int = -1,
is_neox_style: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
if not q.is_contiguous():
q = q.contiguous()
if not k.is_contiguous():
@@ -187,7 +166,7 @@ def rope_forward_triton(
num_vectorcore = get_vectorcore_num()
n_row = min(num_tokens, num_vectorcore)
_triton_rope[(n_row, )](
_triton_rope[(n_row,)](
q,
q.stride(0),
k,