[Rocm] Fix to the rocm_mla_decode_rope.py returning random result (#3898)
This commit is contained in:
@@ -230,7 +230,7 @@ def _fwd_grouped_kernel_stage1_rope(
|
|||||||
other=0.0,
|
other=0.0,
|
||||||
) # positional embedding part of keys
|
) # positional embedding part of keys
|
||||||
|
|
||||||
if USE_ROPE and start_n >= cur_batch_seq_len - BLOCK_N:
|
if (USE_ROPE and LAST_SPLIT) and start_n >= cur_batch_seq_len - BLOCK_N:
|
||||||
k_pe = tl.where(
|
k_pe = tl.where(
|
||||||
offs_n[None, :] != (split_kv_end - 1),
|
offs_n[None, :] != (split_kv_end - 1),
|
||||||
k_pe,
|
k_pe,
|
||||||
|
|||||||
Reference in New Issue
Block a user