support pangumoe w8a8c8 and docs (#1477)

### What this PR does / why we need it?
support pangu moe w8a8c8

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with new added test.

Signed-off-by: zhuyilin <809721801@qq.com>
This commit is contained in:
Zhu Yi Lin
2025-06-28 18:51:07 +08:00
committed by GitHub
parent c59d69d9e6
commit b308a7a258
8 changed files with 689 additions and 50 deletions

View File

@@ -505,7 +505,7 @@ class PanguProMoESparseMoeBlock(nn.Module):
# native FusedMoE. here we need to design a better FusedMoE
# (maybe using AscendFusedMoE) to enable these different
# communication schema.
final_hidden_states = self.experts.quant_method(
final_hidden_states = self.experts.quant_method.apply(
layer=self.experts,
x=hidden_states,
router_logits=router_logits,
@@ -937,6 +937,8 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -972,6 +974,51 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
if "module" in name:
continue
if name.endswith('kv_cache_offset'):
continue
if name.endswith("k_proj.kv_cache_scale"):
remapped_kv_scale_name = name.replace(
"k_proj.kv_cache_scale", "attn.key_antiquant_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
loaded_weight = torch.tensor_split(loaded_weight,
tp_size,
dim=0)[tp_rank]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if name.endswith("v_proj.kv_cache_scale"):
remapped_kv_scale_name = name.replace(
"v_proj.kv_cache_scale", "attn.value_antiquant_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
loaded_weight = torch.tensor_split(loaded_weight,
tp_size,
dim=0)[tp_rank]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: