From d060c797ed8fd22a1ecdf86c5e26f6eeaf7531f2 Mon Sep 17 00:00:00 2001 From: lhp-deep Date: Mon, 9 Feb 2026 16:43:27 +0800 Subject: [PATCH] [fix bug] fix tensor mismatch bug in sigmoid operate test case (#6619) ### What this PR does / why we need it? This PR fixes a bug in the `test_triton_fusion_ops` test case. The test compares a fused kernel (`fused_sigmoid_gating_delta_rule_update`) with a split implementation. Both paths use a recurrent state tensor. The bug was that the state tensor was being modified in-place by the fused kernel call, and this modified tensor was then reused for the split implementation path. This led to an incorrect comparison and test failure. This fix ensures that each path starts with an identical, clean initial state by creating separate tensors. It also changes the state initialization from `torch.randn` to `torch.ones` to make the test deterministic. ### Does this PR introduce _any_ user-facing change? No, this change only affects a test case and has no user-facing impact. ### How was this patch tested? The fix is applied directly to the test case. The CI passing for `test_fused_sigmoid_gating_delta_rule.py` will confirm that the fix is working as expected. - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a Signed-off-by: lhp-deep --- .../triton/test_fused_sigmoid_gating_delta_rule.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_sigmoid_gating_delta_rule.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_sigmoid_gating_delta_rule.py index fd469ef3..abfbcc20 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_sigmoid_gating_delta_rule.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_sigmoid_gating_delta_rule.py @@ -16,7 +16,6 @@ def test_triton_fusion_ops(): b = torch.tensor( [[0.4277, 0.8906, 1.6875, 2.3750, 4.1562, 0.3809, 1.0625, 3.6719]]).bfloat16().npu() - ssm_state = torch.randn(1, 8, 128, 128, dtype=torch.bfloat16).npu() non_spec_state_indices_tensor = torch.tensor([2]).int().npu() non_spec_query_start_loc = torch.tensor([0, 1]).int().npu() a_log = torch.tensor([ @@ -25,6 +24,7 @@ def test_triton_fusion_ops(): dt_bias = torch.tensor( [-4.7812, -5.0938, -5.5000, 9.4375, 7.6250, -4.3750, -3.0938, 0.9688]).bfloat16().npu() + ssm_state1 = torch.ones(1, 8, 128, 128, dtype=torch.bfloat16).npu() core_attn_out_non_spec_fused = fused_sigmoid_gating_delta_rule_update( A_log=a_log.contiguous(), @@ -34,7 +34,7 @@ def test_triton_fusion_ops(): v=v.contiguous(), a=a.contiguous(), b=b.contiguous(), - initial_state_source=ssm_state, + initial_state_source=ssm_state1, initial_state_indices=non_spec_state_indices_tensor, cu_seqlens=non_spec_query_start_loc, use_qk_l2norm_in_kernel=True, @@ -42,6 +42,7 @@ def test_triton_fusion_ops(): softplus_threshold=20.0, ) + ssm_state2 = torch.ones(1, 8, 128, 128, dtype=torch.bfloat16).npu() g, beta = fused_gdn_gating(a_log, a, b, dt_bias) g_non_spec = g beta_non_spec = beta @@ -52,7 +53,7 @@ def test_triton_fusion_ops(): v=v, g=g_non_spec, beta=beta_non_spec, - initial_state=ssm_state, + initial_state=ssm_state2, inplace_final_state=True, cu_seqlens=non_spec_query_start_loc, ssm_state_indices=non_spec_state_indices_tensor,