[Feat]Make full graph mode compalible with MTP (#3276)
### What this PR does / why we need it? Make the Full Graph mode can run with MTP. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
@@ -20,6 +21,7 @@ def mtp_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
num_speculative_tokens: int,
|
||||
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
@@ -38,6 +40,10 @@ def mtp_correctness(
|
||||
enforce_eager=False) as ref_llm:
|
||||
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
||||
|
||||
graph_mode_str = "PIECEWISE"
|
||||
if graph_mode == CUDAGraphMode.FULL:
|
||||
graph_mode_str = "FULL"
|
||||
|
||||
with VllmRunner(
|
||||
model_name,
|
||||
tensor_parallel_size=1,
|
||||
@@ -51,6 +57,8 @@ def mtp_correctness(
|
||||
},
|
||||
enforce_eager=False,
|
||||
max_model_len=2000,
|
||||
compilation_config=CompilationConfig(
|
||||
cudagraph_mode=graph_mode_str),
|
||||
additional_config={"ascend_scheduler_config": {
|
||||
"enabled": False
|
||||
}}) as spec_llm:
|
||||
@@ -74,15 +82,29 @@ def mtp_correctness(
|
||||
del spec_llm
|
||||
|
||||
|
||||
def test_mtp1_correctness(
|
||||
def test_mtp1_correctness_piecewise_graph(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_correctness(sampling_config, model_name, 1)
|
||||
|
||||
|
||||
def test_mtp2_correctness(
|
||||
def test_mtp2_correctness_piecewise_graph(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_correctness(sampling_config, model_name, 2)
|
||||
|
||||
|
||||
def test_mtp1_correctness_full_graph(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL)
|
||||
|
||||
|
||||
def test_mtp2_correctness_full_graph(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
@@ -16,9 +17,10 @@ def model_name():
|
||||
return "wemaster/deepseek_mtp_main_random_bf16"
|
||||
|
||||
|
||||
def test_mtp_torchair_correctness(
|
||||
def mtp_torchair_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
@@ -44,6 +46,11 @@ def test_mtp_torchair_correctness(
|
||||
"multistream_overlap_shared_expert": "True"
|
||||
}) as ref_llm:
|
||||
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
||||
|
||||
graph_mode_str = "PIECEWISE"
|
||||
if graph_mode == CUDAGraphMode.FULL:
|
||||
graph_mode_str = "FULL"
|
||||
|
||||
with VllmRunner(model_name,
|
||||
tensor_parallel_size=1,
|
||||
max_num_seqs=256,
|
||||
@@ -56,6 +63,8 @@ def test_mtp_torchair_correctness(
|
||||
},
|
||||
enforce_eager=False,
|
||||
max_model_len=2000,
|
||||
compilation_config=CompilationConfig(
|
||||
cudagraph_mode=graph_mode_str),
|
||||
additional_config={
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
@@ -81,3 +90,17 @@ def test_mtp_torchair_correctness(
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
|
||||
|
||||
def test_mtp_torchair_correctness_piecewise(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_torchair_correctness(sampling_config, model_name)
|
||||
|
||||
|
||||
def test_mtp_torchair_correctness_full(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_torchair_correctness(sampling_config, model_name, CUDAGraphMode.FULL)
|
||||
|
||||
@@ -448,6 +448,7 @@ class TestNPUWorker(TestBase):
|
||||
worker.compilation_config = MagicMock()
|
||||
worker.compilation_config.cudagraph_mode = MagicMock()
|
||||
mock_model_runner = MagicMock()
|
||||
mock_decode_token_per_req = mock_model_runner.decode_token_per_req
|
||||
worker.model_runner = mock_model_runner
|
||||
|
||||
# Test execute_dummy_batch
|
||||
@@ -455,7 +456,9 @@ class TestNPUWorker(TestBase):
|
||||
|
||||
# Verify call
|
||||
mock_model_runner._dummy_run.assert_called_once_with(
|
||||
num_tokens=1, uniform_decode=True, force_attention=False)
|
||||
num_tokens=mock_decode_token_per_req,
|
||||
uniform_decode=True,
|
||||
force_attention=False)
|
||||
|
||||
@patch("vllm_ascend.worker.worker_v1.envs_vllm")
|
||||
@patch("vllm_ascend.worker.worker_v1.logger")
|
||||
|
||||
Reference in New Issue
Block a user