[Feat] Flash comm allgher ep (#3334)
Support flash comm v1(Sequence Parallelism) for Allgather EP. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com> Co-authored-by: zhaozx-cn <zhaozx2116@163.com>
This commit is contained in:
@@ -55,6 +55,7 @@ _PREFETCH_STREAM = None
|
||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
||||
_DEFAULT_BUFFER_SIZE = 200
|
||||
_MIN_DP_BUFFER_SIZE = 50
|
||||
_IS_MOE_MODEL = None
|
||||
|
||||
|
||||
def is_310p():
|
||||
@@ -609,12 +610,24 @@ def enable_sp(vllm_config=None) -> bool:
|
||||
vllm_config = get_current_vllm_config()
|
||||
return (
|
||||
vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
||||
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM)
|
||||
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))
|
||||
|
||||
|
||||
# TODO remove it after vllm has this func
|
||||
def shared_expert_dp_enabled() -> bool:
|
||||
return get_ascend_config().enable_shared_expert_dp or enable_sp()
|
||||
|
||||
|
||||
def is_moe_model(vllm_config: VllmConfig):
|
||||
config = vllm_config.model_config.hf_config
|
||||
return any('experts' in key.lower() for key in config.to_dict())
|
||||
global _IS_MOE_MODEL
|
||||
if _IS_MOE_MODEL is None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
_IS_MOE_MODEL = any('experts' in key.lower()
|
||||
for key in config.to_dict())
|
||||
return _IS_MOE_MODEL
|
||||
|
||||
|
||||
def weak_ref_tensor(tensor: Any) -> Any:
|
||||
|
||||
Reference in New Issue
Block a user