### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -8,23 +8,24 @@
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import prepare_chunk_indices, prepare_chunk_offsets, safe_exp
|
||||
|
||||
_CONDITIONS = ("seq7168", )
|
||||
_CONDITIONS = ("seq7168",)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
})
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
k,
|
||||
@@ -85,28 +86,20 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
if USE_INITIAL_STATE:
|
||||
h0_ptr = h0 + i_nh * K * V
|
||||
ptr_h0_bv1 = h0_ptr + offs_k * V + offs_v1 * 1
|
||||
b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1,
|
||||
other=0.0).to(tl.float32)
|
||||
b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1, other=0.0).to(tl.float32)
|
||||
|
||||
ptr_h0_bv2 = h0_ptr + offs_k * V + offs_v2 * 1
|
||||
b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2,
|
||||
other=0.0).to(tl.float32)
|
||||
b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2, other=0.0).to(tl.float32)
|
||||
|
||||
# main recurrence
|
||||
for i_t in range(NT):
|
||||
h_base = h + (boh + i_t) * H * K * V + i_h * K * V
|
||||
|
||||
p_h1_bv1 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start1),
|
||||
(128, 64), (1, 0))
|
||||
tl.store(p_h1_bv1,
|
||||
b_h1_bv1.to(p_h1_bv1.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
p_h1_bv1 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start1), (128, 64), (1, 0))
|
||||
tl.store(p_h1_bv1, b_h1_bv1.to(p_h1_bv1.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
p_h1_bv2 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start2),
|
||||
(128, 64), (1, 0))
|
||||
tl.store(p_h1_bv2,
|
||||
b_h1_bv2.to(p_h1_bv2.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
p_h1_bv2 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start2), (128, 64), (1, 0))
|
||||
tl.store(p_h1_bv2, b_h1_bv2.to(p_h1_bv2.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
offs_t_wv = (i_t * BT + tl.arange(0, BT))[:, None]
|
||||
offs_k_wv = tl.arange(0, 128)[None, :]
|
||||
@@ -117,8 +110,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
b_w = tl.load(ptr_w, mask=mask_w, other=0.0)
|
||||
|
||||
k_base = k + bos * Hg * K + (i_h // (H // Hg)) * K
|
||||
p_k = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT),
|
||||
(128, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT), (128, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
|
||||
v_new_base = v_new + bos * H * V + i_h * V
|
||||
@@ -144,12 +136,8 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
b_v_new1 -= tl.dot(b_w, b_h1_bv1.to(b_w.dtype))
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v_new1 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1),
|
||||
(i_t * BT, v_start1), (BT, 64),
|
||||
(1, 0))
|
||||
tl.store(p_v_new1,
|
||||
b_v_new1.to(p_v_new1.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
p_v_new1 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), (i_t * BT, v_start1), (BT, 64), (1, 0))
|
||||
tl.store(p_v_new1, b_v_new1.to(p_v_new1.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
b_v_new1 = b_v_new1 * b_g[:, None]
|
||||
@@ -165,12 +153,8 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
b_v_new2 -= tl.dot(b_w, b_h1_bv2.to(b_w.dtype))
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v_new2 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1),
|
||||
(i_t * BT, v_start2), (BT, 64),
|
||||
(1, 0))
|
||||
tl.store(p_v_new2,
|
||||
b_v_new2.to(p_v_new2.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
p_v_new2 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), (i_t * BT, v_start2), (BT, 64), (1, 0))
|
||||
tl.store(p_v_new2, b_v_new2.to(p_v_new2.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
b_v_new2 = b_v_new2 * b_g[:, None]
|
||||
@@ -183,29 +167,23 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
if STORE_FINAL_STATE:
|
||||
ht_ptr = ht + i_nh * K * V
|
||||
|
||||
p_ht1_bv1 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start1),
|
||||
(128, 64), (1, 0))
|
||||
tl.store(p_ht1_bv1,
|
||||
b_h1_bv1.to(p_ht1_bv1.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
p_ht1_bv1 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start1), (128, 64), (1, 0))
|
||||
tl.store(p_ht1_bv1, b_h1_bv1.to(p_ht1_bv1.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
p_ht1_bv2 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start2),
|
||||
(128, 64), (1, 0))
|
||||
tl.store(p_ht1_bv2,
|
||||
b_h1_bv2.to(p_ht1_bv2.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
p_ht1_bv2 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start2), (128, 64), (1, 0))
|
||||
tl.store(p_ht1_bv2, b_h1_bv2.to(p_ht1_bv2.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd_h(
|
||||
k: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
g: torch.Tensor | None = None,
|
||||
initial_state: torch.Tensor | None = None,
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||
@@ -213,8 +191,7 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = (prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is not None else None)
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
@@ -227,8 +204,7 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, K, V)
|
||||
final_state = (k.new_empty(N, H, K, V, dtype=torch.float32)
|
||||
if output_final_state else None)
|
||||
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
||||
|
||||
v_new = torch.empty_like(u) if save_new_value else None
|
||||
g = g.transpose(1, 2).contiguous()
|
||||
|
||||
Reference in New Issue
Block a user