[Test] Remove VLLM_USE_V1 in example and tests (#1733)
V1 is enabled by default, no need to set it by hand now. This PR remove
the useless setting in example and tests
- vLLM version: v0.9.2
- vLLM main:
9ad0a4588b
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -53,7 +53,6 @@ def model_name():
|
||||
@pytest.mark.skipif(
|
||||
True, reason="TODO: Enable me after test_mtp_correctness is fixed")
|
||||
def test_mtp_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
@@ -62,33 +61,30 @@ def test_mtp_correctness(
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using mtp speculative decoding.
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
spec_llm = LLM(model=model_name,
|
||||
trust_remote_code=True,
|
||||
speculative_config={
|
||||
"method": "deepseek_mtp",
|
||||
"num_speculative_tokens": 1,
|
||||
},
|
||||
max_model_len=256,
|
||||
enforce_eager=True)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
spec_llm = LLM(model=model_name,
|
||||
trust_remote_code=True,
|
||||
speculative_config={
|
||||
"method": "deepseek_mtp",
|
||||
"num_speculative_tokens": 1,
|
||||
},
|
||||
max_model_len=256,
|
||||
enforce_eager=True)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
|
||||
@@ -60,7 +60,6 @@ def eagle3_model_name():
|
||||
|
||||
|
||||
def test_ngram_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
@@ -70,44 +69,40 @@ def test_ngram_correctness(
|
||||
should be the same when using ngram speculative decoding.
|
||||
'''
|
||||
pytest.skip("Not current support for the test.")
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=True)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=True)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
max_model_len=1024,
|
||||
enforce_eager=True,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
max_model_len=1024,
|
||||
enforce_eager=True,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 70% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.7 * len(ref_outputs))
|
||||
del spec_llm
|
||||
# Heuristic: expect at least 70% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.7 * len(ref_outputs))
|
||||
del spec_llm
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
@@ -119,43 +114,40 @@ def test_eagle_correctness(
|
||||
'''
|
||||
if not use_eagle3:
|
||||
pytest.skip("Not current support for the test.")
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
|
||||
spec_model_name = eagle3_model_name(
|
||||
) if use_eagle3 else eagle_model_name()
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_seqs=1,
|
||||
max_num_batched_tokens=2048,
|
||||
gpu_memory_utilization=0.6,
|
||||
speculative_config={
|
||||
"method": "eagle3" if use_eagle3 else "eagle",
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 2,
|
||||
"max_model_len": 128,
|
||||
},
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_seqs=1,
|
||||
max_num_batched_tokens=2048,
|
||||
gpu_memory_utilization=0.6,
|
||||
speculative_config={
|
||||
"method": "eagle3" if use_eagle3 else "eagle",
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 2,
|
||||
"max_model_len": 128,
|
||||
},
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
|
||||
Reference in New Issue
Block a user