[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216)
### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
This commit is contained in:
@@ -226,8 +226,8 @@ class TestMoECommMethod(TestBase):
|
|||||||
w2 = w2.contiguous()
|
w2 = w2.contiguous()
|
||||||
|
|
||||||
result = comm_impl.fused_experts(hidden_states=hidden_states,
|
result = comm_impl.fused_experts(hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=[w1],
|
||||||
w2=w2,
|
w2=[w2],
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
activation="silu")
|
activation="silu")
|
||||||
|
|||||||
@@ -44,11 +44,22 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
self.init_redundancy_expert = get_ascend_config(
|
self.init_redundancy_expert = get_ascend_config(
|
||||||
).init_redundancy_expert
|
).init_redundancy_expert
|
||||||
|
|
||||||
|
for i in range(self.num_dense_layers,
|
||||||
|
self.model.config.num_hidden_layers):
|
||||||
|
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = \
|
||||||
|
self.model.model.layers[i].mlp.experts.w13_weight_list
|
||||||
|
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = \
|
||||||
|
self.model.model.layers[i].mlp.experts.w2_weight_list
|
||||||
|
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = \
|
||||||
|
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
|
||||||
|
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \
|
||||||
|
self.model.model.layers[i].mlp.experts.w2_weight_scale_list
|
||||||
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
|
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
|
||||||
if self.model.quant_config is not None:
|
if self.model.quant_config is not None:
|
||||||
self.expert_weight_names = [
|
self.expert_weight_names = [
|
||||||
"w13_weight", "w2_weight", "w13_weight_scale",
|
"w13_weight_list", "w2_weight_list",
|
||||||
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
|
"w13_weight_scale_fp32_list", "w13_weight_offset",
|
||||||
|
"w2_weight_scale_list", "w2_weight_offset"
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
||||||
@@ -84,9 +95,14 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
for name in self.expert_weight_names:
|
for name in self.expert_weight_names:
|
||||||
complete_name = "model.layers." + str(
|
complete_name = "model.layers." + str(
|
||||||
self.num_dense_layers) + ".mlp.experts." + name
|
self.num_dense_layers) + ".mlp.experts." + name
|
||||||
expert_tensor = self.param_dict[complete_name].data[0]
|
if name in [
|
||||||
if name in ["w13_weight", "w2_weight"]:
|
"w13_weight_list", "w2_weight_list",
|
||||||
|
"w13_weight_scale_fp32_list", "w2_weight_scale_list"
|
||||||
|
]:
|
||||||
|
expert_tensor = self.param_dict[complete_name][0]
|
||||||
expert_tensor = expert_tensor.clone()
|
expert_tensor = expert_tensor.clone()
|
||||||
|
else:
|
||||||
|
expert_tensor = self.param_dict[complete_name][0].data[0]
|
||||||
buffer_tensor = torch.empty_like(expert_tensor)
|
buffer_tensor = torch.empty_like(expert_tensor)
|
||||||
self.buffer_tensor_list[buffer_id].append(buffer_tensor)
|
self.buffer_tensor_list[buffer_id].append(buffer_tensor)
|
||||||
|
|
||||||
@@ -97,12 +113,23 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
layer_idx = self.num_dense_layers + moe_layer_id
|
layer_idx = self.num_dense_layers + moe_layer_id
|
||||||
self.expert_param_per_layer[layer_idx] = list()
|
self.expert_param_per_layer[layer_idx] = list()
|
||||||
for local_expert_id in range(num_local_expert):
|
for local_expert_id in range(num_local_expert):
|
||||||
self.expert_param_per_layer[layer_idx].append([
|
per_expert_param = list()
|
||||||
self.param_dict["model.layers." + str(layer_idx) +
|
for name in self.expert_weight_names:
|
||||||
".mlp.experts." +
|
if name in [
|
||||||
name].data[local_expert_id]
|
"w13_weight_list", "w2_weight_list",
|
||||||
for name in self.expert_weight_names
|
"w13_weight_scale_fp32_list",
|
||||||
])
|
"w2_weight_scale_list"
|
||||||
|
]:
|
||||||
|
per_expert_param.append(
|
||||||
|
self.param_dict["model.layers." + str(layer_idx) +
|
||||||
|
".mlp.experts." +
|
||||||
|
name][local_expert_id])
|
||||||
|
else:
|
||||||
|
per_expert_param.append(
|
||||||
|
self.param_dict["model.layers." + str(layer_idx) +
|
||||||
|
".mlp.experts." +
|
||||||
|
name][0].data[local_expert_id])
|
||||||
|
self.expert_param_per_layer[layer_idx].append(per_expert_param)
|
||||||
|
|
||||||
def get_rank_expert_workload(self) -> torch.Tensor:
|
def get_rank_expert_workload(self) -> torch.Tensor:
|
||||||
self.moe_load = self.model.get_all_moe_loads()
|
self.moe_load = self.model.get_all_moe_loads()
|
||||||
|
|||||||
@@ -83,8 +83,8 @@ class MoECommMethod(ABC):
|
|||||||
def fused_experts(
|
def fused_experts(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor | list[torch.Tensor],
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor | list[torch.Tensor],
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
@@ -93,8 +93,8 @@ class MoECommMethod(ABC):
|
|||||||
use_int4_w4a8: bool = False,
|
use_int4_w4a8: bool = False,
|
||||||
global_num_experts: Optional[int] = None,
|
global_num_experts: Optional[int] = None,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||||
w1_scale_bias: torch.Tensor = None,
|
w1_scale_bias: torch.Tensor = None,
|
||||||
w2_scale_bias: torch.Tensor = None,
|
w2_scale_bias: torch.Tensor = None,
|
||||||
# For TorchAir graph
|
# For TorchAir graph
|
||||||
|
|||||||
@@ -23,7 +23,11 @@ from vllm.forward_context import get_forward_context
|
|||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
|
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
|
||||||
get_ascend_device_type)
|
enable_custom_op, get_ascend_device_type)
|
||||||
|
|
||||||
|
|
||||||
|
def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
||||||
|
return fusion and dynamic_eplb and enable_custom_op()
|
||||||
|
|
||||||
|
|
||||||
def cumsum_group_list(group_list: torch.Tensor,
|
def cumsum_group_list(group_list: torch.Tensor,
|
||||||
@@ -55,10 +59,10 @@ def cumsum_group_list(group_list: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: list[torch.Tensor],
|
||||||
w1_scale: torch.Tensor,
|
w1_scale: list[torch.Tensor],
|
||||||
w2: torch.Tensor,
|
w2: list[torch.Tensor],
|
||||||
w2_scale: torch.Tensor,
|
w2_scale: list[torch.Tensor],
|
||||||
group_list: torch.Tensor,
|
group_list: torch.Tensor,
|
||||||
group_list_type: int = 1,
|
group_list_type: int = 1,
|
||||||
dynamic_scale: torch.Tensor = None,
|
dynamic_scale: torch.Tensor = None,
|
||||||
@@ -79,7 +83,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
quantized_hidden_states = hidden_states
|
quantized_hidden_states = hidden_states
|
||||||
|
|
||||||
bias1, bias2 = None, None
|
bias1, bias2 = None, None
|
||||||
_output_dtype = w2_scale.dtype
|
_output_dtype = w2_scale[0].dtype
|
||||||
|
|
||||||
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
||||||
if weight_prefetch_method:
|
if weight_prefetch_method:
|
||||||
@@ -87,23 +91,34 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
hidden_states)
|
hidden_states)
|
||||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||||
if w1_scale_bias is None and is_mc2:
|
if w1_scale_bias is None and is_mc2:
|
||||||
if fusion and not dynamic_eplb:
|
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
||||||
|
# gmm1: gate_up_proj & act_fn: swiglu
|
||||||
|
hidden_states, swiglu_out_scale, _ = (
|
||||||
|
torch.ops._C_ascend.
|
||||||
|
grouped_matmul_swiglu_quant_weight_nz_tensor_list(
|
||||||
|
x=hidden_states,
|
||||||
|
weight=w1,
|
||||||
|
weight_scale=w1_scale,
|
||||||
|
x_scale=pertoken_scale,
|
||||||
|
group_list=cumsum_group_list(group_list, group_list_type),
|
||||||
|
))
|
||||||
|
elif fusion and not dynamic_eplb:
|
||||||
# gmm1: gate_up_proj & act_fn: swiglu
|
# gmm1: gate_up_proj & act_fn: swiglu
|
||||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
weight=w1,
|
weight=w1[0],
|
||||||
group_list=cumsum_group_list(group_list, group_list_type),
|
group_list=cumsum_group_list(group_list, group_list_type),
|
||||||
weight_scale=w1_scale,
|
weight_scale=w1_scale[0],
|
||||||
x_scale=pertoken_scale)
|
x_scale=pertoken_scale)
|
||||||
if quantized_hidden_states is not None:
|
if quantized_hidden_states is not None:
|
||||||
dispose_tensor(quantized_hidden_states)
|
dispose_tensor(quantized_hidden_states)
|
||||||
else:
|
else:
|
||||||
if w1_scale.dtype != torch.float32:
|
if w1_scale[0].dtype != torch.float32:
|
||||||
w1_scale = w1_scale.to(torch.float32)
|
w1_scale[0] = w1_scale[0].to(torch.float32)
|
||||||
# gmm1: gate_up_proj
|
# gmm1: gate_up_proj
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
weight=[w1],
|
weight=w1,
|
||||||
split_item=3,
|
split_item=3,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
group_type=0,
|
group_type=0,
|
||||||
@@ -126,14 +141,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
# gmm2: down_proj
|
# gmm2: down_proj
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
weight=[w2],
|
weight=w2,
|
||||||
scale=[w2_scale],
|
scale=w2_scale,
|
||||||
per_token_scale=[swiglu_out_scale],
|
per_token_scale=[swiglu_out_scale],
|
||||||
split_item=2,
|
split_item=2,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
group_type=0,
|
group_type=0,
|
||||||
group_list=group_list,
|
group_list=group_list,
|
||||||
output_dtype=w2_scale.dtype)[0]
|
output_dtype=w2_scale[0].dtype)[0]
|
||||||
else:
|
else:
|
||||||
if w1_scale_bias is not None:
|
if w1_scale_bias is not None:
|
||||||
if group_list_type == 0:
|
if group_list_type == 0:
|
||||||
@@ -146,23 +161,36 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||||
_output_dtype = torch.bfloat16
|
_output_dtype = torch.bfloat16
|
||||||
|
|
||||||
if fusion and not dynamic_eplb:
|
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
||||||
|
# gmm1: gate_up_proj & act_fn: swiglu
|
||||||
|
hidden_states, swiglu_out_scale, _ = (
|
||||||
|
torch.ops._C_ascend.
|
||||||
|
grouped_matmul_swiglu_quant_weight_nz_tensor_list(
|
||||||
|
x=hidden_states,
|
||||||
|
weight=w1,
|
||||||
|
weight_scale=w1_scale,
|
||||||
|
x_scale=pertoken_scale,
|
||||||
|
group_list=cumsum_group_list(group_list, group_list_type),
|
||||||
|
bias=bias1,
|
||||||
|
))
|
||||||
|
elif fusion and not dynamic_eplb:
|
||||||
# gmm1: gate_up_proj & act_fn: swiglu
|
# gmm1: gate_up_proj & act_fn: swiglu
|
||||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
weight=w1,
|
weight=w1[0],
|
||||||
bias=bias1,
|
bias=bias1,
|
||||||
group_list=cumsum_group_list(group_list, group_list_type),
|
group_list=cumsum_group_list(group_list, group_list_type),
|
||||||
weight_scale=w1_scale,
|
weight_scale=w1_scale[0],
|
||||||
x_scale=pertoken_scale)
|
x_scale=pertoken_scale)
|
||||||
if quantized_hidden_states is not None:
|
if quantized_hidden_states is not None:
|
||||||
dispose_tensor(quantized_hidden_states)
|
dispose_tensor(quantized_hidden_states)
|
||||||
else:
|
else:
|
||||||
|
w1_scale[0] = w1_scale[0].to(w2_scale[0].dtype)
|
||||||
# gmm1: gate_up_proj
|
# gmm1: gate_up_proj
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
weight=[w1],
|
weight=w1,
|
||||||
scale=[w1_scale.to(w2_scale.dtype)],
|
scale=w1_scale,
|
||||||
bias=bias1,
|
bias=bias1,
|
||||||
per_token_scale=[pertoken_scale],
|
per_token_scale=[pertoken_scale],
|
||||||
split_item=2,
|
split_item=2,
|
||||||
@@ -179,8 +207,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
# gmm2: down_proj
|
# gmm2: down_proj
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
weight=[w2],
|
weight=w2,
|
||||||
scale=[w2_scale],
|
scale=w2_scale,
|
||||||
bias=bias2,
|
bias=bias2,
|
||||||
per_token_scale=[swiglu_out_scale],
|
per_token_scale=[swiglu_out_scale],
|
||||||
split_item=2,
|
split_item=2,
|
||||||
@@ -232,11 +260,11 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def unified_apply_mlp(hidden_states: torch.Tensor,
|
def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor | list[torch.Tensor],
|
||||||
w1_scale: torch.Tensor,
|
w2: torch.Tensor | list[torch.Tensor],
|
||||||
w2: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
group_list: torch.Tensor,
|
group_list: torch.Tensor,
|
||||||
|
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||||
|
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||||
dynamic_scale: torch.Tensor = None,
|
dynamic_scale: torch.Tensor = None,
|
||||||
group_list_type: int = 1,
|
group_list_type: int = 1,
|
||||||
w1_scale_bias: torch.Tensor = None,
|
w1_scale_bias: torch.Tensor = None,
|
||||||
@@ -247,6 +275,7 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
need_trans: bool = True,
|
need_trans: bool = True,
|
||||||
dynamic_eplb: bool = False) -> torch.Tensor:
|
dynamic_eplb: bool = False) -> torch.Tensor:
|
||||||
if with_quant:
|
if with_quant:
|
||||||
|
assert w1_scale is not None and w2_scale is not None
|
||||||
return quant_apply_mlp(hidden_states=hidden_states,
|
return quant_apply_mlp(hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
|
|||||||
@@ -379,10 +379,10 @@ class AscendW4A8DynamicFusedMoEMethod:
|
|||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
return moe_comm_method.fused_experts(
|
return moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=[layer.w13_weight],
|
||||||
w2=layer.w2_weight,
|
w2=[layer.w2_weight],
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=[layer.w13_weight_scale],
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=[layer.w2_weight_scale],
|
||||||
w1_scale_bias=layer.w13_scale_bias,
|
w1_scale_bias=layer.w13_scale_bias,
|
||||||
w2_scale_bias=layer.w2_scale_bias,
|
w2_scale_bias=layer.w2_scale_bias,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
|
|||||||
@@ -236,13 +236,24 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
topk_weights = topk_weights.to(self.in_dtype)
|
topk_weights = topk_weights.to(self.in_dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
|
if self.dynamic_eplb:
|
||||||
|
w1 = layer.w13_weight_list
|
||||||
|
w1_scale = layer.w13_weight_scale_fp32_list
|
||||||
|
w2 = layer.w2_weight_list
|
||||||
|
w2_scale = layer.w2_weight_scale_list
|
||||||
|
else:
|
||||||
|
w1 = [layer.w13_weight]
|
||||||
|
w1_scale = [layer.w13_weight_scale_fp32]
|
||||||
|
w2 = [layer.w2_weight]
|
||||||
|
w2_scale = [layer.w2_weight_scale]
|
||||||
|
|
||||||
return moe_comm_method.fused_experts(
|
return moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
pertoken_scale=pertoken_scale,
|
pertoken_scale=pertoken_scale,
|
||||||
w1=layer.w13_weight,
|
w1=w1,
|
||||||
w1_scale=layer.w13_weight_scale_fp32,
|
w1_scale=w1_scale,
|
||||||
w2=layer.w2_weight,
|
w2=w2,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=w2_scale,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
use_int8_w8a8=True,
|
use_int8_w8a8=True,
|
||||||
@@ -274,3 +285,25 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
layer.w2_weight_scale.data.shape[0], -1)
|
layer.w2_weight_scale.data.shape[0], -1)
|
||||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||||
layer.w2_weight_offset.data.shape[0], -1)
|
layer.w2_weight_offset.data.shape[0], -1)
|
||||||
|
if self.dynamic_eplb:
|
||||||
|
layer.w13_weight_list = [
|
||||||
|
weight.clone()
|
||||||
|
for weight in layer.w13_weight.data.unbind(dim=0)
|
||||||
|
]
|
||||||
|
layer.w2_weight_list = [
|
||||||
|
weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)
|
||||||
|
]
|
||||||
|
layer.w13_weight_scale_fp32_list = [
|
||||||
|
weight.clone()
|
||||||
|
for weight in layer.w13_weight_scale.data.unbind(dim=0)
|
||||||
|
]
|
||||||
|
layer.w2_weight_scale_list = [
|
||||||
|
weight.clone()
|
||||||
|
for weight in layer.w2_weight_scale.data.unbind(dim=0)
|
||||||
|
]
|
||||||
|
del layer.w13_weight
|
||||||
|
del layer.w2_weight
|
||||||
|
del layer.w13_weight_scale
|
||||||
|
del layer.w13_weight_scale_fp32
|
||||||
|
del layer.w2_weight_scale
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user