support pangumoe w8a8c8 and docs (#1477)
### What this PR does / why we need it? support pangu moe w8a8c8 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with new added test. Signed-off-by: zhuyilin <809721801@qq.com>
This commit is contained in:
@@ -505,7 +505,7 @@ class PanguProMoESparseMoeBlock(nn.Module):
|
||||
# native FusedMoE. here we need to design a better FusedMoE
|
||||
# (maybe using AscendFusedMoE) to enable these different
|
||||
# communication schema.
|
||||
final_hidden_states = self.experts.quant_method(
|
||||
final_hidden_states = self.experts.quant_method.apply(
|
||||
layer=self.experts,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
@@ -937,6 +937,8 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
tp_size = get_tp_group().world_size
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -972,6 +974,51 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
|
||||
if "module" in name:
|
||||
continue
|
||||
|
||||
if name.endswith('kv_cache_offset'):
|
||||
continue
|
||||
|
||||
if name.endswith("k_proj.kv_cache_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
"k_proj.kv_cache_scale", "attn.key_antiquant_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found kv scale in the checkpoint "
|
||||
f"(e.g. {name}), but not found the expected "
|
||||
f"name in the model "
|
||||
f"(e.g. {remapped_kv_scale_name}). "
|
||||
"kv-scale is not loaded.")
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
param = params_dict[name]
|
||||
loaded_weight = torch.tensor_split(loaded_weight,
|
||||
tp_size,
|
||||
dim=0)[tp_rank]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name.endswith("v_proj.kv_cache_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
"v_proj.kv_cache_scale", "attn.value_antiquant_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found kv scale in the checkpoint "
|
||||
f"(e.g. {name}), but not found the expected "
|
||||
f"name in the model "
|
||||
f"(e.g. {remapped_kv_scale_name}). "
|
||||
"kv-scale is not loaded.")
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
param = params_dict[name]
|
||||
loaded_weight = torch.tensor_split(loaded_weight,
|
||||
tp_size,
|
||||
dim=0)[tp_rank]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
|
||||
Reference in New Issue
Block a user