[Feature]EPLB:Adapt DispatchGmmCombineDecode operator to eplb tensor list and expert token numbers (#5552)
#### What this PR does / why we need it?
This PR adapt DispatchGmmCombineDecode operator to eplb tensor list and
expert token numbers.
This operator support gmm1, gmm2, gmm1Scale and gmm2Scale in format of
list.
This operator support couting how many token each local expert recieves
by expertTokensNum .
- vLLM version: v0.13.0
- vLLM main:
7157596103
More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
This commit is contained in:
@@ -27,7 +27,8 @@ BASE_KWARGS = {
|
||||
"test_bfloat16": True,
|
||||
"enable_dynamic_bs": False,
|
||||
"test_graph": False,
|
||||
"with_mc2_mask": False
|
||||
"with_mc2_mask": False,
|
||||
"dynamic_eplb": False
|
||||
}
|
||||
|
||||
|
||||
@@ -48,20 +49,6 @@ def permute_weight(w: torch.Tensor, tile_n):
|
||||
n).contiguous()
|
||||
|
||||
|
||||
def from_inclusive_prefix_sum(pref):
|
||||
if isinstance(pref, torch.Tensor):
|
||||
if pref.numel() == 0:
|
||||
return pref
|
||||
return torch.cat([pref[:1], pref[1:] - pref[:-1]])
|
||||
|
||||
if not pref:
|
||||
return []
|
||||
out = [pref[0]]
|
||||
for i in range(1, len(pref)):
|
||||
out.append(pref[i] - pref[i - 1])
|
||||
return out
|
||||
|
||||
|
||||
def output_to_file(rank_id):
|
||||
return False
|
||||
|
||||
@@ -80,7 +67,8 @@ class DecodeMoeOps(torch.nn.Module):
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0):
|
||||
shared_expert_rank_num=0,
|
||||
dynamic_eplb=False):
|
||||
super().__init__()
|
||||
self.ep_hcomm_info = ep_hcomm_info
|
||||
self.batch_size = batch_size
|
||||
@@ -95,6 +83,7 @@ class DecodeMoeOps(torch.nn.Module):
|
||||
shared_expert_rank_num)
|
||||
self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
|
||||
self.ep_recv_count_size = self.local_expert_num * ep_world_size
|
||||
self.dynamic_eplb = dynamic_eplb
|
||||
self.gmm1_weight = torch.empty([
|
||||
self.local_expert_num, self.token_hidden_size,
|
||||
self.moe_intermediate_size * 2
|
||||
@@ -152,12 +141,13 @@ class SmallOps(DecodeMoeOps):
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0):
|
||||
shared_expert_rank_num=0,
|
||||
dynamic_eplb=False):
|
||||
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
|
||||
gmm2_weight_scale, ep_hcomm_info, batch_size,
|
||||
token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
shared_expert_rank_num, dynamic_eplb)
|
||||
self.tp_hcomm_info = ""
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
@@ -232,7 +222,7 @@ class SmallOps(DecodeMoeOps):
|
||||
shared_expert_num=1,
|
||||
shared_expert_rank_num=self.shared_expert_rank_num,
|
||||
global_bs=self.batch_size * self.ep_world_size)
|
||||
return (combine_output, ep_send_counts[:self.ep_recv_count_size])
|
||||
return (combine_output, expert_token_nums)
|
||||
|
||||
|
||||
class FusionOp(DecodeMoeOps):
|
||||
@@ -249,12 +239,13 @@ class FusionOp(DecodeMoeOps):
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0):
|
||||
shared_expert_rank_num=0,
|
||||
dynamic_eplb=False):
|
||||
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
|
||||
gmm2_weight_scale, ep_hcomm_info, batch_size,
|
||||
token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
shared_expert_rank_num, dynamic_eplb)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
@@ -278,6 +269,34 @@ class FusionOp(DecodeMoeOps):
|
||||
global_bs=self.batch_size * self.ep_world_size)
|
||||
return output
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm1_weight_scale = gmm1_weight_scale.float()
|
||||
gmm2_weight_scale = gmm2_weight_scale.float()
|
||||
|
||||
if self.dynamic_eplb:
|
||||
self.gmm1_weight = [
|
||||
weight.clone() for weight in gmm1_weight.unbind(dim=0)
|
||||
]
|
||||
self.gmm1_weight_scale_fp32 = [
|
||||
weight.clone() for weight in gmm1_weight_scale.unbind(dim=0)
|
||||
]
|
||||
self.gmm2_weight = [
|
||||
weight.clone() for weight in gmm2_weight.unbind(dim=0)
|
||||
]
|
||||
self.gmm2_weight_scale_fp32 = [
|
||||
weight.clone() for weight in gmm2_weight_scale.unbind(dim=0)
|
||||
]
|
||||
else:
|
||||
self.gmm1_weight = [gmm1_weight.clone()]
|
||||
self.gmm1_weight_scale_fp32 = [gmm1_weight_scale.clone()]
|
||||
self.gmm2_weight = [gmm2_weight.clone()]
|
||||
self.gmm2_weight_scale_fp32 = [gmm2_weight_scale.clone()]
|
||||
|
||||
|
||||
def generate_datas(batch_size,
|
||||
token_hidden_size,
|
||||
@@ -362,7 +381,8 @@ def run_once(local_rank_id,
|
||||
test_bfloat16=True,
|
||||
enable_dynamic_bs=False,
|
||||
test_graph=False,
|
||||
with_mc2_mask=False):
|
||||
with_mc2_mask=False,
|
||||
dynamic_eplb=False):
|
||||
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
|
||||
) if output_to_file(local_rank_id) else None
|
||||
global_rank_id = local_rank_id # 单机
|
||||
@@ -396,10 +416,10 @@ def run_once(local_rank_id,
|
||||
weight_datas = [
|
||||
data.npu() if data is not None else None for data in weight_datas
|
||||
]
|
||||
small_ops = SmallOps(*weight_datas, ep_hcomm_info_small,
|
||||
*parameter).npu() # type: ignore
|
||||
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused,
|
||||
*parameter).npu() # type: ignore
|
||||
small_ops = SmallOps(*weight_datas, ep_hcomm_info_small, *parameter,
|
||||
dynamic_eplb).npu() # type: ignore
|
||||
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, *parameter,
|
||||
dynamic_eplb).npu() # type: ignore
|
||||
if test_graph:
|
||||
config = torchair.CompilerConfig()
|
||||
config.mode = "reduce-overhead"
|
||||
@@ -411,7 +431,7 @@ def run_once(local_rank_id,
|
||||
dist.destroy_process_group()
|
||||
if log_file is not None:
|
||||
log_file.close()
|
||||
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
|
||||
|
||||
torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(),
|
||||
fused_op_token_output[0:valid_token_num].cpu(),
|
||||
atol=2.0,
|
||||
@@ -431,9 +451,19 @@ def test_dispatch_gmm_combine_decode_base():
|
||||
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dispatch_gmm_combine_decode_with_mc2_mask():
|
||||
custom_kwargs = BASE_KWARGS
|
||||
custom_kwargs["with_mc2_mask"] = True
|
||||
ep_world_size = custom_kwargs["ep_world_size"]
|
||||
custom_args = tuple(custom_kwargs.values())
|
||||
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dispatch_gmm_combine_decode_dynamic_eplb():
|
||||
custom_kwargs = BASE_KWARGS
|
||||
custom_kwargs["dynamic_eplb"] = True
|
||||
ep_world_size = custom_kwargs["ep_world_size"]
|
||||
custom_args = tuple(custom_kwargs.values())
|
||||
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||
|
||||
Reference in New Issue
Block a user