diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 1579209a3..b76a1cb15 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -547,7 +547,7 @@ class NixlKVManager(CommonKVManager): notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))]) decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size - if decode_tp_size == self.attn_tp_size: + if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): kv_xfer_handle = self.send_kvcache( req.agent_name, kv_indices,