Support pipeline parallel in V1 Engine (#1700)
### What this PR does / why we need it?
This patch supports pipeline parallel in V1 Engine
### Does this PR introduce _any_ user-facing change?
Yes, users can run PP in V1
### How was this patch tested?
Manully test
- vLLM version: v0.9.2
- vLLM main:
31d5c1797f
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
43
tests/e2e/multicard/test_pipeline_parallel.py
Normal file
43
tests/e2e/multicard/test_pipeline_parallel.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
import pytest
|
||||
|
||||
from tests.conftest import VllmRunner
|
||||
|
||||
MODELS = [
|
||||
"Qwen/Qwen3-0.6B",
|
||||
]
|
||||
|
||||
TENSOR_PARALLELS = [2]
|
||||
PIPELINE_PARALLELS = [2]
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
|
||||
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
|
||||
def test_models(model: str, tp_size: int, pp_size: int) -> None:
|
||||
with VllmRunner(model,
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.7) as vllm_model:
|
||||
vllm_model.generate_greedy(prompts, 64)
|
||||
@@ -37,7 +37,8 @@ from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_dp_group, get_pp_group
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import logger
|
||||
@@ -146,6 +147,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
@@ -921,8 +923,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
||||
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
|
||||
num_scheduled_tokens)
|
||||
sample_indices = cu_num_tokens - 1
|
||||
sample_indices = torch.from_numpy(sample_indices).to(self.device,
|
||||
logits_indices = cu_num_tokens - 1
|
||||
logits_indices = torch.from_numpy(logits_indices).to(self.device,
|
||||
non_blocking=True)
|
||||
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
||||
|
||||
@@ -1153,14 +1155,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
sample_indices = spec_decode_metadata.logits_indices
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
aux_hidden_states = None
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
|
||||
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
total_num_scheduled_tokens, sample_indices, aux_hidden_states,
|
||||
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
|
||||
num_scheduled_tokens)
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
@@ -1397,16 +1399,42 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Return empty ModelRunnerOuptut if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
(attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
num_scheduled_tokens, sample_indices, aux_hidden_states,
|
||||
num_scheduled_tokens, logits_indices, aux_hidden_states,
|
||||
num_scheduled_tokens_np) = (self._process_reqs(
|
||||
scheduler_output, intermediate_tensors))
|
||||
|
||||
with ProfileExecuteDuration().capture_async("post process"):
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np)
|
||||
logits = self.model.compute_logits(hidden_states[sample_indices],
|
||||
None)
|
||||
# Broadcast PP output for external_launcher (torchrun)
|
||||
# to make sure we are synced across pp ranks
|
||||
# TODO: Support overlapping mirco-batches
|
||||
# https://github.com/vllm-project/vllm/issues/18019
|
||||
broadcast_pp_output = \
|
||||
self.parallel_config.distributed_executor_backend \
|
||||
== "external_launcher" and len(get_pp_group().ranks) > 0
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
if not broadcast_pp_output:
|
||||
return hidden_states
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(
|
||||
hidden_states.tensors, all_gather_group=get_tp_group())
|
||||
logits = None
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np)
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
model_output_broadcast_data = {
|
||||
"logits": logits.contiguous(),
|
||||
} if logits is not None else {}
|
||||
model_output_broadcast_data = get_pp_group(
|
||||
).broadcast_tensor_dict(model_output_broadcast_data,
|
||||
src=len(get_pp_group().ranks) - 1)
|
||||
assert model_output_broadcast_data is not None
|
||||
logits = model_output_broadcast_data["logits"]
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
logits = self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
@@ -1423,6 +1451,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[
|
||||
spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
|
||||
@@ -28,8 +28,10 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
@@ -206,7 +208,22 @@ class NPUWorker(WorkerBase):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
output = self.model_runner.execute_model(scheduler_output)
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
if parallel_config.distributed_executor_backend != "external_launcher" \
|
||||
and not get_pp_group().is_last_rank:
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return None
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
def load_model(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user