[Bugfix] fix logging and d2h bug for flash comm1 (#3505)
### What this PR does / why we need it? Fix 3 bugs in flash comm1 of Allgather EP(https://github.com/vllm-project/vllm-ascend/pull/3334): 1. call `enable_sp()` with argument `vllm_config` trigger a lot of warning log, this PR caches its return value. 2. `num_tokens_after_padding` should be cpu tensor as it will used as `num_tokens_across_dp_cpu` in `DPMetadata`. It will causes may d2h copy when running model. 3. In PD, model runner will execute `kv_connector_no_forward`,where `num_tokens` is None - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -815,7 +815,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Create a tensor for num_tokens_after_padding
|
||||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
|
||||
self.dp_size,
|
||||
device="npu",
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
|
||||
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo
|
||||
|
||||
Reference in New Issue
Block a user