[Perf] Optimize fused_experts quantization code to save npu memory (#784)
### What this PR does / why we need it? In the w8a8 quantization code of `fused_experts`, the output of almost every operator is assigned a new variable name. If we want to save NPU memory, we manually `del` these variables to end their lifecycle, which fills the code with `del` statements and looks inelegant. Therefore, I plan to names the output of most operators as `hidden_states`, thereby ending the lifecycle of the previous `hidden_states`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
@@ -222,9 +222,6 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
if self.n_shared_experts is not None:
|
|
||||||
shared_output = self.shared_experts(hidden_states)
|
|
||||||
|
|
||||||
if (self.tp_size > 1 and self.enable_mc2 and not is_prefill):
|
if (self.tp_size > 1 and self.enable_mc2 and not is_prefill):
|
||||||
chunks = torch.chunk(hidden_states,
|
chunks = torch.chunk(hidden_states,
|
||||||
get_tp_group().world_size,
|
get_tp_group().world_size,
|
||||||
@@ -248,8 +245,8 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
else:
|
else:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
if self.n_shared_experts is not None:
|
||||||
if shared_output is not None:
|
shared_output = self.shared_experts(hidden_states)
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
@@ -25,7 +25,7 @@ from vllm_ascend.distributed.parallel_state import get_ep_group
|
|||||||
from vllm_ascend.ops.fused_moe import select_experts
|
from vllm_ascend.ops.fused_moe import select_experts
|
||||||
|
|
||||||
|
|
||||||
def apply_mlp(x: torch.Tensor,
|
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
@@ -37,7 +37,7 @@ def apply_mlp(x: torch.Tensor,
|
|||||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: input hidden states with shape (num_tokens, hidden_size).
|
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
|
||||||
w1: expert weights1 with shape
|
w1: expert weights1 with shape
|
||||||
(num_experts, hidden_size, intermediate_size * 2)
|
(num_experts, hidden_size, intermediate_size * 2)
|
||||||
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
||||||
@@ -56,23 +56,27 @@ def apply_mlp(x: torch.Tensor,
|
|||||||
hidden_states: output hidden states after MLP.
|
hidden_states: output hidden states after MLP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert len(hidden_states_wrapper) == 1
|
||||||
|
hidden_states = hidden_states_wrapper.pop()
|
||||||
if dynamic_scale is None:
|
if dynamic_scale is None:
|
||||||
h, pertoken_scale = torch_npu.npu_dynamic_quant(x)
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||||
|
hidden_states)
|
||||||
else:
|
else:
|
||||||
h = x
|
|
||||||
pertoken_scale = dynamic_scale
|
pertoken_scale = dynamic_scale
|
||||||
|
|
||||||
# gmm1: gate_up_proj
|
# gmm1: gate_up_proj
|
||||||
gate_up_out = torch_npu.npu_grouped_matmul(x=[h],
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
weight=[w1],
|
x=[hidden_states],
|
||||||
split_item=3,
|
weight=[w1],
|
||||||
group_list_type=group_list_type,
|
split_item=3,
|
||||||
group_type=0,
|
group_list_type=group_list_type,
|
||||||
group_list=group_list,
|
group_type=0,
|
||||||
output_dtype=torch.int32)[0]
|
group_list=group_list,
|
||||||
|
output_dtype=torch.int32)[0]
|
||||||
|
|
||||||
swiglu_out, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
# act_fn: swiglu
|
||||||
x=gate_up_out,
|
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||||
|
x=hidden_states,
|
||||||
weight_scale=w1_scale,
|
weight_scale=w1_scale,
|
||||||
activation_scale=pertoken_scale,
|
activation_scale=pertoken_scale,
|
||||||
bias=None,
|
bias=None,
|
||||||
@@ -83,17 +87,18 @@ def apply_mlp(x: torch.Tensor,
|
|||||||
quant_mode=1,
|
quant_mode=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# down_proj
|
# gmm2: down_proj
|
||||||
down_out = torch_npu.npu_grouped_matmul(x=[swiglu_out],
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
weight=[w2],
|
x=[hidden_states],
|
||||||
scale=[w2_scale],
|
weight=[w2],
|
||||||
per_token_scale=[swiglu_out_scale],
|
scale=[w2_scale],
|
||||||
split_item=2,
|
per_token_scale=[swiglu_out_scale],
|
||||||
group_list_type=group_list_type,
|
split_item=2,
|
||||||
group_type=0,
|
group_list_type=group_list_type,
|
||||||
group_list=group_list,
|
group_type=0,
|
||||||
output_dtype=w2_scale.dtype)[0]
|
group_list=group_list,
|
||||||
return down_out
|
output_dtype=w2_scale.dtype)[0]
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_with_mc2(
|
def fused_experts_with_mc2(
|
||||||
@@ -152,7 +157,11 @@ def fused_experts_with_mc2(
|
|||||||
if quant_mode == 0:
|
if quant_mode == 0:
|
||||||
dynamic_scale = None
|
dynamic_scale = None
|
||||||
|
|
||||||
down_out_list = apply_mlp(expand_x,
|
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
|
||||||
|
hidden_states_wrapper = [expand_x]
|
||||||
|
del expand_x
|
||||||
|
|
||||||
|
down_out_list = apply_mlp(hidden_states_wrapper,
|
||||||
w1,
|
w1,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2,
|
w2,
|
||||||
@@ -246,7 +255,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
||||||
expert_tokens = token_counts[:num_experts]
|
expert_tokens = token_counts[:num_experts]
|
||||||
# Rearrange hidden_states
|
# Rearrange hidden_states
|
||||||
sorted_hidden_states = hidden_states[sorted_token_indices]
|
hidden_states = hidden_states[sorted_token_indices]
|
||||||
group_list_type = 1
|
group_list_type = 1
|
||||||
else:
|
else:
|
||||||
row_idx_len = num_tokens * top_k
|
row_idx_len = num_tokens * top_k
|
||||||
@@ -255,19 +264,22 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=topk_weights.device).view(
|
device=topk_weights.device).view(
|
||||||
top_k, -1).permute(1, 0).contiguous()
|
top_k, -1).permute(1, 0).contiguous()
|
||||||
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
row_idx=row_idx,
|
row_idx=row_idx,
|
||||||
expert_idx=topk_ids,
|
expert_idx=topk_ids,
|
||||||
active_num=num_tokens)
|
active_num=num_tokens)
|
||||||
del hidden_states
|
|
||||||
|
|
||||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||||
expanded_expert_idx, num_experts)
|
expanded_expert_idx, num_experts)
|
||||||
expert_tokens = expert_tokens.to(torch.int64)
|
expert_tokens = expert_tokens.to(torch.int64)
|
||||||
group_list_type = 0
|
group_list_type = 0
|
||||||
|
|
||||||
down_out_list = apply_mlp(sorted_hidden_states,
|
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
|
||||||
|
hidden_states_wrapper = [hidden_states]
|
||||||
|
del hidden_states
|
||||||
|
|
||||||
|
hidden_states = apply_mlp(hidden_states_wrapper,
|
||||||
w1,
|
w1,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2,
|
w2,
|
||||||
@@ -276,23 +288,23 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
group_list_type=group_list_type)
|
group_list_type=group_list_type)
|
||||||
|
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
down_out_list.mul_(sorted_weights.unsqueeze(1))
|
hidden_states.mul_(sorted_weights.unsqueeze(1))
|
||||||
final_hidden_states = torch.zeros(*original_shape,
|
final_hidden_states = torch.zeros(*original_shape,
|
||||||
device=hidden_states.device,
|
device=device,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
||||||
num_valid_tokens = mask.sum()
|
num_valid_tokens = mask.sum()
|
||||||
valid_token_mask = torch.arange(
|
valid_token_mask = torch.arange(
|
||||||
0, sorted_token_indices.shape[0],
|
0, sorted_token_indices.shape[0],
|
||||||
device=device).unsqueeze(1) < num_valid_tokens
|
device=device).unsqueeze(1) < num_valid_tokens
|
||||||
down_out_list = down_out_list.masked_fill_(~valid_token_mask,
|
hidden_states = hidden_states.masked_fill_(~valid_token_mask,
|
||||||
0).to(dtype)
|
0).to(dtype)
|
||||||
final_hidden_states.index_add_(0, sorted_token_indices, down_out_list)
|
final_hidden_states.index_add_(0, sorted_token_indices, hidden_states)
|
||||||
else:
|
else:
|
||||||
# TODO: Reorder device memory 2 times here, replace the current
|
# TODO: Reorder device memory 2 times here, replace the current
|
||||||
# implementation here when suitable operators become available.
|
# implementation here when suitable operators become available.
|
||||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||||
down_out_list,
|
hidden_states,
|
||||||
skip1=None,
|
skip1=None,
|
||||||
skip2=None,
|
skip2=None,
|
||||||
bias=None,
|
bias=None,
|
||||||
@@ -300,7 +312,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
expanded_src_to_dst_row=expanded_row_idx,
|
expanded_src_to_dst_row=expanded_row_idx,
|
||||||
export_for_source_row=topk_ids,
|
export_for_source_row=topk_ids,
|
||||||
)
|
)
|
||||||
del down_out_list
|
|
||||||
if len(original_shape) == 3:
|
if len(original_shape) == 3:
|
||||||
final_hidden_states = final_hidden_states.view(original_shape)
|
final_hidden_states = final_hidden_states.view(original_shape)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|||||||
Reference in New Issue
Block a user