[Bugfix] fix logging and d2h bug for flash comm1 (#3505)
### What this PR does / why we need it? Fix 3 bugs in flash comm1 of Allgather EP(https://github.com/vllm-project/vllm-ascend/pull/3334): 1. call `enable_sp()` with argument `vllm_config` trigger a lot of warning log, this PR caches its return value. 2. `num_tokens_after_padding` should be cpu tensor as it will used as `num_tokens_across_dp_cpu` in `DPMetadata`. It will causes may d2h copy when running model. 3. In PD, model runner will execute `kv_connector_no_forward`,where `num_tokens` is None - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -114,7 +114,7 @@ def set_ascend_forward_context(
|
|||||||
# the performance may degrade due to the switching of communication methods.
|
# the performance may degrade due to the switching of communication methods.
|
||||||
if is_moe_model(vllm_config):
|
if is_moe_model(vllm_config):
|
||||||
sp_enabled = enable_sp(vllm_config) and \
|
sp_enabled = enable_sp(vllm_config) and \
|
||||||
tp_world_size > 1
|
tp_world_size > 1 and num_tokens is not None
|
||||||
else:
|
else:
|
||||||
sp_enabled = enable_sp(vllm_config) and \
|
sp_enabled = enable_sp(vllm_config) and \
|
||||||
tp_world_size > 1 and \
|
tp_world_size > 1 and \
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ _ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
|||||||
_DEFAULT_BUFFER_SIZE = 200
|
_DEFAULT_BUFFER_SIZE = 200
|
||||||
_MIN_DP_BUFFER_SIZE = 50
|
_MIN_DP_BUFFER_SIZE = 50
|
||||||
_IS_MOE_MODEL = None
|
_IS_MOE_MODEL = None
|
||||||
|
_ENABLE_SP = None
|
||||||
|
|
||||||
|
|
||||||
def is_310p():
|
def is_310p():
|
||||||
@@ -606,15 +607,20 @@ def dense_optim_enable() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def enable_sp(vllm_config=None) -> bool:
|
def enable_sp(vllm_config=None) -> bool:
|
||||||
if vllm_config is None:
|
global _ENABLE_SP
|
||||||
from vllm.config import get_current_vllm_config
|
if _ENABLE_SP is None:
|
||||||
vllm_config = get_current_vllm_config()
|
if vllm_config is None:
|
||||||
return (
|
from vllm.config import get_current_vllm_config
|
||||||
vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
vllm_config = get_current_vllm_config()
|
||||||
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
_ENABLE_SP = (
|
||||||
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
vllm_config.compilation_config.pass_config.
|
||||||
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
enable_sequence_parallelism
|
||||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))
|
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||||
|
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||||
|
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
||||||
|
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))
|
||||||
|
|
||||||
|
return _ENABLE_SP
|
||||||
|
|
||||||
|
|
||||||
# TODO remove it after vllm has this func
|
# TODO remove it after vllm has this func
|
||||||
|
|||||||
@@ -815,7 +815,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Create a tensor for num_tokens_after_padding
|
# Create a tensor for num_tokens_after_padding
|
||||||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
|
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
|
||||||
self.dp_size,
|
self.dp_size,
|
||||||
device="npu",
|
device="cpu",
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo
|
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo
|
||||||
|
|||||||
Reference in New Issue
Block a user