[Feature] Add token mask for DispatchGmmCombineDecode operator (#5171)
### What this PR does / why we need it?
In this PR, DispatchGmmCombineDecode add an optional input
x_active_mask, with which
only token masked True will be dispatched and handle.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -16,6 +16,19 @@ torch.manual_seed(42)
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
enable_custom_op()
|
||||
LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
|
||||
BASE_KWARGS = {
|
||||
"batch_size": 64,
|
||||
"token_hidden_size": 7168,
|
||||
"moe_intermediate_size": 2048,
|
||||
"ep_world_size": 16,
|
||||
"moe_expert_num": 64,
|
||||
"shared_expert_rank_num": 0,
|
||||
"top_k": 8,
|
||||
"test_bfloat16": True,
|
||||
"enable_dynamic_bs": False,
|
||||
"test_graph": False,
|
||||
"with_mc2_mask": False
|
||||
}
|
||||
|
||||
|
||||
def redirect_output(log_file_path):
|
||||
@@ -115,11 +128,14 @@ class DecodeMoeOps(torch.nn.Module):
|
||||
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
|
||||
gmm2_weight_scale.float(), requires_grad=False)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
|
||||
def forward(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales)
|
||||
def forward(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask)
|
||||
|
||||
|
||||
class SmallOps(DecodeMoeOps):
|
||||
@@ -144,11 +160,13 @@ class SmallOps(DecodeMoeOps):
|
||||
shared_expert_rank_num)
|
||||
self.tp_hcomm_info = ""
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
outputs = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
x=x,
|
||||
expert_ids=expert_ids,
|
||||
expert_scales=expert_scales,
|
||||
x_active_mask=x_active_mask,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
ep_world_size=self.ep_world_size,
|
||||
ep_rank_id=self.global_rank_id,
|
||||
@@ -200,6 +218,7 @@ class SmallOps(DecodeMoeOps):
|
||||
assist_info_for_combine=assist_info_for_combine,
|
||||
ep_send_counts=ep_send_counts,
|
||||
expert_scales=expert_scales,
|
||||
x_active_mask=x_active_mask,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
ep_world_size=self.ep_world_size,
|
||||
ep_rank_id=self.global_rank_id,
|
||||
@@ -237,7 +256,8 @@ class FusionOp(DecodeMoeOps):
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
|
||||
x=x,
|
||||
expert_ids=expert_ids,
|
||||
@@ -245,8 +265,9 @@ class FusionOp(DecodeMoeOps):
|
||||
gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32,
|
||||
gmm2_weight=self.gmm2_weight,
|
||||
gmm2_weight_scale=self.gmm2_weight_scale_fp32,
|
||||
expert_smooth_scales=smooth_scales,
|
||||
expert_scales=expert_scales,
|
||||
expert_smooth_scales=smooth_scales,
|
||||
x_active_mask=x_active_mask,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
ep_rank_size=self.ep_world_size,
|
||||
ep_rank_id=self.global_rank_id,
|
||||
@@ -267,12 +288,13 @@ def generate_datas(batch_size,
|
||||
shared_expert_rank_num=0,
|
||||
top_k=8,
|
||||
test_bfloat16=True,
|
||||
enable_dynamic_bs=False):
|
||||
enable_dynamic_bs=False,
|
||||
with_mc2_mask=False):
|
||||
is_shared_expert = global_rank_id < shared_expert_rank_num
|
||||
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
|
||||
shared_expert_rank_num)
|
||||
actual_bs = int(
|
||||
torch.randint(1, batch_size, [1]).item(
|
||||
torch.randint(2 if with_mc2_mask else 1, batch_size, [1]).item(
|
||||
) if enable_dynamic_bs else batch_size)
|
||||
local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
|
||||
gmm1_input_dim = token_hidden_size
|
||||
@@ -317,9 +339,16 @@ def generate_datas(batch_size,
|
||||
else:
|
||||
x = x.half()
|
||||
smooth_sales = None
|
||||
return (x, expert_ids, smooth_sales, expert_scales), \
|
||||
x_active_mask = None
|
||||
valid_token_num = actual_bs
|
||||
if with_mc2_mask:
|
||||
valid_token_num = int(torch.randint(1, actual_bs, [1]).item())
|
||||
x_active_mask = torch.cat(
|
||||
(torch.ones(valid_token_num),
|
||||
torch.zeros(actual_bs - valid_token_num))).bool()
|
||||
return (x, expert_ids, smooth_sales, expert_scales, x_active_mask), \
|
||||
(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \
|
||||
actual_bs
|
||||
actual_bs, valid_token_num
|
||||
|
||||
|
||||
def run_once(local_rank_id,
|
||||
@@ -332,7 +361,8 @@ def run_once(local_rank_id,
|
||||
top_k=8,
|
||||
test_bfloat16=True,
|
||||
enable_dynamic_bs=False,
|
||||
test_graph=False):
|
||||
test_graph=False,
|
||||
with_mc2_mask=False):
|
||||
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
|
||||
) if output_to_file(local_rank_id) else None
|
||||
global_rank_id = local_rank_id # 单机
|
||||
@@ -358,8 +388,8 @@ def run_once(local_rank_id,
|
||||
parameter = (batch_size, token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
input_datas, weight_datas, actual_bs = generate_datas(
|
||||
*parameter, top_k, test_bfloat16, enable_dynamic_bs)
|
||||
input_datas, weight_datas, actual_bs, valid_token_num = generate_datas(
|
||||
*parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask)
|
||||
input_datas = [
|
||||
data.npu() if data is not None else None for data in input_datas
|
||||
]
|
||||
@@ -382,8 +412,8 @@ def run_once(local_rank_id,
|
||||
if log_file is not None:
|
||||
log_file.close()
|
||||
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
|
||||
torch.testing.assert_close(small_op_token_output.cpu(),
|
||||
fused_op_token_output.cpu(),
|
||||
torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(),
|
||||
fused_op_token_output[0:valid_token_num].cpu(),
|
||||
atol=2.0,
|
||||
rtol=0.02)
|
||||
torch.testing.assert_close(small_op_count_output.cpu(),
|
||||
@@ -394,18 +424,16 @@ def run_once(local_rank_id,
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test():
|
||||
batch_size = 64
|
||||
token_hidden_size = 7168
|
||||
moe_intermediate_size = 2048
|
||||
ep_world_size = 16
|
||||
moe_expert_num = 64
|
||||
shared_expert_rank_num = 0
|
||||
top_k = 8
|
||||
test_bfloat16 = True
|
||||
enable_dynamic_bs = False
|
||||
test_graph = False
|
||||
args = (batch_size, token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, shared_expert_rank_num, top_k,
|
||||
test_bfloat16, enable_dynamic_bs, test_graph)
|
||||
mp.spawn(run_once, args=args, nprocs=ep_world_size, join=True)
|
||||
def test_dispatch_gmm_combine_decode_base():
|
||||
custom_kwargs = BASE_KWARGS
|
||||
ep_world_size = custom_kwargs["ep_world_size"]
|
||||
custom_args = tuple(custom_kwargs.values())
|
||||
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||
|
||||
|
||||
def test_dispatch_gmm_combine_decode_with_mc2_mask():
|
||||
custom_kwargs = BASE_KWARGS
|
||||
custom_kwargs["with_mc2_mask"] = True
|
||||
ep_world_size = custom_kwargs["ep_world_size"]
|
||||
custom_args = tuple(custom_kwargs.values())
|
||||
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||
|
||||
Reference in New Issue
Block a user