From 8b681d7724a6e0368bab556f6ac08fa5c312ffb9 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 27 Feb 2025 03:05:30 +0200 Subject: [PATCH] [Rocm] Fix to the rocm_mla_decode_rope.py returning random result (#3898) --- .../srt/layers/attention/triton_ops/rocm_mla_decode_rope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py b/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py index 218244501..4c3e63968 100644 --- a/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +++ b/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py @@ -230,7 +230,7 @@ def _fwd_grouped_kernel_stage1_rope( other=0.0, ) # positional embedding part of keys - if USE_ROPE and start_n >= cur_batch_seq_len - BLOCK_N: + if (USE_ROPE and LAST_SPLIT) and start_n >= cur_batch_seq_len - BLOCK_N: k_pe = tl.where( offs_n[None, :] != (split_kv_end - 1), k_pe,