diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 96aa0d5..1334328 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -123,7 +123,11 @@ jobs: --ignore=tests/singlecard/test_camem.py else pytest -sv tests/multicard/test_ilama_lora_tp2.py - VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py + # To avoid oom, we need to run the test in a single process. + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py fi - name: Run vllm-project/vllm-ascend test on V0 engine @@ -149,7 +153,9 @@ jobs: else pytest -sv tests/multicard/test_ilama_lora_tp2.py # Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py will raise error. + # To avoid oom, we need to run the test in a single process. VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py fi diff --git a/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py b/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py index 8af68c1..66e7aa5 100644 --- a/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py +++ b/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py @@ -56,6 +56,12 @@ def create_worker( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + + # TODO(Yizhou): A quick fix, must be refactored ASAP + draft_worker_kwargs["vllm_config"].parallel_config.expert_parallel_size = 1 + draft_worker_kwargs[ + "vllm_config"].parallel_config.expert_tensor_parallel_size = 1 + draft_model_config = draft_worker_kwargs["vllm_config"].model_config draft_parallel_config: ParallelConfig = draft_worker_kwargs[ 'vllm_config'].parallel_config