[310P]fused recurrent gated delta rule pytorch core and ut (#7398)

### What this PR does / why we need it?
RFC https://github.com/vllm-project/vllm-ascend/issues/7394
Add a PyTorch implementation of the fused recurrent gated delta ruler on
310P.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
UT
- vLLM version: v0.17.0
- vLLM main:
4497431df6

---------

Signed-off-by: Tflowers-0129 <2906339855@qq.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
Shaoxu Cheng
2026-03-25 08:53:14 +08:00
committed by GitHub
parent 54879467c4
commit 3f4087a8f0
3 changed files with 309 additions and 1 deletions

View File

@@ -0,0 +1,79 @@
import torch
from vllm.model_executor.layers.fla.ops import fused_recurrent_gated_delta_rule
from vllm_ascend._310p.ops.fla.fused_recurrent_gated_delta_rule import fused_recurrent_gated_delta_rule_pytorch
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
def test_fused_recurrent_gated_delta_rule_310p_parity_precision():
init_device_properties_triton()
torch.manual_seed(0)
device = "npu"
bsz = 1
total_tokens = 9
num_qk_heads = 2
num_v_heads = 4
kdim = 64
vdim = 48
q = torch.randn(bsz, total_tokens, num_qk_heads, kdim, dtype=torch.float16, device=device)
k = torch.randn(bsz, total_tokens, num_qk_heads, kdim, dtype=torch.float16, device=device)
v = torch.randn(bsz, total_tokens, num_v_heads, vdim, dtype=torch.float16, device=device)
g = torch.randn(bsz, total_tokens, num_v_heads, dtype=torch.float32, device=device)
beta = torch.sigmoid(torch.randn(bsz, total_tokens, num_v_heads, dtype=torch.float32, device=device)).to(
torch.float16
)
initial_state = torch.randn(2, num_v_heads, kdim, vdim, dtype=torch.float16, device=device)
cu_seqlens = torch.tensor([0, 4, 9], dtype=torch.long, device=device)
# For inplace_final_state=True, Ascend triton kernel expects explicit per-token state indices.
# seq0 (len=4) -> state 0, seq1 (len=5) -> state 1.
ssm_state_indices = torch.tensor(
[
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
],
dtype=torch.long,
device=device,
)
triton_out, triton_state = fused_recurrent_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state.clone(),
inplace_final_state=True,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
use_qk_l2norm_in_kernel=True,
)
ref_out, ref_state = fused_recurrent_gated_delta_rule_pytorch(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state.clone(),
inplace_final_state=True,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
use_qk_l2norm_in_kernel=True,
)
torch.testing.assert_close(
triton_out.to(torch.float32).cpu(),
ref_out.to(torch.float32).cpu(),
rtol=1e-2,
atol=1e-2,
equal_nan=True,
)
torch.testing.assert_close(
triton_state.to(torch.float32).cpu(),
ref_state.to(torch.float32).cpu(),
rtol=1e-2,
atol=1e-2,
equal_nan=True,
)