Handle with_prefill_across_dp for multistream mla (#1322)
### What this PR does / why we need it? After #1094, decode might be executed with non-compiled mode, despite of `torchair_graph_config.enabled`, causing multistream mla to fail, which assumes torchair compiled mode for decode when `torchair_graph_config.enabled == True`. Augment that assumption to fix this. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tested both offline, and by graph mode mla e2e testcase. --------- Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
@@ -20,6 +20,7 @@
|
||||
Run `pytest tests/multicard/test_torchair_graph_mode.py`.
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -28,53 +29,73 @@ from tests.conftest import VllmRunner
|
||||
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
||||
|
||||
|
||||
def _deepseek_torchair_test_fixture(
|
||||
additional_config: Dict,
|
||||
*,
|
||||
tensor_parallel_size=4,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# torchair is only work without chunked-prefill now
|
||||
kwargs = {
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True,
|
||||
}
|
||||
additional_config.update(**kwargs)
|
||||
|
||||
with VllmRunner(
|
||||
"vllm-ascend/DeepSeek-V3-Pruning",
|
||||
dtype="half",
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend="mp",
|
||||
enforce_eager=False,
|
||||
additional_config=additional_config,
|
||||
) as vllm_model:
|
||||
# use greedy sampler to make sure the generated results are fix
|
||||
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
|
||||
|
||||
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
|
||||
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
|
||||
# inaccurate. This will only change if accuracy improves with the
|
||||
# official weights of DeepSeek-V3.
|
||||
golden_results = [
|
||||
'Hello, my name is feasibility伸 spazio debtor添',
|
||||
'The president of the United States is begg"""\n杭州风和 bestimm',
|
||||
'The capital of France is frequentlyশามalinkAllowed',
|
||||
'The future of AI is deleting俯احت怎么样了حراف',
|
||||
]
|
||||
|
||||
assert len(golden_results) == len(vllm_output)
|
||||
for i in range(len(vllm_output)):
|
||||
assert golden_results[i] == vllm_output[i][1]
|
||||
print(f"Generated text: {vllm_output[i][1]!r}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
|
||||
reason="torchair graph is not supported on v0")
|
||||
def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_MODELSCOPE", "True")
|
||||
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
def test_e2e_deepseekv3_with_torchair():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
_deepseek_torchair_test_fixture(additional_config)
|
||||
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
dtype = "half"
|
||||
max_tokens = 5
|
||||
# torchair is only work without chunked-prefill now
|
||||
with VllmRunner(
|
||||
"vllm-ascend/DeepSeek-V3-Pruning",
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=4,
|
||||
distributed_executor_backend="mp",
|
||||
additional_config={
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True,
|
||||
},
|
||||
enforce_eager=False,
|
||||
) as vllm_model:
|
||||
# use greedy sampler to make sure the generated results are fix
|
||||
vllm_output = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
|
||||
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
|
||||
# inaccurate. This will only change if accuracy improves with the
|
||||
# official weights of DeepSeek-V3.
|
||||
golden_results = [
|
||||
'Hello, my name is feasibility伸 spazio debtor添',
|
||||
'The president of the United States is begg"""\n杭州风和 bestimm',
|
||||
'The capital of France is frequentlyশามalinkAllowed',
|
||||
'The future of AI is deleting俯احت怎么样了حراف',
|
||||
]
|
||||
|
||||
assert len(golden_results) == len(vllm_output)
|
||||
for i in range(len(vllm_output)):
|
||||
assert golden_results[i] == vllm_output[i][1]
|
||||
print(f"Generated text: {vllm_output[i][1]!r}")
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
|
||||
reason="torchair graph is not supported on v0")
|
||||
def test_e2e_deepseekv3_with_torchair_ms_mla():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"enable_multistream_mla": True,
|
||||
},
|
||||
}
|
||||
_deepseek_torchair_test_fixture(additional_config)
|
||||
|
||||
Reference in New Issue
Block a user