[fix bug] fix tensor mismatch bug in sigmoid operate test case (#6619)
### What this PR does / why we need it?
This PR fixes a bug in the `test_triton_fusion_ops` test case. The test
compares a fused kernel (`fused_sigmoid_gating_delta_rule_update`) with
a split implementation. Both paths use a recurrent state tensor.
The bug was that the state tensor was being modified in-place by the
fused kernel call, and this modified tensor was then reused for the
split implementation path. This led to an incorrect comparison and test
failure.
This fix ensures that each path starts with an identical, clean initial
state by creating separate tensors. It also changes the state
initialization from `torch.randn` to `torch.ones` to make the test
deterministic.
### Does this PR introduce _any_ user-facing change?
No, this change only affects a test case and has no user-facing impact.
### How was this patch tested?
The fix is applied directly to the test case. The CI passing for
`test_fused_sigmoid_gating_delta_rule.py` will confirm that the fix is
working as expected.
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
This commit is contained in:
@@ -16,7 +16,6 @@ def test_triton_fusion_ops():
|
||||
b = torch.tensor(
|
||||
[[0.4277, 0.8906, 1.6875, 2.3750, 4.1562, 0.3809, 1.0625,
|
||||
3.6719]]).bfloat16().npu()
|
||||
ssm_state = torch.randn(1, 8, 128, 128, dtype=torch.bfloat16).npu()
|
||||
non_spec_state_indices_tensor = torch.tensor([2]).int().npu()
|
||||
non_spec_query_start_loc = torch.tensor([0, 1]).int().npu()
|
||||
a_log = torch.tensor([
|
||||
@@ -25,6 +24,7 @@ def test_triton_fusion_ops():
|
||||
dt_bias = torch.tensor(
|
||||
[-4.7812, -5.0938, -5.5000, 9.4375, 7.6250, -4.3750, -3.0938,
|
||||
0.9688]).bfloat16().npu()
|
||||
ssm_state1 = torch.ones(1, 8, 128, 128, dtype=torch.bfloat16).npu()
|
||||
|
||||
core_attn_out_non_spec_fused = fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=a_log.contiguous(),
|
||||
@@ -34,7 +34,7 @@ def test_triton_fusion_ops():
|
||||
v=v.contiguous(),
|
||||
a=a.contiguous(),
|
||||
b=b.contiguous(),
|
||||
initial_state_source=ssm_state,
|
||||
initial_state_source=ssm_state1,
|
||||
initial_state_indices=non_spec_state_indices_tensor,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
@@ -42,6 +42,7 @@ def test_triton_fusion_ops():
|
||||
softplus_threshold=20.0,
|
||||
)
|
||||
|
||||
ssm_state2 = torch.ones(1, 8, 128, 128, dtype=torch.bfloat16).npu()
|
||||
g, beta = fused_gdn_gating(a_log, a, b, dt_bias)
|
||||
g_non_spec = g
|
||||
beta_non_spec = beta
|
||||
@@ -52,7 +53,7 @@ def test_triton_fusion_ops():
|
||||
v=v,
|
||||
g=g_non_spec,
|
||||
beta=beta_non_spec,
|
||||
initial_state=ssm_state,
|
||||
initial_state=ssm_state2,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
ssm_state_indices=non_spec_state_indices_tensor,
|
||||
|
||||
Reference in New Issue
Block a user