[310P]fused recurrent gated delta rule pytorch core and ut (#7398)
### What this PR does / why we need it?
RFC https://github.com/vllm-project/vllm-ascend/issues/7394
Add a PyTorch implementation of the fused recurrent gated delta ruler on
310P.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
UT
- vLLM version: v0.17.0
- vLLM main:
4497431df6
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -0,0 +1,79 @@
|
|||||||
|
import torch
|
||||||
|
from vllm.model_executor.layers.fla.ops import fused_recurrent_gated_delta_rule
|
||||||
|
|
||||||
|
from vllm_ascend._310p.ops.fla.fused_recurrent_gated_delta_rule import fused_recurrent_gated_delta_rule_pytorch
|
||||||
|
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||||
|
|
||||||
|
|
||||||
|
def test_fused_recurrent_gated_delta_rule_310p_parity_precision():
|
||||||
|
init_device_properties_triton()
|
||||||
|
torch.manual_seed(0)
|
||||||
|
device = "npu"
|
||||||
|
|
||||||
|
bsz = 1
|
||||||
|
total_tokens = 9
|
||||||
|
num_qk_heads = 2
|
||||||
|
num_v_heads = 4
|
||||||
|
kdim = 64
|
||||||
|
vdim = 48
|
||||||
|
|
||||||
|
q = torch.randn(bsz, total_tokens, num_qk_heads, kdim, dtype=torch.float16, device=device)
|
||||||
|
k = torch.randn(bsz, total_tokens, num_qk_heads, kdim, dtype=torch.float16, device=device)
|
||||||
|
v = torch.randn(bsz, total_tokens, num_v_heads, vdim, dtype=torch.float16, device=device)
|
||||||
|
g = torch.randn(bsz, total_tokens, num_v_heads, dtype=torch.float32, device=device)
|
||||||
|
beta = torch.sigmoid(torch.randn(bsz, total_tokens, num_v_heads, dtype=torch.float32, device=device)).to(
|
||||||
|
torch.float16
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_state = torch.randn(2, num_v_heads, kdim, vdim, dtype=torch.float16, device=device)
|
||||||
|
cu_seqlens = torch.tensor([0, 4, 9], dtype=torch.long, device=device)
|
||||||
|
# For inplace_final_state=True, Ascend triton kernel expects explicit per-token state indices.
|
||||||
|
# seq0 (len=4) -> state 0, seq1 (len=5) -> state 1.
|
||||||
|
ssm_state_indices = torch.tensor(
|
||||||
|
[
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 1, 1],
|
||||||
|
],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
triton_out, triton_state = fused_recurrent_gated_delta_rule(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
g=g,
|
||||||
|
beta=beta,
|
||||||
|
initial_state=initial_state.clone(),
|
||||||
|
inplace_final_state=True,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
ssm_state_indices=ssm_state_indices,
|
||||||
|
use_qk_l2norm_in_kernel=True,
|
||||||
|
)
|
||||||
|
ref_out, ref_state = fused_recurrent_gated_delta_rule_pytorch(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
g=g,
|
||||||
|
beta=beta,
|
||||||
|
initial_state=initial_state.clone(),
|
||||||
|
inplace_final_state=True,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
ssm_state_indices=ssm_state_indices,
|
||||||
|
use_qk_l2norm_in_kernel=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
triton_out.to(torch.float32).cpu(),
|
||||||
|
ref_out.to(torch.float32).cpu(),
|
||||||
|
rtol=1e-2,
|
||||||
|
atol=1e-2,
|
||||||
|
equal_nan=True,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
triton_state.to(torch.float32).cpu(),
|
||||||
|
ref_state.to(torch.float32).cpu(),
|
||||||
|
rtol=1e-2,
|
||||||
|
atol=1e-2,
|
||||||
|
equal_nan=True,
|
||||||
|
)
|
||||||
@@ -1,3 +1,7 @@
|
|||||||
from .fused_gdn_gating import fused_gdn_gating_pytorch
|
from .fused_gdn_gating import fused_gdn_gating_pytorch
|
||||||
|
from .fused_recurrent_gated_delta_rule import fused_recurrent_gated_delta_rule_pytorch
|
||||||
|
|
||||||
__all__ = ["fused_gdn_gating_pytorch"]
|
__all__ = [
|
||||||
|
"fused_gdn_gating_pytorch",
|
||||||
|
"fused_recurrent_gated_delta_rule_pytorch",
|
||||||
|
]
|
||||||
|
|||||||
225
vllm_ascend/_310p/ops/fla/fused_recurrent_gated_delta_rule.py
Normal file
225
vllm_ascend/_310p/ops/fla/fused_recurrent_gated_delta_rule.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_l2norm(x: torch.Tensor, enabled: bool) -> torch.Tensor:
|
||||||
|
if not enabled:
|
||||||
|
return x
|
||||||
|
return x / (torch.sqrt(torch.sum(x * x, dim=-1, keepdim=True)) + 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_to_hv(x: torch.Tensor, hv: int) -> torch.Tensor:
|
||||||
|
"""Expand [H, ...] to [HV, ...] for grouped-value-attention semantics."""
|
||||||
|
h = x.shape[0]
|
||||||
|
if h == hv:
|
||||||
|
return x
|
||||||
|
if hv % h != 0:
|
||||||
|
raise ValueError(f"Cannot expand head dim from {h} to {hv}.")
|
||||||
|
return x.repeat_interleave(hv // h, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_num_states(
|
||||||
|
default_n: int,
|
||||||
|
initial_state: torch.Tensor | None,
|
||||||
|
ssm_state_indices: torch.Tensor | None,
|
||||||
|
) -> int:
|
||||||
|
if initial_state is not None:
|
||||||
|
return initial_state.shape[0]
|
||||||
|
if ssm_state_indices is None:
|
||||||
|
return default_n
|
||||||
|
nonneg = ssm_state_indices[ssm_state_indices >= 0]
|
||||||
|
if nonneg.numel() == 0:
|
||||||
|
return default_n
|
||||||
|
return int(nonneg.max().item()) + 1
|
||||||
|
|
||||||
|
|
||||||
|
def _state_index(
|
||||||
|
seq_idx: int,
|
||||||
|
tok_idx: int,
|
||||||
|
ssm_state_indices: torch.Tensor | None,
|
||||||
|
) -> int:
|
||||||
|
if ssm_state_indices is None:
|
||||||
|
return seq_idx
|
||||||
|
if ssm_state_indices.ndim == 1:
|
||||||
|
return int(ssm_state_indices[seq_idx].item())
|
||||||
|
return int(ssm_state_indices[seq_idx, tok_idx].item())
|
||||||
|
|
||||||
|
|
||||||
|
def _run_recurrent_gated_delta_rule(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
g: torch.Tensor | None,
|
||||||
|
beta: torch.Tensor | None,
|
||||||
|
states: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
cu_seqlens: torch.Tensor | None,
|
||||||
|
ssm_state_indices: torch.Tensor | None,
|
||||||
|
num_accepted_tokens: torch.Tensor | None,
|
||||||
|
use_initial_state: bool,
|
||||||
|
use_qk_l2norm_in_kernel: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Reference PyTorch recurrence for GDN delta rule.
|
||||||
|
|
||||||
|
Shapes follow fla.ops conventions:
|
||||||
|
q,k: [B, T, H, K]
|
||||||
|
v: [B, T, HV, V]
|
||||||
|
g,beta: [B, T, HV] (beta may also be [B, T, HV, V])
|
||||||
|
states: [N_state, HV, K, V]
|
||||||
|
"""
|
||||||
|
B, T, _, Kdim = k.shape
|
||||||
|
HV = v.shape[2]
|
||||||
|
Vdim = v.shape[-1]
|
||||||
|
|
||||||
|
if cu_seqlens is not None and B != 1:
|
||||||
|
raise ValueError("Variable-length mode expects batch size B=1.")
|
||||||
|
|
||||||
|
out = torch.zeros_like(v)
|
||||||
|
|
||||||
|
if cu_seqlens is None:
|
||||||
|
seq_ranges = [(i, 0, T) for i in range(B)]
|
||||||
|
else:
|
||||||
|
n_seq = len(cu_seqlens) - 1
|
||||||
|
seq_ranges = [
|
||||||
|
(
|
||||||
|
i,
|
||||||
|
int(cu_seqlens[i].item()),
|
||||||
|
int(cu_seqlens[i + 1].item()),
|
||||||
|
)
|
||||||
|
for i in range(n_seq)
|
||||||
|
]
|
||||||
|
|
||||||
|
for seq_idx, start, end in seq_ranges:
|
||||||
|
seq_len = end - start
|
||||||
|
if seq_len <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
accepted = None
|
||||||
|
if num_accepted_tokens is not None:
|
||||||
|
accepted = int(num_accepted_tokens[seq_idx].item())
|
||||||
|
seq_len = min(seq_len, accepted)
|
||||||
|
if seq_len <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if use_initial_state:
|
||||||
|
if ssm_state_indices is None:
|
||||||
|
init_state_idx = seq_idx
|
||||||
|
else:
|
||||||
|
init_tok = (accepted - 1) if accepted is not None else 0
|
||||||
|
init_state_idx = _state_index(seq_idx, init_tok, ssm_state_indices)
|
||||||
|
if init_state_idx < 0:
|
||||||
|
# Match triton behavior for invalid PAD_SLOT_ID in continuous batching.
|
||||||
|
continue
|
||||||
|
if init_state_idx >= states.shape[0]:
|
||||||
|
raise IndexError(f"state_idx {init_state_idx} out of range for states size {states.shape[0]}")
|
||||||
|
h_t = states[init_state_idx].transpose(-1, -2).to(torch.float32)
|
||||||
|
else:
|
||||||
|
h_t = torch.zeros(HV, Vdim, Kdim, dtype=torch.float32, device=q.device)
|
||||||
|
|
||||||
|
for rel_t in range(seq_len):
|
||||||
|
tok = start + rel_t
|
||||||
|
|
||||||
|
if cu_seqlens is None:
|
||||||
|
q_t = q[seq_idx, tok]
|
||||||
|
k_t = k[seq_idx, tok]
|
||||||
|
v_t = v[seq_idx, tok]
|
||||||
|
g_t = g[seq_idx, tok] if g is not None else None
|
||||||
|
beta_t = beta[seq_idx, tok] if beta is not None else None
|
||||||
|
else:
|
||||||
|
q_t = q[0, tok]
|
||||||
|
k_t = k[0, tok]
|
||||||
|
v_t = v[0, tok]
|
||||||
|
g_t = g[0, tok] if g is not None else None
|
||||||
|
beta_t = beta[0, tok] if beta is not None else None
|
||||||
|
|
||||||
|
# Match Triton kernel math: load to fp32 first, then apply l2norm.
|
||||||
|
q_t = q_t.to(torch.float32)
|
||||||
|
k_t = k_t.to(torch.float32)
|
||||||
|
q_t = _maybe_l2norm(q_t, use_qk_l2norm_in_kernel)
|
||||||
|
k_t = _maybe_l2norm(k_t, use_qk_l2norm_in_kernel)
|
||||||
|
v_t = v_t.to(torch.float32)
|
||||||
|
q_t = q_t * scale
|
||||||
|
|
||||||
|
q_hv = _expand_to_hv(q_t, HV)
|
||||||
|
k_hv = _expand_to_hv(k_t, HV)
|
||||||
|
|
||||||
|
if g_t is not None:
|
||||||
|
g_t = g_t.to(torch.float32)
|
||||||
|
if g_t.ndim == 0:
|
||||||
|
g_t = g_t.expand(HV)
|
||||||
|
elif g_t.shape[0] != HV:
|
||||||
|
g_t = _expand_to_hv(g_t.unsqueeze(-1), HV).squeeze(-1)
|
||||||
|
h_t = h_t * torch.exp(g_t).view(HV, 1, 1)
|
||||||
|
|
||||||
|
v_t = v_t - torch.sum(h_t * k_hv.unsqueeze(-2), dim=-1)
|
||||||
|
|
||||||
|
if beta_t is not None:
|
||||||
|
beta_t = beta_t.to(torch.float32)
|
||||||
|
if beta_t.ndim == 1:
|
||||||
|
if beta_t.shape[0] != HV:
|
||||||
|
beta_t = _expand_to_hv(beta_t.unsqueeze(-1), HV).squeeze(-1)
|
||||||
|
v_t = v_t * beta_t.view(HV, 1)
|
||||||
|
else:
|
||||||
|
if beta_t.shape[0] != HV:
|
||||||
|
beta_t = _expand_to_hv(beta_t, HV)
|
||||||
|
v_t = v_t * beta_t
|
||||||
|
|
||||||
|
h_t = h_t + v_t.unsqueeze(-1) * k_hv.unsqueeze(-2)
|
||||||
|
o_t = torch.sum(h_t * q_hv.unsqueeze(-2), dim=-1)
|
||||||
|
|
||||||
|
if cu_seqlens is None:
|
||||||
|
out[seq_idx, tok] = o_t.to(out.dtype)
|
||||||
|
else:
|
||||||
|
out[0, tok] = o_t.to(out.dtype)
|
||||||
|
|
||||||
|
state_idx = _state_index(seq_idx, rel_t, ssm_state_indices)
|
||||||
|
if state_idx >= 0:
|
||||||
|
if state_idx >= states.shape[0]:
|
||||||
|
raise IndexError(f"state_idx {state_idx} out of range for states size {states.shape[0]}")
|
||||||
|
states[state_idx] = h_t.transpose(-1, -2).to(states.dtype)
|
||||||
|
|
||||||
|
return out, states
|
||||||
|
|
||||||
|
|
||||||
|
def fused_recurrent_gated_delta_rule_pytorch(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
g,
|
||||||
|
beta,
|
||||||
|
initial_state=None,
|
||||||
|
inplace_final_state=False,
|
||||||
|
cu_seqlens=None,
|
||||||
|
ssm_state_indices=None,
|
||||||
|
num_accepted_tokens=None,
|
||||||
|
use_qk_l2norm_in_kernel=False,
|
||||||
|
):
|
||||||
|
"""PyTorch fallback for fused_recurrent_gated_delta_rule."""
|
||||||
|
B, _, _, Kdim = k.shape
|
||||||
|
HV = v.shape[2]
|
||||||
|
Vdim = v.shape[-1]
|
||||||
|
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
||||||
|
|
||||||
|
n_states = _infer_num_states(N, initial_state, ssm_state_indices)
|
||||||
|
if initial_state is not None:
|
||||||
|
states = initial_state if inplace_final_state else initial_state.clone()
|
||||||
|
else:
|
||||||
|
states = torch.zeros(n_states, HV, Kdim, Vdim, dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
|
scale = Kdim**-0.5
|
||||||
|
out, states = _run_recurrent_gated_delta_rule(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
g=g,
|
||||||
|
beta=beta,
|
||||||
|
states=states,
|
||||||
|
scale=scale,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
ssm_state_indices=ssm_state_indices,
|
||||||
|
num_accepted_tokens=num_accepted_tokens,
|
||||||
|
use_initial_state=initial_state is not None,
|
||||||
|
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out, states
|
||||||
Reference in New Issue
Block a user