[CustomOp] support TensorList for dispatchFFNCombine (#5665)

### What this PR does / why we need it?
To support tensorList for dispatch_ffn_combine, to adjust eplb

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
Single Operator Testing

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: lhchg <lhao_cheng@163.com>
Co-authored-by: lihaocheng <lihaosheng1@h-partners.com>
This commit is contained in:
lhchg
2026-01-09 15:56:29 +08:00
committed by GitHub
parent 3ce5a34468
commit dc99cfdc15
16 changed files with 293 additions and 105 deletions

View File

@@ -87,13 +87,13 @@ class TestDisptachFFNCombine:
hcomm_info = hcomm_info_dist["default_pg_info"]
self.hcomm_info = hcomm_info
def run_npu_out(self) -> bool:
def run_tensor_list(self) -> bool:
torch_npu.npu.set_device(self.rank)
m = 2 # token-num 32
k = 4 # hidden_size 7168
n = 4 # mid-hidden-size 4096
topk = 2
e = 2 # expert-num-per-rank 16
m = 64
k = 1024
n = 1024
topk = 8
e = 8
k2 = n // 2
n2 = k
@@ -112,15 +112,79 @@ class TestDisptachFFNCombine:
scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu()
scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu()
probs = torch.randn(size=(m, topk), dtype=torch.float32).npu()
weight1_nz_npu = []
weight2_nz_npu = []
scale1_npu = []
scale2_npu = []
for i in range(e):
weight1_nz_npu.append(
torch_npu.npu_format_cast(weight1[i].npu(), 29))
scale1_npu.append(scale1[i].npu())
weight2_nz_npu.append(
torch_npu.npu_format_cast(weight2[i].npu(), 29))
scale2_npu.append(scale2[i].npu())
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
torch.ops._C_ascend.dispatch_ffn_combine(
x=x,
weight1=weight1,
weight2=weight2,
weight1=weight1_nz_npu,
weight2=weight2_nz_npu,
expert_idx=expert_idx,
scale1=scale1,
scale2=scale2,
scale1=scale1_npu,
scale2=scale2_npu,
probs=probs,
group=self.hcomm_info,
max_output_size=512,
out=out,
)
return True
def run_normal(self) -> bool:
torch_npu.npu.set_device(self.rank)
m = 64
k = 1024
n = 1024
topk = 8
e = 8
k2 = n // 2
n2 = k
torch_npu.npu.config.allow_internal_format = True
x = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
weight1 = self.generate_random_tensor((e, k, n),
dtype=torch.int8).npu()
weight1 = torch_npu.npu_format_cast(weight1, 29)
weight2 = self.generate_random_tensor((e, k2, n2),
dtype=torch.int8).npu()
weight2 = torch_npu.npu_format_cast(weight2, 29)
expert_idx = torch.randint(0,
self.world_size * e, (m, topk),
dtype=torch.int32).npu()
scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu()
scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu()
probs = torch.randn(size=(m, topk), dtype=torch.float32).npu()
weight1_nz_npu = []
weight2_nz_npu = []
scale1_npu = []
scale2_npu = []
weight1_nz_npu.append(torch_npu.npu_format_cast(weight1.npu(), 29))
scale1_npu.append(scale1.npu())
weight2_nz_npu.append(torch_npu.npu_format_cast(weight2.npu(), 29))
scale2_npu.append(scale2.npu())
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
torch.ops._C_ascend.dispatch_ffn_combine(
x=x,
weight1=weight1_nz_npu,
weight2=weight2_nz_npu,
expert_idx=expert_idx,
scale1=scale1_npu,
scale2=scale2_npu,
probs=probs,
group=self.hcomm_info,
max_output_size=512,
@@ -142,8 +206,10 @@ class TestDisptachFFNCombine:
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
op = TestDisptachFFNCombine(rank, world_size, port)
op.generate_hcom()
out = op.run_npu_out()
q.put(out)
out1 = op.run_tensor_list()
q.put(out1)
out2 = op.run_normal()
q.put(out2)
@torch.inference_mode()