From 13397e9cb75721f7365cfc4c22738320aac21d54 Mon Sep 17 00:00:00 2001 From: Shaoxu Cheng <2906339855@qq.com> Date: Mon, 23 Mar 2026 20:26:39 +0800 Subject: [PATCH] [310p] Add a PyTorch implementation of the GDN gating operator on 310P (#7430) ### What this PR does / why we need it? RFC #7394 Add a PyTorch implementation of the GDN gating operator on 310P. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? UT - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4497431df654e46fb1fb5e64bf8611e762ae5d87 Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- .../triton/test_fused_gdn_gating.py | 51 +++++++++++++++ vllm_ascend/_310p/ops/fla/__init__.py | 3 + vllm_ascend/_310p/ops/fla/fused_gdn_gating.py | 62 +++++++++++++++++++ 3 files changed, 116 insertions(+) create mode 100644 tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_gdn_gating.py create mode 100644 vllm_ascend/_310p/ops/fla/__init__.py create mode 100644 vllm_ascend/_310p/ops/fla/fused_gdn_gating.py diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_gdn_gating.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_gdn_gating.py new file mode 100644 index 00000000..c42a1355 --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_gdn_gating.py @@ -0,0 +1,51 @@ +import torch + +from vllm_ascend._310p.ops.fla.fused_gdn_gating import fused_gdn_gating_pytorch +from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch +from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton + + +def test_fused_gdn_gating_310p_parity_precision(): + init_device_properties_triton() + torch.manual_seed(0) + device = "npu" + + num_tokens = 37 + num_heads = 8 + + A_log = torch.randn(num_heads, dtype=torch.float16, device=device) + dt_bias = torch.randn(num_heads, dtype=torch.float16, device=device) + a = torch.randn(num_tokens, num_heads, dtype=torch.float16, device=device) + b = torch.randn(num_tokens, num_heads, dtype=torch.float16, device=device) + + triton_g, triton_beta = fused_gdn_gating_patch( + A_log=A_log, + a=a, + b=b, + dt_bias=dt_bias, + beta=1.0, + threshold=20.0, + ) + ref_g, ref_beta = fused_gdn_gating_pytorch( + A_log=A_log, + a=a, + b=b, + dt_bias=dt_bias, + beta=1.0, + threshold=20.0, + ) + + torch.testing.assert_close( + triton_g.to(torch.float32).cpu(), + ref_g.to(torch.float32).cpu(), + rtol=1e-2, + atol=1e-2, + equal_nan=True, + ) + torch.testing.assert_close( + triton_beta.to(torch.float32).cpu(), + ref_beta.to(torch.float32).cpu(), + rtol=1e-2, + atol=1e-2, + equal_nan=True, + ) diff --git a/vllm_ascend/_310p/ops/fla/__init__.py b/vllm_ascend/_310p/ops/fla/__init__.py new file mode 100644 index 00000000..1104d3ff --- /dev/null +++ b/vllm_ascend/_310p/ops/fla/__init__.py @@ -0,0 +1,3 @@ +from .fused_gdn_gating import fused_gdn_gating_pytorch + +__all__ = ["fused_gdn_gating_pytorch"] diff --git a/vllm_ascend/_310p/ops/fla/fused_gdn_gating.py b/vllm_ascend/_310p/ops/fla/fused_gdn_gating.py new file mode 100644 index 00000000..8442648c --- /dev/null +++ b/vllm_ascend/_310p/ops/fla/fused_gdn_gating.py @@ -0,0 +1,62 @@ +import torch + + +def fused_gdn_gating_pytorch( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + PyTorch implementation of fused_gdn_gating. + This is a fallback implementation for 310P without Triton support. + + Args: + A_log: Log of A parameter, shape [num_heads] + a: a parameter, shape [batch, num_heads] + b: b parameter, shape [batch, num_heads] + dt_bias: dt bias, shape [num_heads] + beta: softplus beta parameter + threshold: softplus threshold parameter + + Returns: + g: gating parameter, shape [1, batch, num_heads] + beta_output: sigmoid(b), shape [1, batch, num_heads] + """ + batch, num_heads = a.shape + del num_heads + # Keep nonlinear gating math in fp32 for stability. + compute_dtype = torch.float32 + A_log_f = A_log.to(compute_dtype) + a_f = a.to(compute_dtype) + b_f = b.to(compute_dtype) + dt_bias_f = dt_bias.to(compute_dtype) + + # Expand A_log and dt_bias to match a shape. + A_log_expanded = A_log_f.unsqueeze(0).expand(batch, -1) + dt_bias_expanded = dt_bias_f.unsqueeze(0).expand(batch, -1) + + # Compute x = a + dt_bias. + x = a_f + dt_bias_expanded + + # Compute softplus(x). + beta_x = beta * x + softplus_x = torch.where( + beta_x <= threshold, + (1.0 / beta) * torch.log1p(torch.exp(beta_x)), + x, + ) + + # Compute g = -exp(A_log) * softplus(x). + g = -torch.exp(A_log_expanded) * softplus_x + + # Add sequence dimension. + g = g.unsqueeze(0) + + # Match Triton kernel: sigmoid in fp32, then cast to input b dtype. + beta_output = torch.sigmoid(b_f).to(b.dtype) + beta_output = beta_output.unsqueeze(0) + + return g, beta_output