From 2ed96c7a8a2039191d30a523b0983209c3e91dd1 Mon Sep 17 00:00:00 2001 From: lukec <118525388+sleepcoo@users.noreply.github.com> Date: Wed, 23 Apr 2025 01:36:23 +0800 Subject: [PATCH] fix flashmla bug (#5272) --- .../srt/layers/attention/flashmla_backend.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index 85fe4a2fb..1513c1c71 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -68,9 +68,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend): self.num_q_heads = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) - self.num_kv_heads = model_runner.model_config.get_num_kv_heads( - get_attention_tp_size() - ) self.req_to_token = model_runner.req_to_token_pool.req_to_token self.num_local_heads = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() @@ -111,8 +108,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ) mla_metadata, num_splits = get_mla_metadata( forward_batch.seq_lens.to(torch.int32), - Q_LEN * self.num_q_heads // self.num_kv_heads, - self.num_kv_heads, + Q_LEN * self.num_q_heads, + 1, ) self.forward_metadata = FlashMLADecodeMetadata( mla_metadata, @@ -141,8 +138,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend): self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata( torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device), - Q_LEN * self.num_q_heads // self.num_kv_heads, - self.num_kv_heads, + Q_LEN * self.num_q_heads, + 1, ) self.cuda_graph_kv_indices = cuda_graph_kv_indices @@ -171,8 +168,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ) mla_metadata, num_splits = get_mla_metadata( seq_lens.to(torch.int32), - Q_LEN * self.num_q_heads // self.num_kv_heads, - self.num_kv_heads, + Q_LEN * self.num_q_heads, + 1, ) self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) @@ -221,8 +218,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ) mla_metadata, num_splits = get_mla_metadata( seq_lens.to(torch.int32), - Q_LEN * self.num_q_heads // self.num_kv_heads, - self.num_kv_heads, + Q_LEN * self.num_q_heads, + 1, ) self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)