[Feat]support sequence parallelism by pass for VL models (#5632)
This commit is contained in:
@@ -55,7 +55,7 @@ from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
from vllm_ascend.utils import (
|
||||
AscendDeviceType,
|
||||
check_ascend_device_type,
|
||||
enable_sp,
|
||||
enable_flash_comm_v1,
|
||||
get_ascend_device_type,
|
||||
register_ascend_customop,
|
||||
)
|
||||
@@ -376,7 +376,7 @@ class NPUWorker(WorkerBase):
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
|
||||
# it will conflict with the all-gather operation in flashcomm1.
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
all_gather_group = None
|
||||
else:
|
||||
all_gather_group = get_tp_group()
|
||||
@@ -393,7 +393,7 @@ class NPUWorker(WorkerBase):
|
||||
assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank
|
||||
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
|
||||
# it will conflict with the all-gather operation in flashcomm1.
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
all_gather_group = None
|
||||
else:
|
||||
all_gather_group = get_tp_group()
|
||||
|
||||
Reference in New Issue
Block a user