[BugFix] Fix world size bug in model_runner (#2915)
- Fix world size bug in model_runner to make sure ep>16 runs with MC2
- enable e2e test for vl
Co-Authored-By: whx-sjtu <2952154980@qq.com>
Co-Authored-By: Icey <1790571317@qq.com>
- vLLM version: v0.10.2
- vLLM main:
3e903b6cb4
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -22,17 +22,16 @@ Run `pytest tests/test_offline_inference.py`.
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
|
||||||
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="fix me")
|
|
||||||
def test_multimodal_vl(prompt_template):
|
def test_multimodal_vl(prompt_template):
|
||||||
image = ImageAsset("cherry_blossom") \
|
image = ImageAsset("cherry_blossom") \
|
||||||
.pil_image.convert("RGB")
|
.pil_image.convert("RGB")
|
||||||
@@ -52,9 +51,12 @@ def test_multimodal_vl(prompt_template):
|
|||||||
"fps": 1,
|
"fps": 1,
|
||||||
},
|
},
|
||||||
enforce_eager=True) as vllm_model:
|
enforce_eager=True) as vllm_model:
|
||||||
vllm_model.generate_greedy(prompts=prompts,
|
outputs = vllm_model.generate_greedy(prompts=prompts,
|
||||||
images=images,
|
images=images,
|
||||||
max_tokens=64)
|
max_tokens=64)
|
||||||
|
assert len(outputs) == len(prompts)
|
||||||
|
for _, output_str in outputs:
|
||||||
|
assert output_str, "Generated output should not be empty."
|
||||||
|
|
||||||
|
|
||||||
def test_multimodal_audio():
|
def test_multimodal_audio():
|
||||||
@@ -86,4 +88,7 @@ def test_multimodal_audio():
|
|||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
limit_mm_per_prompt={"audio": 2},
|
limit_mm_per_prompt={"audio": 2},
|
||||||
gpu_memory_utilization=0.9) as runner:
|
gpu_memory_utilization=0.9) as runner:
|
||||||
runner.generate(inputs, sampling_params=sampling_params)
|
outputs = runner.generate(inputs, sampling_params=sampling_params)
|
||||||
|
|
||||||
|
assert outputs is not None, "Generated outputs should not be None."
|
||||||
|
assert len(outputs) > 0, "Generated outputs should not be empty."
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel,
|
|||||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||||
mock_runner.parallel_config = MagicMock()
|
mock_runner.parallel_config = MagicMock()
|
||||||
mock_runner.parallel_config.enable_expert_parallel = enable_expert_parallel
|
mock_runner.parallel_config.enable_expert_parallel = enable_expert_parallel
|
||||||
mock_runner.parallel_config.world_size = world_size
|
mock_runner.parallel_config.world_size_across_dp = world_size
|
||||||
mock_runner.mc2_tokens_capacity = mc2_tokens_capacity
|
mock_runner.mc2_tokens_capacity = mc2_tokens_capacity
|
||||||
|
|
||||||
# Patch the helper functions
|
# Patch the helper functions
|
||||||
|
|||||||
@@ -1539,7 +1539,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if not self.parallel_config.enable_expert_parallel:
|
if not self.parallel_config.enable_expert_parallel:
|
||||||
moe_comm_method = "allgather"
|
moe_comm_method = "allgather"
|
||||||
elif soc_version in {AscendSocVersion.A2}:
|
elif soc_version in {AscendSocVersion.A2}:
|
||||||
if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size >= 16:
|
if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size_across_dp >= 16:
|
||||||
moe_comm_method = "mc2"
|
moe_comm_method = "mc2"
|
||||||
else:
|
else:
|
||||||
moe_comm_method = "allgather"
|
moe_comm_method = "allgather"
|
||||||
|
|||||||
Reference in New Issue
Block a user