[BugFix][DS 3.2] Fix ds indexer accuracy problem caused by rope. (#4641)

### What this PR does / why we need it?
The rotary algorithm in deepseek indexer should be neox-style instead of
gptj style. PR #4413 fix this accuracy bug with new triton kernel. This
PR fixes original pytorch version.

### Does this PR introduce _any_ user-facing change?
None

### How was this patch tested?
CI passed with existing test.


- vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24
- vLLM main:
86e178f7c4

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2026-01-09 14:11:44 +08:00
committed by GitHub
parent 98c788a65a
commit ee2ed573f1

View File

@@ -882,7 +882,7 @@ class AscendSFAImpl(MLAAttentionImpl):
dim=-1) # [b,s,64,64+64]
q_pe = q_pe.unsqueeze(2)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q)
q_pe = q_pe.squeeze(2)
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
@@ -892,7 +892,7 @@ class AscendSFAImpl(MLAAttentionImpl):
dim=-1) # [b,s,64+64]
k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
k_pe = torch_npu.npu_rotary_mul(k_pe, cos, sin)
k_pe = k_pe.squeeze(2)
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]