### What this PR does / why we need it?
RFC #7394
Add a PyTorch implementation of the GDN gating operator on 310P.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
UT
- vLLM version: v0.17.0
- vLLM main:
4497431df6
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
import torch
|
|
|
|
|
|
def fused_gdn_gating_pytorch(
|
|
A_log: torch.Tensor,
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
dt_bias: torch.Tensor,
|
|
beta: float = 1.0,
|
|
threshold: float = 20.0,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
PyTorch implementation of fused_gdn_gating.
|
|
This is a fallback implementation for 310P without Triton support.
|
|
|
|
Args:
|
|
A_log: Log of A parameter, shape [num_heads]
|
|
a: a parameter, shape [batch, num_heads]
|
|
b: b parameter, shape [batch, num_heads]
|
|
dt_bias: dt bias, shape [num_heads]
|
|
beta: softplus beta parameter
|
|
threshold: softplus threshold parameter
|
|
|
|
Returns:
|
|
g: gating parameter, shape [1, batch, num_heads]
|
|
beta_output: sigmoid(b), shape [1, batch, num_heads]
|
|
"""
|
|
batch, num_heads = a.shape
|
|
del num_heads
|
|
# Keep nonlinear gating math in fp32 for stability.
|
|
compute_dtype = torch.float32
|
|
A_log_f = A_log.to(compute_dtype)
|
|
a_f = a.to(compute_dtype)
|
|
b_f = b.to(compute_dtype)
|
|
dt_bias_f = dt_bias.to(compute_dtype)
|
|
|
|
# Expand A_log and dt_bias to match a shape.
|
|
A_log_expanded = A_log_f.unsqueeze(0).expand(batch, -1)
|
|
dt_bias_expanded = dt_bias_f.unsqueeze(0).expand(batch, -1)
|
|
|
|
# Compute x = a + dt_bias.
|
|
x = a_f + dt_bias_expanded
|
|
|
|
# Compute softplus(x).
|
|
beta_x = beta * x
|
|
softplus_x = torch.where(
|
|
beta_x <= threshold,
|
|
(1.0 / beta) * torch.log1p(torch.exp(beta_x)),
|
|
x,
|
|
)
|
|
|
|
# Compute g = -exp(A_log) * softplus(x).
|
|
g = -torch.exp(A_log_expanded) * softplus_x
|
|
|
|
# Add sequence dimension.
|
|
g = g.unsqueeze(0)
|
|
|
|
# Match Triton kernel: sigmoid in fp32, then cast to input b dtype.
|
|
beta_output = torch.sigmoid(b_f).to(b.dtype)
|
|
beta_output = beta_output.unsqueeze(0)
|
|
|
|
return g, beta_output
|