[bugfix][accuracy] Fix ds indexer accuracy problem caused by k rope (#7341)
### What this PR does / why we need it? The rotary algorithm in deepseek indexer should be neox-style instead of gptj style. PR #4641 fix this accuracy bug in original pytorch version. But PR #5701 accidentally removed the fixed code line and reverted the implementation back to the problematic version. This PR fixes it. Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -884,7 +884,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
|
|
||||||
k_li_pe = k_li_pe.unsqueeze(2)
|
k_li_pe = k_li_pe.unsqueeze(2)
|
||||||
k_li_pe = torch_npu.npu_interleave_rope(k_li_pe, cos, sin)
|
k_li_pe = torch_npu.npu_rotary_mul(k_li_pe, cos, sin)
|
||||||
k_li_pe = k_li_pe.squeeze(2)
|
k_li_pe = k_li_pe.squeeze(2)
|
||||||
|
|
||||||
k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128]
|
k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128]
|
||||||
|
|||||||
Reference in New Issue
Block a user