[Feature]EPLB:Adapt DispatchGmmCombineDecode operator to eplb tensor list and expert token numbers (#5552)

#### What this PR does / why we need it?
This PR adapt DispatchGmmCombineDecode operator to eplb tensor list and
expert token numbers.

This operator support gmm1, gmm2, gmm1Scale and gmm2Scale in format of
list.
This operator support couting how many token each local expert recieves
by expertTokensNum .


- vLLM version: v0.13.0
- vLLM main:
7157596103

More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
This commit is contained in:
wangyibo1005
2026-01-07 11:23:42 +08:00
committed by GitHub
parent 086c093347
commit 25baf6df09
18 changed files with 425 additions and 195 deletions

View File

@@ -27,7 +27,8 @@ BASE_KWARGS = {
"test_bfloat16": True,
"enable_dynamic_bs": False,
"test_graph": False,
"with_mc2_mask": False
"with_mc2_mask": False,
"dynamic_eplb": False
}
@@ -48,20 +49,6 @@ def permute_weight(w: torch.Tensor, tile_n):
n).contiguous()
def from_inclusive_prefix_sum(pref):
if isinstance(pref, torch.Tensor):
if pref.numel() == 0:
return pref
return torch.cat([pref[:1], pref[1:] - pref[:-1]])
if not pref:
return []
out = [pref[0]]
for i in range(1, len(pref)):
out.append(pref[i] - pref[i - 1])
return out
def output_to_file(rank_id):
return False
@@ -80,7 +67,8 @@ class DecodeMoeOps(torch.nn.Module):
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0):
shared_expert_rank_num=0,
dynamic_eplb=False):
super().__init__()
self.ep_hcomm_info = ep_hcomm_info
self.batch_size = batch_size
@@ -95,6 +83,7 @@ class DecodeMoeOps(torch.nn.Module):
shared_expert_rank_num)
self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
self.ep_recv_count_size = self.local_expert_num * ep_world_size
self.dynamic_eplb = dynamic_eplb
self.gmm1_weight = torch.empty([
self.local_expert_num, self.token_hidden_size,
self.moe_intermediate_size * 2
@@ -152,12 +141,13 @@ class SmallOps(DecodeMoeOps):
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0):
shared_expert_rank_num=0,
dynamic_eplb=False):
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
gmm2_weight_scale, ep_hcomm_info, batch_size,
token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
shared_expert_rank_num, dynamic_eplb)
self.tp_hcomm_info = ""
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
@@ -232,7 +222,7 @@ class SmallOps(DecodeMoeOps):
shared_expert_num=1,
shared_expert_rank_num=self.shared_expert_rank_num,
global_bs=self.batch_size * self.ep_world_size)
return (combine_output, ep_send_counts[:self.ep_recv_count_size])
return (combine_output, expert_token_nums)
class FusionOp(DecodeMoeOps):
@@ -249,12 +239,13 @@ class FusionOp(DecodeMoeOps):
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0):
shared_expert_rank_num=0,
dynamic_eplb=False):
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
gmm2_weight_scale, ep_hcomm_info, batch_size,
token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
shared_expert_rank_num, dynamic_eplb)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
@@ -278,6 +269,34 @@ class FusionOp(DecodeMoeOps):
global_bs=self.batch_size * self.ep_world_size)
return output
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale):
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.FRACTAL_NZ)
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
torch_npu.Format.FRACTAL_NZ)
gmm1_weight_scale = gmm1_weight_scale.float()
gmm2_weight_scale = gmm2_weight_scale.float()
if self.dynamic_eplb:
self.gmm1_weight = [
weight.clone() for weight in gmm1_weight.unbind(dim=0)
]
self.gmm1_weight_scale_fp32 = [
weight.clone() for weight in gmm1_weight_scale.unbind(dim=0)
]
self.gmm2_weight = [
weight.clone() for weight in gmm2_weight.unbind(dim=0)
]
self.gmm2_weight_scale_fp32 = [
weight.clone() for weight in gmm2_weight_scale.unbind(dim=0)
]
else:
self.gmm1_weight = [gmm1_weight.clone()]
self.gmm1_weight_scale_fp32 = [gmm1_weight_scale.clone()]
self.gmm2_weight = [gmm2_weight.clone()]
self.gmm2_weight_scale_fp32 = [gmm2_weight_scale.clone()]
def generate_datas(batch_size,
token_hidden_size,
@@ -362,7 +381,8 @@ def run_once(local_rank_id,
test_bfloat16=True,
enable_dynamic_bs=False,
test_graph=False,
with_mc2_mask=False):
with_mc2_mask=False,
dynamic_eplb=False):
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
) if output_to_file(local_rank_id) else None
global_rank_id = local_rank_id # 单机
@@ -396,10 +416,10 @@ def run_once(local_rank_id,
weight_datas = [
data.npu() if data is not None else None for data in weight_datas
]
small_ops = SmallOps(*weight_datas, ep_hcomm_info_small,
*parameter).npu() # type: ignore
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused,
*parameter).npu() # type: ignore
small_ops = SmallOps(*weight_datas, ep_hcomm_info_small, *parameter,
dynamic_eplb).npu() # type: ignore
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, *parameter,
dynamic_eplb).npu() # type: ignore
if test_graph:
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
@@ -411,7 +431,7 @@ def run_once(local_rank_id,
dist.destroy_process_group()
if log_file is not None:
log_file.close()
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(),
fused_op_token_output[0:valid_token_num].cpu(),
atol=2.0,
@@ -431,9 +451,19 @@ def test_dispatch_gmm_combine_decode_base():
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
@torch.inference_mode()
def test_dispatch_gmm_combine_decode_with_mc2_mask():
custom_kwargs = BASE_KWARGS
custom_kwargs["with_mc2_mask"] = True
ep_world_size = custom_kwargs["ep_world_size"]
custom_args = tuple(custom_kwargs.values())
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
@torch.inference_mode()
def test_dispatch_gmm_combine_decode_dynamic_eplb():
custom_kwargs = BASE_KWARGS
custom_kwargs["dynamic_eplb"] = True
ep_world_size = custom_kwargs["ep_world_size"]
custom_args = tuple(custom_kwargs.values())
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)