Files
xc-llm-ascend/vllm_ascend/_310p/ops/fla/fused_gdn_gating.py
Shaoxu Cheng 13397e9cb7 [310p] Add a PyTorch implementation of the GDN gating operator on 310P (#7430)
### What this PR does / why we need it?
RFC #7394
Add a PyTorch implementation of the GDN gating operator on 310P.

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
UT

- vLLM version: v0.17.0
- vLLM main:
4497431df6

Signed-off-by: Tflowers-0129 <2906339855@qq.com>
2026-03-23 20:26:39 +08:00

63 lines
1.8 KiB
Python

import torch
def fused_gdn_gating_pytorch(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
PyTorch implementation of fused_gdn_gating.
This is a fallback implementation for 310P without Triton support.
Args:
A_log: Log of A parameter, shape [num_heads]
a: a parameter, shape [batch, num_heads]
b: b parameter, shape [batch, num_heads]
dt_bias: dt bias, shape [num_heads]
beta: softplus beta parameter
threshold: softplus threshold parameter
Returns:
g: gating parameter, shape [1, batch, num_heads]
beta_output: sigmoid(b), shape [1, batch, num_heads]
"""
batch, num_heads = a.shape
del num_heads
# Keep nonlinear gating math in fp32 for stability.
compute_dtype = torch.float32
A_log_f = A_log.to(compute_dtype)
a_f = a.to(compute_dtype)
b_f = b.to(compute_dtype)
dt_bias_f = dt_bias.to(compute_dtype)
# Expand A_log and dt_bias to match a shape.
A_log_expanded = A_log_f.unsqueeze(0).expand(batch, -1)
dt_bias_expanded = dt_bias_f.unsqueeze(0).expand(batch, -1)
# Compute x = a + dt_bias.
x = a_f + dt_bias_expanded
# Compute softplus(x).
beta_x = beta * x
softplus_x = torch.where(
beta_x <= threshold,
(1.0 / beta) * torch.log1p(torch.exp(beta_x)),
x,
)
# Compute g = -exp(A_log) * softplus(x).
g = -torch.exp(A_log_expanded) * softplus_x
# Add sequence dimension.
g = g.unsqueeze(0)
# Match Triton kernel: sigmoid in fp32, then cast to input b dtype.
beta_output = torch.sigmoid(b_f).to(b.dtype)
beta_output = beta_output.unsqueeze(0)
return g, beta_output