[310p] Add a PyTorch implementation of the GDN gating operator on 310P (#7430)
### What this PR does / why we need it?
RFC #7394
Add a PyTorch implementation of the GDN gating operator on 310P.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
UT
- vLLM version: v0.17.0
- vLLM main:
4497431df6
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
@@ -0,0 +1,51 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm_ascend._310p.ops.fla.fused_gdn_gating import fused_gdn_gating_pytorch
|
||||||
|
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
||||||
|
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||||
|
|
||||||
|
|
||||||
|
def test_fused_gdn_gating_310p_parity_precision():
|
||||||
|
init_device_properties_triton()
|
||||||
|
torch.manual_seed(0)
|
||||||
|
device = "npu"
|
||||||
|
|
||||||
|
num_tokens = 37
|
||||||
|
num_heads = 8
|
||||||
|
|
||||||
|
A_log = torch.randn(num_heads, dtype=torch.float16, device=device)
|
||||||
|
dt_bias = torch.randn(num_heads, dtype=torch.float16, device=device)
|
||||||
|
a = torch.randn(num_tokens, num_heads, dtype=torch.float16, device=device)
|
||||||
|
b = torch.randn(num_tokens, num_heads, dtype=torch.float16, device=device)
|
||||||
|
|
||||||
|
triton_g, triton_beta = fused_gdn_gating_patch(
|
||||||
|
A_log=A_log,
|
||||||
|
a=a,
|
||||||
|
b=b,
|
||||||
|
dt_bias=dt_bias,
|
||||||
|
beta=1.0,
|
||||||
|
threshold=20.0,
|
||||||
|
)
|
||||||
|
ref_g, ref_beta = fused_gdn_gating_pytorch(
|
||||||
|
A_log=A_log,
|
||||||
|
a=a,
|
||||||
|
b=b,
|
||||||
|
dt_bias=dt_bias,
|
||||||
|
beta=1.0,
|
||||||
|
threshold=20.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
triton_g.to(torch.float32).cpu(),
|
||||||
|
ref_g.to(torch.float32).cpu(),
|
||||||
|
rtol=1e-2,
|
||||||
|
atol=1e-2,
|
||||||
|
equal_nan=True,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
triton_beta.to(torch.float32).cpu(),
|
||||||
|
ref_beta.to(torch.float32).cpu(),
|
||||||
|
rtol=1e-2,
|
||||||
|
atol=1e-2,
|
||||||
|
equal_nan=True,
|
||||||
|
)
|
||||||
3
vllm_ascend/_310p/ops/fla/__init__.py
Normal file
3
vllm_ascend/_310p/ops/fla/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .fused_gdn_gating import fused_gdn_gating_pytorch
|
||||||
|
|
||||||
|
__all__ = ["fused_gdn_gating_pytorch"]
|
||||||
62
vllm_ascend/_310p/ops/fla/fused_gdn_gating.py
Normal file
62
vllm_ascend/_310p/ops/fla/fused_gdn_gating.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def fused_gdn_gating_pytorch(
|
||||||
|
A_log: torch.Tensor,
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
dt_bias: torch.Tensor,
|
||||||
|
beta: float = 1.0,
|
||||||
|
threshold: float = 20.0,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
PyTorch implementation of fused_gdn_gating.
|
||||||
|
This is a fallback implementation for 310P without Triton support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
A_log: Log of A parameter, shape [num_heads]
|
||||||
|
a: a parameter, shape [batch, num_heads]
|
||||||
|
b: b parameter, shape [batch, num_heads]
|
||||||
|
dt_bias: dt bias, shape [num_heads]
|
||||||
|
beta: softplus beta parameter
|
||||||
|
threshold: softplus threshold parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
g: gating parameter, shape [1, batch, num_heads]
|
||||||
|
beta_output: sigmoid(b), shape [1, batch, num_heads]
|
||||||
|
"""
|
||||||
|
batch, num_heads = a.shape
|
||||||
|
del num_heads
|
||||||
|
# Keep nonlinear gating math in fp32 for stability.
|
||||||
|
compute_dtype = torch.float32
|
||||||
|
A_log_f = A_log.to(compute_dtype)
|
||||||
|
a_f = a.to(compute_dtype)
|
||||||
|
b_f = b.to(compute_dtype)
|
||||||
|
dt_bias_f = dt_bias.to(compute_dtype)
|
||||||
|
|
||||||
|
# Expand A_log and dt_bias to match a shape.
|
||||||
|
A_log_expanded = A_log_f.unsqueeze(0).expand(batch, -1)
|
||||||
|
dt_bias_expanded = dt_bias_f.unsqueeze(0).expand(batch, -1)
|
||||||
|
|
||||||
|
# Compute x = a + dt_bias.
|
||||||
|
x = a_f + dt_bias_expanded
|
||||||
|
|
||||||
|
# Compute softplus(x).
|
||||||
|
beta_x = beta * x
|
||||||
|
softplus_x = torch.where(
|
||||||
|
beta_x <= threshold,
|
||||||
|
(1.0 / beta) * torch.log1p(torch.exp(beta_x)),
|
||||||
|
x,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute g = -exp(A_log) * softplus(x).
|
||||||
|
g = -torch.exp(A_log_expanded) * softplus_x
|
||||||
|
|
||||||
|
# Add sequence dimension.
|
||||||
|
g = g.unsqueeze(0)
|
||||||
|
|
||||||
|
# Match Triton kernel: sigmoid in fp32, then cast to input b dtype.
|
||||||
|
beta_output = torch.sigmoid(b_f).to(b.dtype)
|
||||||
|
beta_output = beta_output.unsqueeze(0)
|
||||||
|
|
||||||
|
return g, beta_output
|
||||||
Reference in New Issue
Block a user