[Bugifx] fix quant_apply_mlp w1_scale type error & fix getting num_local_expert (#4632)
### What this PR does / why we need it?
Fix bugs introduced by
bc67696a02
1. fix getting num_local_experet error in vllm_adaptor
2. fix w1_scale type error in
moe_mlp.quant_apply_mlp.npu_dequant_swiglu_quant in w4a8 quantized
scenario
- vLLM version: v0.12.0
---------
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: 欧派果奶我还要 <47294568+845473182@users.noreply.github.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -107,8 +107,8 @@ class VllmEplbAdaptor(EplbAdaptor):
|
||||
self.buffer_tensor_list[buffer_id].append(buffer_tensor)
|
||||
|
||||
def init_expert_param_per_layer(self):
|
||||
num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) + \
|
||||
".mlp.experts." + self.expert_weight_names[0]].data.shape[0]
|
||||
key = f"model.layers.{self.num_dense_layers}.mlp.experts.{self.expert_weight_names[0]}"
|
||||
num_local_expert = len(self.param_dict[key])
|
||||
for moe_layer_id in range(self.num_moe_layers):
|
||||
layer_idx = self.num_dense_layers + moe_layer_id
|
||||
self.expert_param_per_layer[layer_idx] = list()
|
||||
|
||||
@@ -129,7 +129,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
weight_scale=w1_scale[0],
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
|
||||
@@ -289,7 +289,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
]
|
||||
layer.w13_weight_scale_fp32_list = [
|
||||
weight.clone()
|
||||
for weight in layer.w13_weight_scale.data.unbind(dim=0)
|
||||
for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
|
||||
]
|
||||
layer.w2_weight_scale_list = [
|
||||
weight.clone()
|
||||
|
||||
Reference in New Issue
Block a user