[V1][PP] Support pp with ray backend in V1 (#1800)
### What this PR does / why we need it?
Support pipeline parallel with ray backend in V1Engine.
Fixes #1751
### Does this PR introduce _any_ user-facing change?
Users could specify ray as distributed backend when inferencing with pp
### How was this patch tested?
CI passed with new added test.
- vLLM version: v0.9.2
- vLLM main:
32142b3c62
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -6,7 +6,6 @@ pytest >= 6.0
|
|||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
pytest-mock
|
pytest-mock
|
||||||
lm-eval
|
lm-eval
|
||||||
ray
|
|
||||||
types-jsonschema
|
types-jsonschema
|
||||||
xgrammar
|
xgrammar
|
||||||
zmq
|
zmq
|
||||||
@@ -14,3 +13,5 @@ types-psutil
|
|||||||
pytest-cov
|
pytest-cov
|
||||||
regex
|
regex
|
||||||
sentence_transformers
|
sentence_transformers
|
||||||
|
ray>=2.47.1
|
||||||
|
protobuf==4.25.6
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ MODELS = [
|
|||||||
|
|
||||||
TENSOR_PARALLELS = [2]
|
TENSOR_PARALLELS = [2]
|
||||||
PIPELINE_PARALLELS = [2]
|
PIPELINE_PARALLELS = [2]
|
||||||
|
DIST_EXECUTOR_BACKEND = ["mp", "ray"]
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
@@ -34,10 +35,13 @@ prompts = [
|
|||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
|
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
|
||||||
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
|
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
|
||||||
def test_models(model: str, tp_size: int, pp_size: int) -> None:
|
@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND)
|
||||||
|
def test_models(model: str, tp_size: int, pp_size: int,
|
||||||
|
distributed_executor_backend: str) -> None:
|
||||||
with VllmRunner(model,
|
with VllmRunner(model,
|
||||||
tensor_parallel_size=tp_size,
|
tensor_parallel_size=tp_size,
|
||||||
pipeline_parallel_size=pp_size,
|
pipeline_parallel_size=pp_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
gpu_memory_utilization=0.7) as vllm_model:
|
gpu_memory_utilization=0.7) as vllm_model:
|
||||||
vllm_model.generate_greedy(prompts, 64)
|
vllm_model.generate_greedy(prompts, 64)
|
||||||
|
|||||||
@@ -400,19 +400,13 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
layer = self.layer_no_quant
|
layer = self.layer_no_quant
|
||||||
mock_vanilla_prefill.return_value = MagicMock()
|
mock_vanilla_prefill.return_value = MagicMock()
|
||||||
|
|
||||||
def mock_tensor(data, device=None, **kwargs):
|
output = self.impl_192.forward(layer,
|
||||||
if device == "npu":
|
query,
|
||||||
return metadata.attn_mask
|
key,
|
||||||
return torch.tensor(data, **kwargs)
|
value,
|
||||||
|
kv_cache,
|
||||||
with patch("torch.tensor", side_effect=mock_tensor):
|
metadata,
|
||||||
output = self.impl_192.forward(layer,
|
trace_flag=False)
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
mock_vanilla_prefill.assert_called_once()
|
mock_vanilla_prefill.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 192)
|
assert output.shape == (10, 8 * 192)
|
||||||
|
|||||||
@@ -396,8 +396,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
if self.head_size == 192:
|
if self.head_size == 192:
|
||||||
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
||||||
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
||||||
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
|
cu_seqlen_q = torch.tensor(cu_seqlen_q,
|
||||||
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
|
device=query.device)
|
||||||
|
cu_seqlen_k = torch.tensor(cu_seqlen_k,
|
||||||
|
device=query.device)
|
||||||
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
||||||
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
||||||
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||||
2048,
|
2048,
|
||||||
dtype=torch.bool),
|
dtype=torch.bool),
|
||||||
diagonal=1).to("npu")
|
diagonal=1).to(self.device)
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
if self.speculative_config.method == "ngram":
|
if self.speculative_config.method == "ngram":
|
||||||
self.drafter = NgramProposer(self.vllm_config)
|
self.drafter = NgramProposer(self.vllm_config)
|
||||||
@@ -1120,6 +1120,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
input_ids = self.input_ids[:padded_batch_size]
|
input_ids = self.input_ids[:padded_batch_size]
|
||||||
positions = self.positions[:padded_batch_size]
|
positions = self.positions[:padded_batch_size]
|
||||||
|
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
intermediate_tensors = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
assert self.intermediate_tensors is not None
|
||||||
|
for k, v in intermediate_tensors.items():
|
||||||
|
self.intermediate_tensors[k][:num_input_tokens].copy_(
|
||||||
|
v[:num_input_tokens], non_blocking=True)
|
||||||
|
intermediate_tensors = IntermediateTensors({
|
||||||
|
k: v[:num_input_tokens]
|
||||||
|
for k, v in self.intermediate_tensors.items()
|
||||||
|
})
|
||||||
|
|
||||||
# Run forward pass
|
# Run forward pass
|
||||||
with set_forward_context(attn_metadata,
|
with set_forward_context(attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
|
|||||||
Reference in New Issue
Block a user