[Bugfix] fix pcp + eplb error (#5561)
### What this PR does / why we need it?
Fix the bug in the PCP overlay feature
1、Fix the bug related to PCP and EPLB overlap by including PCP size in
the word_size calculation.
2、In the PCP pooling scenario, a prompt has been added for setting the
cp_kv_cache_interleave_size.
- vLLM version: v0.13.0
- vLLM main:
7157596103
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -24,7 +24,6 @@ import pytest
|
|||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
from tests.e2e.model_utils import check_outputs_equal
|
from tests.e2e.model_utils import check_outputs_equal
|
||||||
from vllm_ascend.utils import vllm_version_is
|
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"Qwen/Qwen3-8B",
|
"Qwen/Qwen3-8B",
|
||||||
@@ -32,8 +31,6 @@ MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("max_tokens", [10])
|
@pytest.mark.parametrize("max_tokens", [10])
|
||||||
def test_models_long_sequence_output_between_tp_and_cp(
|
def test_models_long_sequence_output_between_tp_and_cp(
|
||||||
|
|||||||
@@ -23,17 +23,13 @@ Run `pytest tests/e2e/multicard/test_qwen3_moe.py`.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
from vllm_ascend.utils import vllm_version_is
|
|
||||||
|
|
||||||
os.environ["HCCL_BUFFSIZE"] = "768"
|
os.environ["HCCL_BUFFSIZE"] = "768"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_models_pcp_dcp_basic():
|
def test_models_pcp_dcp_basic():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
@@ -67,8 +63,6 @@ def test_models_pcp_dcp_basic():
|
|||||||
runner.model.generate(prompts, sampling_params)
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_models_pcp_dcp_full_graph():
|
def test_models_pcp_dcp_full_graph():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
@@ -106,8 +100,6 @@ def test_models_pcp_dcp_full_graph():
|
|||||||
runner.model.generate(prompts, sampling_params)
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_models_pcp_dcp_piece_wise():
|
def test_models_pcp_dcp_piece_wise():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
@@ -139,8 +131,6 @@ def test_models_pcp_dcp_piece_wise():
|
|||||||
runner.model.generate(prompts, sampling_params)
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_pcp_basic():
|
def test_pcp_basic():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
@@ -160,8 +150,6 @@ def test_pcp_basic():
|
|||||||
runner.model.generate(prompts, sampling_params)
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_pcp_full_graph():
|
def test_pcp_full_graph():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
@@ -185,8 +173,6 @@ def test_pcp_full_graph():
|
|||||||
runner.model.generate(prompts, sampling_params)
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_pcp_piece_wise():
|
def test_pcp_piece_wise():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
@@ -206,8 +192,6 @@ def test_pcp_piece_wise():
|
|||||||
runner.model.generate(prompts, sampling_params)
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_dcp_basic():
|
def test_dcp_basic():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
@@ -227,8 +211,6 @@ def test_dcp_basic():
|
|||||||
runner.model.generate(prompts, sampling_params)
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_dcp_full_graph():
|
def test_dcp_full_graph():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
@@ -252,8 +234,6 @@ def test_dcp_full_graph():
|
|||||||
runner.model.generate(prompts, sampling_params)
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_dcp_piece_wise():
|
def test_dcp_piece_wise():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
|
|||||||
@@ -19,16 +19,11 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
from vllm_ascend.utils import vllm_version_is
|
|
||||||
|
|
||||||
os.environ["HCCL_BUFFSIZE"] = "512"
|
os.environ["HCCL_BUFFSIZE"] = "512"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_pcp_dcp_mtp1_eager():
|
def test_pcp_dcp_mtp1_eager():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
@@ -53,8 +48,6 @@ def test_pcp_dcp_mtp1_eager():
|
|||||||
runner.generate_greedy(prompts, 32)
|
runner.generate_greedy(prompts, 32)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_pcp_dcp_mtp3_eager():
|
def test_pcp_dcp_mtp3_eager():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
@@ -79,8 +72,6 @@ def test_pcp_dcp_mtp3_eager():
|
|||||||
runner.generate_greedy(prompts, 32)
|
runner.generate_greedy(prompts, 32)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_pcp_dcp_mtp3_piecewise_graph():
|
def test_pcp_dcp_mtp3_piecewise_graph():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
@@ -108,8 +99,6 @@ def test_pcp_dcp_mtp3_piecewise_graph():
|
|||||||
runner.generate_greedy(prompts, 32)
|
runner.generate_greedy(prompts, 32)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_pcp_dcp_mtp3_full_graph():
|
def test_pcp_dcp_mtp3_full_graph():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
@@ -137,8 +126,6 @@ def test_pcp_dcp_mtp3_full_graph():
|
|||||||
runner.generate_greedy(prompts, 32)
|
runner.generate_greedy(prompts, 32)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
|
||||||
reason="0.12.0 is not supported for context sequence.")
|
|
||||||
def test_dcp_mtp3_full_graph():
|
def test_dcp_mtp3_full_graph():
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ env_common:
|
|||||||
HCCL_BUFFSIZE: 1024
|
HCCL_BUFFSIZE: 1024
|
||||||
SERVER_PORT: 8080
|
SERVER_PORT: 8080
|
||||||
NUMEXPR_MAX_THREADS: 128
|
NUMEXPR_MAX_THREADS: 128
|
||||||
|
DYNAMIC_EPLB: true
|
||||||
disaggregated_prefill:
|
disaggregated_prefill:
|
||||||
enabled: true
|
enabled: true
|
||||||
prefiller_host_index: [0]
|
prefiller_host_index: [0]
|
||||||
@@ -52,6 +53,8 @@ deployment:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}'
|
}'
|
||||||
|
--additional-config
|
||||||
|
'{"dynamic_eplb":true}'
|
||||||
|
|
||||||
-
|
-
|
||||||
server_cmd: >
|
server_cmd: >
|
||||||
@@ -90,4 +93,6 @@ deployment:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}'
|
}'
|
||||||
|
--additional-config
|
||||||
|
'{"dynamic_eplb":true}'
|
||||||
benchmarks:
|
benchmarks:
|
||||||
|
|||||||
@@ -36,10 +36,12 @@ class AscendMultiprocExecutor(MultiprocExecutor):
|
|||||||
self.local_world_size = self.parallel_config.local_world_size
|
self.local_world_size = self.parallel_config.local_world_size
|
||||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||||
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||||
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
pcp_parallel_size = self.parallel_config.prefill_context_parallel_size
|
||||||
|
assert self.world_size == tensor_parallel_size * pp_parallel_size * pcp_parallel_size, (
|
||||||
f"world_size ({self.world_size}) must be equal to the "
|
f"world_size ({self.world_size}) must be equal to the "
|
||||||
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
||||||
f"_parallel_size ({pp_parallel_size}). ")
|
f"_parallel_size ({pp_parallel_size}) x prefill_context"
|
||||||
|
f"_parallel_size ({pcp_parallel_size}). ")
|
||||||
|
|
||||||
# Set multiprocessing envs
|
# Set multiprocessing envs
|
||||||
set_multiprocessing_worker_envs()
|
set_multiprocessing_worker_envs()
|
||||||
|
|||||||
@@ -329,7 +329,7 @@ class NPUPlatform(Platform):
|
|||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) "
|
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) "
|
||||||
f"and block_size({cache_config.block_size}) "
|
f"and block_size({cache_config.block_size}) "
|
||||||
"needs to be equal if use cp or dcp > 1 in P/D disaggregate scenario."
|
"needs to be equal if use pcp or dcp > 1 in P/D disaggregate and kv pool scenario."
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_vl_model(vllm_config):
|
if is_vl_model(vllm_config):
|
||||||
|
|||||||
Reference in New Issue
Block a user