[Feat]support sequence parallelism by pass for VL models (#5632)

This commit is contained in:
realliujiaxu
2026-02-27 08:27:41 +08:00
committed by GitHub
parent ed175d6d92
commit 5def28dcd3
22 changed files with 460 additions and 101 deletions

View File

@@ -55,7 +55,7 @@ from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
from vllm_ascend.utils import (
AscendDeviceType,
check_ascend_device_type,
enable_sp,
enable_flash_comm_v1,
get_ascend_device_type,
register_ascend_customop,
)
@@ -376,7 +376,7 @@ class NPUWorker(WorkerBase):
if forward_pass and not get_pp_group().is_first_rank:
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
# it will conflict with the all-gather operation in flashcomm1.
if enable_sp():
if enable_flash_comm_v1():
all_gather_group = None
else:
all_gather_group = get_tp_group()
@@ -393,7 +393,7 @@ class NPUWorker(WorkerBase):
assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
# it will conflict with the all-gather operation in flashcomm1.
if enable_sp():
if enable_flash_comm_v1():
all_gather_group = None
else:
all_gather_group = get_tp_group()