[CustomOp] support TensorList for dispatchFFNCombine (#5665)
### What this PR does / why we need it?
To support tensorList for dispatch_ffn_combine, to adjust eplb
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
Single Operator Testing
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: lhchg <lhao_cheng@163.com>
Co-authored-by: lihaocheng <lihaosheng1@h-partners.com>
This commit is contained in:
@@ -87,13 +87,13 @@ class TestDisptachFFNCombine:
|
||||
hcomm_info = hcomm_info_dist["default_pg_info"]
|
||||
self.hcomm_info = hcomm_info
|
||||
|
||||
def run_npu_out(self) -> bool:
|
||||
def run_tensor_list(self) -> bool:
|
||||
torch_npu.npu.set_device(self.rank)
|
||||
m = 2 # token-num 32
|
||||
k = 4 # hidden_size 7168
|
||||
n = 4 # mid-hidden-size 4096
|
||||
topk = 2
|
||||
e = 2 # expert-num-per-rank 16
|
||||
m = 64
|
||||
k = 1024
|
||||
n = 1024
|
||||
topk = 8
|
||||
e = 8
|
||||
k2 = n // 2
|
||||
n2 = k
|
||||
|
||||
@@ -112,15 +112,79 @@ class TestDisptachFFNCombine:
|
||||
scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu()
|
||||
scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu()
|
||||
probs = torch.randn(size=(m, topk), dtype=torch.float32).npu()
|
||||
|
||||
weight1_nz_npu = []
|
||||
weight2_nz_npu = []
|
||||
scale1_npu = []
|
||||
scale2_npu = []
|
||||
for i in range(e):
|
||||
weight1_nz_npu.append(
|
||||
torch_npu.npu_format_cast(weight1[i].npu(), 29))
|
||||
scale1_npu.append(scale1[i].npu())
|
||||
weight2_nz_npu.append(
|
||||
torch_npu.npu_format_cast(weight2[i].npu(), 29))
|
||||
scale2_npu.append(scale2[i].npu())
|
||||
|
||||
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
x=x,
|
||||
weight1=weight1,
|
||||
weight2=weight2,
|
||||
weight1=weight1_nz_npu,
|
||||
weight2=weight2_nz_npu,
|
||||
expert_idx=expert_idx,
|
||||
scale1=scale1,
|
||||
scale2=scale2,
|
||||
scale1=scale1_npu,
|
||||
scale2=scale2_npu,
|
||||
probs=probs,
|
||||
group=self.hcomm_info,
|
||||
max_output_size=512,
|
||||
out=out,
|
||||
)
|
||||
return True
|
||||
|
||||
def run_normal(self) -> bool:
|
||||
torch_npu.npu.set_device(self.rank)
|
||||
m = 64
|
||||
k = 1024
|
||||
n = 1024
|
||||
topk = 8
|
||||
e = 8
|
||||
k2 = n // 2
|
||||
n2 = k
|
||||
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
x = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||
weight1 = self.generate_random_tensor((e, k, n),
|
||||
dtype=torch.int8).npu()
|
||||
weight1 = torch_npu.npu_format_cast(weight1, 29)
|
||||
weight2 = self.generate_random_tensor((e, k2, n2),
|
||||
dtype=torch.int8).npu()
|
||||
weight2 = torch_npu.npu_format_cast(weight2, 29)
|
||||
|
||||
expert_idx = torch.randint(0,
|
||||
self.world_size * e, (m, topk),
|
||||
dtype=torch.int32).npu()
|
||||
scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu()
|
||||
scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu()
|
||||
probs = torch.randn(size=(m, topk), dtype=torch.float32).npu()
|
||||
|
||||
weight1_nz_npu = []
|
||||
weight2_nz_npu = []
|
||||
scale1_npu = []
|
||||
scale2_npu = []
|
||||
weight1_nz_npu.append(torch_npu.npu_format_cast(weight1.npu(), 29))
|
||||
scale1_npu.append(scale1.npu())
|
||||
weight2_nz_npu.append(torch_npu.npu_format_cast(weight2.npu(), 29))
|
||||
scale2_npu.append(scale2.npu())
|
||||
|
||||
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
x=x,
|
||||
weight1=weight1_nz_npu,
|
||||
weight2=weight2_nz_npu,
|
||||
expert_idx=expert_idx,
|
||||
scale1=scale1_npu,
|
||||
scale2=scale2_npu,
|
||||
probs=probs,
|
||||
group=self.hcomm_info,
|
||||
max_output_size=512,
|
||||
@@ -142,8 +206,10 @@ class TestDisptachFFNCombine:
|
||||
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
|
||||
op = TestDisptachFFNCombine(rank, world_size, port)
|
||||
op.generate_hcom()
|
||||
out = op.run_npu_out()
|
||||
q.put(out)
|
||||
out1 = op.run_tensor_list()
|
||||
q.put(out1)
|
||||
out2 = op.run_normal()
|
||||
q.put(out2)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
Reference in New Issue
Block a user