[MLA][Graph] Improve assertion on Graph mode with MLA (#933)

### What this PR does / why we need it?
Improve assertion on Graph mode with MLA.

When running deepseek with graph mode, the fused MLA op only support
`numHeads / numKvHeads ∈ {32, 64, 128}`, thus we improve the assertion
info here to avoid users confused with this.

### Does this PR introduce _any_ user-facing change?
Adjusting tp size is required when running deepseek-v3/r1 with graph
mode. deepseek-v2-lite is not supported in graph mode.

### How was this patch tested?
Test locally as the CI machine could not run V3 due to the HBM limits.

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-06-10 22:26:53 +08:00
committed by GitHub
parent 291c216898
commit 8dd686dfa2
4 changed files with 33 additions and 1 deletions

View File

@@ -119,7 +119,7 @@ class MultiStepWorker(NPUWorker):
# execute_model_req
assert execute_model_req.last_sampled_token_ids is not None
model_input.last_sampled_token_ids = (
execute_model_req.last_sampled_token_ids.cuda())
execute_model_req.last_sampled_token_ids.npu())
model_input.add_sampler_output(
SamplerOutput(outputs=[], sampled_token_ids=None),
model_input.last_sampled_token_ids)