[Feat]support sequence parallelism by pass for VL models (#5632)

This commit is contained in:
realliujiaxu
2026-02-27 08:27:41 +08:00
committed by GitHub
parent ed175d6d92
commit 5def28dcd3
22 changed files with 460 additions and 101 deletions

View File

@@ -112,6 +112,7 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (
enable_flash_comm_v1,
enable_sp,
is_drafter_moe_model,
is_moe_model,
@@ -1677,7 +1678,7 @@ class NPUModelRunner(GPUModelRunner):
self.speculative_config,
positions.shape[0],
)
if get_forward_context().sp_enabled and not isinstance(hidden_states, IntermediateTensors):
if get_forward_context().flash_comm_v1_enabled and not isinstance(hidden_states, IntermediateTensors):
hidden_states = self._all_gather_hidden_states_and_aux(hidden_states)
return hidden_states
@@ -1685,7 +1686,7 @@ class NPUModelRunner(GPUModelRunner):
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if enable_sp():
if enable_sp(self.vllm_config):
return round_up(num_scheduled_tokens, tp_size)
return num_scheduled_tokens
@@ -2223,7 +2224,7 @@ class NPUModelRunner(GPUModelRunner):
# tp_size; otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading
# to incorrect memory estimation and potentially causing OOM.
intermediate_tokens = num_tokens_padded
if enable_sp():
if enable_flash_comm_v1():
tp_size = get_tensor_model_parallel_world_size()
intermediate_tokens = (num_tokens_padded + tp_size - 1) // tp_size
if self.intermediate_tensors is None:

View File

@@ -55,7 +55,7 @@ from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
from vllm_ascend.utils import (
AscendDeviceType,
check_ascend_device_type,
enable_sp,
enable_flash_comm_v1,
get_ascend_device_type,
register_ascend_customop,
)
@@ -376,7 +376,7 @@ class NPUWorker(WorkerBase):
if forward_pass and not get_pp_group().is_first_rank:
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
# it will conflict with the all-gather operation in flashcomm1.
if enable_sp():
if enable_flash_comm_v1():
all_gather_group = None
else:
all_gather_group = get_tp_group()
@@ -393,7 +393,7 @@ class NPUWorker(WorkerBase):
assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
# it will conflict with the all-gather operation in flashcomm1.
if enable_sp():
if enable_flash_comm_v1():
all_gather_group = None
else:
all_gather_group = get_tp_group()