fix flashmla bug (#5272)
This commit is contained in:
@@ -68,9 +68,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.num_q_heads = (
|
self.num_q_heads = (
|
||||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
)
|
)
|
||||||
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
|
||||||
get_attention_tp_size()
|
|
||||||
)
|
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
self.num_local_heads = (
|
self.num_local_heads = (
|
||||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
@@ -111,8 +108,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
)
|
)
|
||||||
mla_metadata, num_splits = get_mla_metadata(
|
mla_metadata, num_splits = get_mla_metadata(
|
||||||
forward_batch.seq_lens.to(torch.int32),
|
forward_batch.seq_lens.to(torch.int32),
|
||||||
Q_LEN * self.num_q_heads // self.num_kv_heads,
|
Q_LEN * self.num_q_heads,
|
||||||
self.num_kv_heads,
|
1,
|
||||||
)
|
)
|
||||||
self.forward_metadata = FlashMLADecodeMetadata(
|
self.forward_metadata = FlashMLADecodeMetadata(
|
||||||
mla_metadata,
|
mla_metadata,
|
||||||
@@ -141,8 +138,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
|
|
||||||
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
|
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
|
||||||
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
|
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
|
||||||
Q_LEN * self.num_q_heads // self.num_kv_heads,
|
Q_LEN * self.num_q_heads,
|
||||||
self.num_kv_heads,
|
1,
|
||||||
)
|
)
|
||||||
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
||||||
|
|
||||||
@@ -171,8 +168,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
)
|
)
|
||||||
mla_metadata, num_splits = get_mla_metadata(
|
mla_metadata, num_splits = get_mla_metadata(
|
||||||
seq_lens.to(torch.int32),
|
seq_lens.to(torch.int32),
|
||||||
Q_LEN * self.num_q_heads // self.num_kv_heads,
|
Q_LEN * self.num_q_heads,
|
||||||
self.num_kv_heads,
|
1,
|
||||||
)
|
)
|
||||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||||
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||||
@@ -221,8 +218,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
)
|
)
|
||||||
mla_metadata, num_splits = get_mla_metadata(
|
mla_metadata, num_splits = get_mla_metadata(
|
||||||
seq_lens.to(torch.int32),
|
seq_lens.to(torch.int32),
|
||||||
Q_LEN * self.num_q_heads // self.num_kv_heads,
|
Q_LEN * self.num_q_heads,
|
||||||
self.num_kv_heads,
|
1,
|
||||||
)
|
)
|
||||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||||
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||||
|
|||||||
Reference in New Issue
Block a user