[Feat]support sequence parallelism by pass for VL models (#5632)

This commit is contained in:
realliujiaxu
2026-02-27 08:27:41 +08:00
committed by GitHub
parent ed175d6d92
commit 5def28dcd3
22 changed files with 460 additions and 101 deletions

View File

@@ -1171,7 +1171,7 @@ class EagleProposer(VllmEagleProposer):
positions = positions.squeeze(-1)
else:
forward_context = get_forward_context()
if forward_context.sp_enabled:
if forward_context.flash_comm_v1_enabled:
hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states)
return hidden_states, positions
@@ -1191,7 +1191,7 @@ class EagleProposer(VllmEagleProposer):
hidden_states = last_hidden_states
else:
forward_context = get_forward_context()
if forward_context.sp_enabled:
if forward_context.flash_comm_v1_enabled:
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True
)