[Model] GLM5 adaptation (#6642)

### What this PR does / why we need it?
GLM5 adaptation
1. use torch_npu.npu_lightning_indexer for GLM5
2. forbid eagle proposer when fullgraph mode is enabled because of bugs
3. add quatization config for GLM5
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM main:
978a37c823

---------

Signed-off-by: yydyzr <liuyuncong1@huawei.com>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
yydyzr
2026-02-11 22:22:22 +08:00
committed by GitHub
parent 140fcaffc3
commit ff3a50d011
17 changed files with 77 additions and 34 deletions

View File

@@ -431,6 +431,11 @@ class AscendSFAImpl(MLAAttentionImpl):
self.weights_proj = self.indexer.weights_proj
self.k_norm = self.indexer.k_norm
self.cp_size = 1
self.is_rope_neox_style = True
self.use_torch_npu_lightning_indexer = False
if self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"]:
self.is_rope_neox_style = False
self.use_torch_npu_lightning_indexer = True
self.enable_dsa_cp = enable_dsa_cp()
self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard()
@@ -973,7 +978,9 @@ class AscendSFAImpl(MLAAttentionImpl):
cos = cos.view(-1, self.qk_rope_head_dim)
sin = sin.view(-1, self.qk_rope_head_dim)
q, k = rope_forward_triton(q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=True)
q, k = rope_forward_triton(
q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style
)
else:
k_pe, k_nope = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1)
@@ -1036,18 +1043,35 @@ class AscendSFAImpl(MLAAttentionImpl):
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids)
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
query=q,
key=key,
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
# So two branches are maintained temporarily.
# TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed.
if self.use_torch_npu_lightning_indexer:
topk_indices, _ = torch_npu.npu_lightning_indexer(
query=q,
key=key,
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
else:
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
query=q,
key=key,
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
return topk_indices
def _init_o_proj_tp_full_params(self):

View File

@@ -96,6 +96,11 @@ packed_modules_model_mapping: dict[str, dict[str, list[str]]] = {
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"glm_moe_dsa": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.

View File

@@ -36,7 +36,14 @@ class MtpProposer(EagleProposer):
dummy_compute_logits=lambda hidden_states: None,
is_profile=False,
) -> None:
if self.pcp_size * self.dcp_size == 1 and not self.speculative_config.disable_padded_drafter_batch:
# Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer.
# Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph.
# TODO: this conditional check should be removed after bug fixing.
if (
self.pcp_size * self.dcp_size == 1
and not self.speculative_config.disable_padded_drafter_batch
and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
super().dummy_run(
num_tokens,
with_prefill,
@@ -166,7 +173,14 @@ class MtpProposer(EagleProposer):
scheduler_output: SchedulerOutput = None,
num_scheduled_tokens: int = 0,
) -> torch.Tensor:
if self.pcp_size * self.dcp_size == 1 and not self.speculative_config.disable_padded_drafter_batch:
# Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer.
# Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph.
# TODO: this conditional check should be removed after bug fixing.
if (
self.pcp_size * self.dcp_size == 1
and not self.speculative_config.disable_padded_drafter_batch
and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
draft_token_ids = super()._propose(
target_token_ids,
target_positions,