forked from EngineX-Ascend/enginex-ascend-910-vllm
init v0.11.0rc0
This commit is contained in:
@@ -1,14 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from vllm import SamplingParams
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampling_config():
|
||||
@@ -20,9 +16,10 @@ def model_name():
|
||||
return "wemaster/deepseek_mtp_main_random_bf16"
|
||||
|
||||
|
||||
def test_mtp_correctness(
|
||||
def mtp_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
num_speculative_tokens: int,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
@@ -38,7 +35,7 @@ def test_mtp_correctness(
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.7,
|
||||
max_model_len=256,
|
||||
enforce_eager=True) as ref_llm:
|
||||
enforce_eager=False) as ref_llm:
|
||||
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
||||
|
||||
with VllmRunner(
|
||||
@@ -50,9 +47,9 @@ def test_mtp_correctness(
|
||||
enable_expert_parallel=True,
|
||||
speculative_config={
|
||||
"method": "deepseek_mtp",
|
||||
"num_speculative_tokens": 1,
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
},
|
||||
enforce_eager=True,
|
||||
enforce_eager=False,
|
||||
max_model_len=2000,
|
||||
additional_config={"ascend_scheduler_config": {
|
||||
"enabled": False
|
||||
@@ -74,3 +71,18 @@ def test_mtp_correctness(
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
|
||||
|
||||
def test_mtp1_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_correctness(sampling_config, model_name, 1)
|
||||
|
||||
|
||||
def test_mtp2_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_correctness(sampling_config, model_name, 2)
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from vllm import SamplingParams
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampling_config():
|
||||
|
||||
@@ -99,7 +99,6 @@ def test_ngram_correctness(
|
||||
assert matches > int(0.7 * len(ref_outputs))
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="oom in CI, fix me")
|
||||
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
|
||||
def test_eagle_correctness(
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
@@ -111,8 +110,6 @@ def test_eagle_correctness(
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
'''
|
||||
if not use_eagle3:
|
||||
pytest.skip("Not current support for the test.")
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
@@ -121,7 +118,6 @@ def test_eagle_correctness(
|
||||
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
|
||||
with VllmRunner(
|
||||
model_name,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_seqs=1,
|
||||
max_num_batched_tokens=2048,
|
||||
|
||||
Reference in New Issue
Block a user