[Feature] DispatchGmmCombineDecode support bf16/float16 gmm1/gmm2 weight and support gmm weight with ND format (#6393)
### What this PR does / why we need it?
1. support ND format gmm weight input.
Before this pr, gmm1_weight and gmm2_weight could only be passed as
input to the DispatchGmmCombineDecode operator in NZ data format. After
the modification, they are allowed to be passed in ND data format.
2. support bf16/float16 gmm weight
The current PR modification enables the DispatchGmmCombineDecode
operator to support non-W8A8 scenarios, allowing gmm1_weight and
gmm2_weight to be passed as float16/bfloat16 which is correspond with
input token data type.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
Signed-off-by: lih827 <383084552@qq.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
@@ -28,7 +29,9 @@ BASE_KWARGS = {
|
||||
"enable_dynamic_bs": False,
|
||||
"test_graph": False,
|
||||
"with_mc2_mask": False,
|
||||
"dynamic_eplb": False
|
||||
"dynamic_eplb": False,
|
||||
"w8a8_dynamic": True,
|
||||
"is_nz": True
|
||||
}
|
||||
|
||||
|
||||
@@ -50,7 +53,7 @@ def permute_weight(w: torch.Tensor, tile_n):
|
||||
|
||||
|
||||
def output_to_file(rank_id):
|
||||
return False
|
||||
return rank_id > 0
|
||||
|
||||
|
||||
class DecodeMoeOps(torch.nn.Module):
|
||||
@@ -68,8 +71,14 @@ class DecodeMoeOps(torch.nn.Module):
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0,
|
||||
dynamic_eplb=False):
|
||||
dynamic_eplb=False,
|
||||
w8a8_dynamic=True,
|
||||
is_nz=True):
|
||||
super().__init__()
|
||||
if w8a8_dynamic:
|
||||
assert (gmm1_weight_scale is not None and gmm2_weight_scale is not None), "gmm1_weight_scale and gmm2_weight_scale must be provided for w8a8_dynamic"
|
||||
else:
|
||||
assert (gmm1_weight_scale is None and gmm2_weight_scale is None), "gmm1_weight_scale and gmm2_weight_scale must be None for w8a8_dynamic"
|
||||
self.ep_hcomm_info = ep_hcomm_info
|
||||
self.batch_size = batch_size
|
||||
self.token_hidden_size = token_hidden_size
|
||||
@@ -84,38 +93,47 @@ class DecodeMoeOps(torch.nn.Module):
|
||||
self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
|
||||
self.ep_recv_count_size = self.local_expert_num * ep_world_size
|
||||
self.dynamic_eplb = dynamic_eplb
|
||||
self.w8a8_dynamic = w8a8_dynamic
|
||||
self.is_nz = is_nz
|
||||
self.gmm1_weight = torch.empty([
|
||||
self.local_expert_num, self.token_hidden_size,
|
||||
self.moe_intermediate_size * 2
|
||||
])
|
||||
self.gmm1_weight_scale = torch.empty(
|
||||
[self.local_expert_num, self.moe_intermediate_size * 2])
|
||||
self.gmm2_weight = torch.empty([
|
||||
self.local_expert_num, self.moe_intermediate_size,
|
||||
self.token_hidden_size
|
||||
])
|
||||
self.gmm2_weight_scale = torch.empty(
|
||||
[self.local_expert_num, self.token_hidden_size])
|
||||
if self.w8a8_dynamic:
|
||||
self.gmm1_weight_scale = torch.empty(
|
||||
[self.local_expert_num, self.moe_intermediate_size * 2])
|
||||
self.gmm2_weight_scale = torch.empty(
|
||||
[self.local_expert_num, self.token_hidden_size])
|
||||
else:
|
||||
self.gmm1_weight_scale = None
|
||||
self.gmm2_weight_scale = None
|
||||
self.gmm1_weight_scale_fp32 = None
|
||||
self.gmm2_weight_scale_fp32 = None
|
||||
self._process_weights_after_loading(gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale)
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
if self.w8a8_dynamic:
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
|
||||
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
|
||||
requires_grad=False)
|
||||
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
|
||||
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
self.gmm1_weight_scale_fp32 = torch.nn.Parameter(
|
||||
gmm1_weight_scale.float(), requires_grad=False)
|
||||
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
|
||||
gmm2_weight_scale.float(), requires_grad=False)
|
||||
if self.w8a8_dynamic:
|
||||
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
|
||||
requires_grad=False)
|
||||
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
|
||||
requires_grad=False)
|
||||
self.gmm1_weight_scale_fp32 = torch.nn.Parameter(
|
||||
gmm1_weight_scale.float(), requires_grad=False)
|
||||
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
|
||||
gmm2_weight_scale.float(), requires_grad=False)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
@@ -142,12 +160,15 @@ class SmallOps(DecodeMoeOps):
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0,
|
||||
dynamic_eplb=False):
|
||||
dynamic_eplb=False,
|
||||
w8a8_dynamic=True,
|
||||
is_nz=True):
|
||||
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
|
||||
gmm2_weight_scale, ep_hcomm_info, batch_size,
|
||||
token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num, dynamic_eplb)
|
||||
shared_expert_rank_num, dynamic_eplb, w8a8_dynamic,
|
||||
is_nz)
|
||||
self.tp_hcomm_info = ""
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
@@ -167,7 +188,7 @@ class SmallOps(DecodeMoeOps):
|
||||
expert_shard_type=0,
|
||||
shared_expert_num=1,
|
||||
shared_expert_rank_num=self.shared_expert_rank_num,
|
||||
quant_mode=2,
|
||||
quant_mode=2 if self.w8a8_dynamic else 0,
|
||||
global_bs=self.batch_size * self.ep_world_size,
|
||||
expert_token_nums_type=1, # 0代表前缀和,1代表各自数量
|
||||
)
|
||||
@@ -181,22 +202,26 @@ class SmallOps(DecodeMoeOps):
|
||||
group_list_type=1, # 默认为0,代表前缀和形式
|
||||
group_type=0, # 0代表m轴分组
|
||||
group_list=expert_token_nums,
|
||||
output_dtype=torch.int32)[0]
|
||||
y1, y1_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=y1_int32,
|
||||
weight_scale=self.gmm1_weight_scale.to(torch.float32),
|
||||
activation_scale=dynamic_scales,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=expert_token_nums,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
output_dtype=torch.int32 if self.w8a8_dynamic else output_dtype)[0]
|
||||
y1_scale = None
|
||||
if self.w8a8_dynamic:
|
||||
y1, y1_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=y1_int32,
|
||||
weight_scale=self.gmm1_weight_scale.to(torch.float32),
|
||||
activation_scale=dynamic_scales,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=expert_token_nums,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
else:
|
||||
y1 = torch_npu.npu_swiglu(y1_int32)
|
||||
y2 = torch_npu.npu_grouped_matmul(x=[y1],
|
||||
weight=[self.gmm2_weight],
|
||||
scale=[self.gmm2_weight_scale],
|
||||
per_token_scale=[y1_scale],
|
||||
scale=[self.gmm2_weight_scale] if self.w8a8_dynamic else None,
|
||||
per_token_scale=[y1_scale] if self.w8a8_dynamic else None,
|
||||
split_item=2,
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
@@ -240,15 +265,19 @@ class FusionOp(DecodeMoeOps):
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0,
|
||||
dynamic_eplb=False):
|
||||
dynamic_eplb=False,
|
||||
w8a8_dynamic=True,
|
||||
is_nz=True):
|
||||
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
|
||||
gmm2_weight_scale, ep_hcomm_info, batch_size,
|
||||
token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num, dynamic_eplb)
|
||||
shared_expert_rank_num, dynamic_eplb, w8a8_dynamic,
|
||||
is_nz)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
smooth_scales = torch.zeros(128 * 1024 * 1024).npu()
|
||||
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
|
||||
x=x,
|
||||
expert_ids=expert_ids,
|
||||
@@ -271,29 +300,35 @@ class FusionOp(DecodeMoeOps):
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
if self.is_nz:
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
|
||||
if self.dynamic_eplb:
|
||||
self.gmm1_weight = [
|
||||
weight.clone() for weight in gmm1_weight.unbind(dim=0)
|
||||
]
|
||||
self.gmm1_weight_scale_fp32 = [
|
||||
weight.clone() for weight in gmm1_weight_scale.unbind(dim=0)
|
||||
]
|
||||
self.gmm2_weight = [
|
||||
weight.clone() for weight in gmm2_weight.unbind(dim=0)
|
||||
]
|
||||
self.gmm2_weight_scale_fp32 = [
|
||||
weight.clone() for weight in gmm2_weight_scale.unbind(dim=0)
|
||||
]
|
||||
if self.w8a8_dynamic:
|
||||
self.gmm1_weight_scale_fp32 = [
|
||||
weight.clone() for weight in gmm1_weight_scale.unbind(dim=0)
|
||||
]
|
||||
self.gmm2_weight_scale_fp32 = [
|
||||
weight.clone() for weight in gmm2_weight_scale.unbind(dim=0)
|
||||
]
|
||||
else:
|
||||
self.gmm1_weight = [gmm1_weight.clone()]
|
||||
self.gmm1_weight_scale_fp32 = [gmm1_weight_scale.clone()]
|
||||
self.gmm2_weight = [gmm2_weight.clone()]
|
||||
self.gmm2_weight_scale_fp32 = [gmm2_weight_scale.clone()]
|
||||
if self.w8a8_dynamic:
|
||||
self.gmm1_weight_scale_fp32 = [gmm1_weight_scale.clone()]
|
||||
self.gmm2_weight_scale_fp32 = [gmm2_weight_scale.clone()]
|
||||
else:
|
||||
self.gmm1_weight_scale_fp32 = [torch.ones(1).npu().to(gmm1_weight.dtype)]
|
||||
self.gmm2_weight_scale_fp32 = [torch.ones(1).npu().to(gmm2_weight.dtype)]
|
||||
|
||||
|
||||
def generate_datas(batch_size,
|
||||
@@ -306,7 +341,8 @@ def generate_datas(batch_size,
|
||||
top_k=8,
|
||||
test_bfloat16=True,
|
||||
enable_dynamic_bs=False,
|
||||
with_mc2_mask=False):
|
||||
with_mc2_mask=False,
|
||||
w8a8_dynamic=True):
|
||||
is_shared_expert = global_rank_id < shared_expert_rank_num
|
||||
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
|
||||
shared_expert_rank_num)
|
||||
@@ -318,41 +354,59 @@ def generate_datas(batch_size,
|
||||
gmm1_output_dim = moe_intermediate_size * 2
|
||||
gmm2_input_dim = moe_intermediate_size
|
||||
gmm2_output_dim = token_hidden_size
|
||||
x = torch.rand([actual_bs, token_hidden_size]) * 10 - 5
|
||||
x = torch.rand([actual_bs, token_hidden_size]) * 0.5 - 0.5
|
||||
expert_ids = torch.arange(
|
||||
global_rank_id * batch_size * top_k,
|
||||
global_rank_id * batch_size * top_k + actual_bs * top_k).to(
|
||||
torch.int32).view(actual_bs, top_k)
|
||||
expert_ids = expert_ids % moe_expert_num
|
||||
if is_shared_expert:
|
||||
gmm1_weight = torch.ones([
|
||||
local_expert_num, gmm1_input_dim, gmm1_output_dim
|
||||
]).to(torch.int8) * 4
|
||||
gmm2_weight = torch.ones([
|
||||
local_expert_num, gmm2_input_dim, gmm2_output_dim
|
||||
]).to(torch.int8) * 4
|
||||
gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1
|
||||
gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1
|
||||
gmm1_weight_scale = torch.ones([local_expert_num, gmm1_output_dim
|
||||
]) * 0.0015
|
||||
gmm2_weight_scale = torch.ones([local_expert_num, gmm2_output_dim
|
||||
]) * 0.0015
|
||||
gmm1_weight_scale = None
|
||||
gmm2_weight_scale = None
|
||||
if w8a8_dynamic:
|
||||
if is_shared_expert:
|
||||
gmm1_weight = torch.ones([
|
||||
local_expert_num, gmm1_input_dim, gmm1_output_dim
|
||||
]).to(torch.int8) * 4
|
||||
gmm2_weight = torch.ones([
|
||||
local_expert_num, gmm2_input_dim, gmm2_output_dim
|
||||
]).to(torch.int8) * 4
|
||||
gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1
|
||||
gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1
|
||||
gmm1_weight_scale = torch.ones([local_expert_num, gmm1_output_dim
|
||||
]) * 0.0015
|
||||
gmm2_weight_scale = torch.ones([local_expert_num, gmm2_output_dim
|
||||
]) * 0.0015
|
||||
else:
|
||||
gmm1_weight = torch.randint(
|
||||
-16, 16,
|
||||
[local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.int8)
|
||||
gmm2_weight = torch.randint(
|
||||
-16, 16,
|
||||
[local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.int8)
|
||||
gmm1_weight_scale = torch.rand([local_expert_num, gmm1_output_dim
|
||||
]) * 0.003 + 0.0015
|
||||
gmm2_weight_scale = torch.rand([local_expert_num, gmm2_output_dim
|
||||
]) * 0.003 + 0.0015
|
||||
else:
|
||||
gmm1_weight = torch.randint(
|
||||
-16, 16,
|
||||
[local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.int8)
|
||||
gmm2_weight = torch.randint(
|
||||
-16, 16,
|
||||
[local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.int8)
|
||||
gmm1_weight_scale = torch.rand([local_expert_num, gmm1_output_dim
|
||||
]) * 0.003 + 0.0015
|
||||
gmm2_weight_scale = torch.rand([local_expert_num, gmm2_output_dim
|
||||
]) * 0.003 + 0.0015
|
||||
if is_shared_expert:
|
||||
gmm1_weight = torch.ones([
|
||||
local_expert_num, gmm1_input_dim, gmm1_output_dim
|
||||
]).to(torch.bfloat16 if test_bfloat16 else torch.float16) * 0.5
|
||||
gmm2_weight = torch.ones([
|
||||
local_expert_num, gmm2_input_dim, gmm2_output_dim
|
||||
]).to(torch.bfloat16 if test_bfloat16 else torch.float16) * 0.5
|
||||
else:
|
||||
gmm1_weight = torch.rand([local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.bfloat16 if test_bfloat16 else torch.float16) * 0.25
|
||||
gmm2_weight = torch.rand([local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.bfloat16 if test_bfloat16 else torch.float16) * 0.25
|
||||
gmm1_weight[:, ::2, :] = gmm1_weight[:, ::2, :] * -1
|
||||
gmm2_weight[:, ::2, :] = gmm2_weight[:, ::2, :] * -1
|
||||
expert_scales = torch.rand(actual_bs, top_k)
|
||||
if test_bfloat16:
|
||||
x = x.bfloat16()
|
||||
gmm1_weight_scale = gmm1_weight_scale.bfloat16()
|
||||
gmm2_weight_scale = gmm2_weight_scale.bfloat16()
|
||||
if w8a8_dynamic:
|
||||
assert (gmm1_weight_scale is not None and gmm2_weight_scale is not None), "gmm1_weight_scale and gmm2_weight_scale must be provided for w8a8_dynamic"
|
||||
gmm1_weight_scale = gmm1_weight_scale.bfloat16()
|
||||
gmm2_weight_scale = gmm2_weight_scale.bfloat16()
|
||||
else:
|
||||
x = x.half()
|
||||
smooth_sales = None
|
||||
@@ -380,7 +434,9 @@ def run_once(local_rank_id,
|
||||
enable_dynamic_bs=False,
|
||||
test_graph=False,
|
||||
with_mc2_mask=False,
|
||||
dynamic_eplb=False):
|
||||
dynamic_eplb=False,
|
||||
w8a8_dynamic=True,
|
||||
is_nz=True):
|
||||
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
|
||||
) if output_to_file(local_rank_id) else None
|
||||
global_rank_id = local_rank_id # 单机
|
||||
@@ -407,7 +463,7 @@ def run_once(local_rank_id,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
input_datas, weight_datas, actual_bs, valid_token_num = generate_datas(
|
||||
*parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask)
|
||||
*parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask, w8a8_dynamic)
|
||||
input_datas = [
|
||||
data.npu() if data is not None else None for data in input_datas
|
||||
]
|
||||
@@ -415,27 +471,52 @@ def run_once(local_rank_id,
|
||||
data.npu() if data is not None else None for data in weight_datas
|
||||
]
|
||||
small_ops = SmallOps(*weight_datas, ep_hcomm_info_small, *parameter,
|
||||
dynamic_eplb).npu() # type: ignore
|
||||
dynamic_eplb, w8a8_dynamic, is_nz).npu() # type: ignore
|
||||
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, *parameter,
|
||||
dynamic_eplb).npu() # type: ignore
|
||||
dynamic_eplb, w8a8_dynamic, is_nz).npu() # type: ignore
|
||||
if test_graph:
|
||||
config = torchair.CompilerConfig()
|
||||
config.mode = "reduce-overhead"
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
fused_ops = torch.compile(fused_ops, backend=npu_backend)
|
||||
|
||||
# test performance
|
||||
start_time = time.perf_counter()
|
||||
for _ in range(100):
|
||||
small_op_token_output, small_op_count_output = small_ops(*input_datas)
|
||||
torch_npu.npu.synchronize(device_id)
|
||||
end_time = time.perf_counter()
|
||||
elapsed_time = end_time - start_time
|
||||
elapsed_time_us = elapsed_time * 1000000
|
||||
print(f"rank-{global_rank_id} small {elapsed_time_us} us")
|
||||
start_time = time.perf_counter()
|
||||
for _ in range(100):
|
||||
fused_op_token_output, fused_op_count_output = fused_ops(*input_datas)
|
||||
torch_npu.npu.synchronize(device_id)
|
||||
end_time = time.perf_counter()
|
||||
elapsed_time = end_time - start_time
|
||||
elapsed_time_us = elapsed_time * 1000000
|
||||
print(f"rank-{global_rank_id} fused {elapsed_time_us} us")
|
||||
small_op_token_output, small_op_count_output = small_ops(*input_datas)
|
||||
torch_npu.npu.synchronize(device_id)
|
||||
print(f"rank-{global_rank_id} Small op End")
|
||||
fused_op_token_output, fused_op_count_output = fused_ops(*input_datas)
|
||||
torch_npu.npu.synchronize(device_id)
|
||||
print(f"rank-{global_rank_id} Fused op End")
|
||||
dist.destroy_process_group()
|
||||
if log_file is not None:
|
||||
log_file.close()
|
||||
|
||||
torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(),
|
||||
fused_op_token_output[0:valid_token_num].cpu(),
|
||||
atol=2.0,
|
||||
rtol=0.02)
|
||||
torch.testing.assert_close(small_op_count_output.cpu(),
|
||||
fused_op_count_output.cpu())
|
||||
try:
|
||||
torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(),
|
||||
fused_op_token_output[0:valid_token_num].cpu(),
|
||||
atol=2.0,
|
||||
rtol=0.02)
|
||||
torch.testing.assert_close(small_op_count_output.cpu(),
|
||||
fused_op_count_output.cpu())
|
||||
except Exception as e:
|
||||
print(f"rank-{global_rank_id} Assert close Failed: {e}")
|
||||
else:
|
||||
print(f"rank-{global_rank_id} Assert close Pass")
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -444,9 +525,16 @@ def run_once(local_rank_id,
|
||||
@torch.inference_mode()
|
||||
def test_dispatch_gmm_combine_decode_base():
|
||||
custom_kwargs = BASE_KWARGS
|
||||
custom_kwargs["batch_size"] = 32
|
||||
custom_kwargs["ep_world_size"] = 8
|
||||
custom_kwargs["moe_expert_num"] = 32
|
||||
custom_kwargs["w8a8_dynamic"] = False
|
||||
custom_kwargs["is_nz"] = True
|
||||
ep_world_size = custom_kwargs["ep_world_size"]
|
||||
custom_args = tuple(custom_kwargs.values())
|
||||
print(f"{custom_kwargs=}")
|
||||
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||
print(f"{custom_kwargs=}")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -465,3 +553,6 @@ def test_dispatch_gmm_combine_decode_dynamic_eplb():
|
||||
ep_world_size = custom_kwargs["ep_world_size"]
|
||||
custom_args = tuple(custom_kwargs.values())
|
||||
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dispatch_gmm_combine_decode_base()
|
||||
|
||||
Reference in New Issue
Block a user