### What this PR does / why we need it?
Fix https://github.com/vllm-project/vllm-ascend/issues/7308.
Subtracting `init_non_torch_memory` (maybe used by the first instance)
from the total `non_torch_memory` when calculating
`available_kv_cache_memory`.
Directly use `non_torch_memory_increase` (contained in
`non_kv_cache_memory`) to calculate `available_kv_cache_memory`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Launch tow vllm-ascend instances sequentially on single card.
```bash
# Launch first instance
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-0.6B \
--port 8100 \
--host 0.0.0.0 \
--additional-config='{"enable_cpu_binding":true}' \
--gpu-memory-utilization 0.3 \
--max-num-seqs 1 \
--max-model-len 2048 \
--max-num-batched-tokens 2048 \
--no-enable-prefix-caching \
--enforce-eager
# Launch second instance
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-0.6B \
--port 8101 \
--host 0.0.0.0 \
--additional-config='{"enable_cpu_binding":true}' \
--gpu-memory-utilization 0.3 \
--max-num-seqs 1 \
--max-model-len 2048 \
--max-num-batched-tokens 2048 \
--no-enable-prefix-caching \
--enforce-eager
```
**Before this PR:**
```bash
# First instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.2340388298034668 GiB
init_non_torch_memory: 0.3616676330566406 GiB
non_torch_memory_before_empty_cache: 0.3896217346191406 GiB
non_torch_memory_increase: 0.0279541015625 GiB
non_torch_memory_cleared_by_empty_cache: 0.3616676330566406 GiB
------------------------------------------------------------------
# Second instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.2336344718933105 GiB
init_non_torch_memory: 18.37220001220703 GiB
non_torch_memory_before_empty_cache: 18.399906158447266 GiB
non_torch_memory_increase: 0.02754974365234375 GiB
non_torch_memory_cleared_by_empty_cache: 18.372356414794922 GiB
------------------------------------------------------------------
# available_kv_cache_memory = requested_memory - non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache
Available KV cache memory: -1.32 GiB
```
**After this PR:**
```bash
# First instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.2340540885925293 GiB
init_non_torch_memory: 0.36182403564453125 GiB
non_torch_memory_before_empty_cache: 0.38979339599609375 GiB
non_torch_memory_increase: 0.0279693603515625 GiB
non_torch_memory_cleared_by_empty_cache: 0.0 GiB
------------------------------------------------------------------
# Second instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.233344554901123 GiB
init_non_torch_memory: 18.74309539794922 GiB
non_torch_memory_before_empty_cache: 18.770355224609375 GiB
non_torch_memory_increase: 0.02725982666015625 GiB
non_torch_memory_cleared_by_empty_cache: 0.0 GiB
------------------------------------------------------------------
# available_kv_cache_memory = requested_memory - non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache
Available KV cache memory: 17.05 GiB
```
- vLLM version: v0.17.0
- vLLM main:
4497431df6
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
329 lines
11 KiB
Python
329 lines
11 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
|
#
|
|
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
|
|
|
Run `pytest tests/test_offline_inference.py`.
|
|
"""
|
|
|
|
import os
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from vllm import SamplingParams
|
|
|
|
from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free
|
|
from tests.e2e.model_utils import check_outputs_equal
|
|
|
|
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
|
|
|
QWEN_DENSE_MODELS = [
|
|
"vllm-ascend/Qwen3-0.6B-W8A8",
|
|
]
|
|
|
|
QWEN_W4A8_MODELS = [
|
|
"vllm-ascend/Qwen3-1.7B-W4A8-V1",
|
|
]
|
|
|
|
QWEN_W4A4_MODELS = [
|
|
"Eco-Tech/Qwen3-32B-w4a4-LAOS",
|
|
]
|
|
|
|
DEEPSEEK_W4A8_MODELS = [
|
|
"vllm-ascend/DeepSeek-V3.1-W4A8-puring",
|
|
]
|
|
|
|
GPT_OSS_MODELS = [
|
|
"unsloth/gpt-oss-20b-BF16",
|
|
]
|
|
|
|
|
|
def test_deepseek_multistream_moe_tp2():
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
dtype = "half"
|
|
max_tokens = 5
|
|
with VllmRunner(
|
|
"vllm-ascend/DeepSeek-V3-Pruning",
|
|
dtype=dtype,
|
|
tensor_parallel_size=2,
|
|
cudagraph_capture_sizes=[1, 2, 4, 8],
|
|
distributed_executor_backend="mp",
|
|
additional_config={
|
|
"enable_multistream_moe": True,
|
|
"refresh": True,
|
|
},
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
|
|
|
|
@pytest.mark.parametrize("model", QWEN_W4A8_MODELS)
|
|
def test_qwen3_w4a8_dynamic_tp2(model):
|
|
prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
max_tokens = 5
|
|
with VllmRunner(
|
|
model,
|
|
max_model_len=8192,
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
cudagraph_capture_sizes=[1, 2, 4, 8],
|
|
quantization="ascend",
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(prompts, max_tokens)
|
|
|
|
|
|
@wait_until_npu_memory_free(target_free_percentage=0.95)
|
|
def test_qwen3_moe_sp_tp2() -> None:
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
sampling_params = SamplingParams(max_tokens=5, temperature=0.0, top_k=50, top_p=0.9)
|
|
|
|
with VllmRunner(
|
|
"Qwen/Qwen3-30B-A3B",
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
distributed_executor_backend="mp",
|
|
compilation_config={"pass_config": {"enable_sp": True}},
|
|
enable_expert_parallel=True,
|
|
enforce_eager=True,
|
|
) as vllm_model:
|
|
vllm_model.generate(example_prompts, sampling_params)
|
|
|
|
|
|
@pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS)
|
|
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "2048"})
|
|
@wait_until_npu_memory_free(target_free_percentage=0.95)
|
|
def test_deepseek_w4a8_accuracy_tp2(model):
|
|
prompts = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs",
|
|
]
|
|
vllm_ds_w4a8_answers = ["逍遙而至地去 accrued", "平行于我udo madreHelen", "ysteepaolis backwards Kj"]
|
|
sampling_params = SamplingParams(max_tokens=5, temperature=0.0)
|
|
with VllmRunner(
|
|
model,
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
cudagraph_capture_sizes=[1, 2, 4, 8],
|
|
quantization="ascend",
|
|
enable_expert_parallel=True,
|
|
) as vllm_model:
|
|
vllm_quant_outputs = vllm_model.model.generate(prompts, sampling_params)
|
|
|
|
vllm_quant_outputs_list = []
|
|
for output in vllm_quant_outputs:
|
|
vllm_quant_outputs_list.append(([output.outputs[0].index], output.outputs[0].text))
|
|
vllm_answer_list = []
|
|
vllm_answer_list = [([0], answer) for answer in vllm_ds_w4a8_answers]
|
|
|
|
check_outputs_equal(
|
|
outputs_0_lst=vllm_answer_list,
|
|
outputs_1_lst=vllm_quant_outputs_list,
|
|
name_0="vllm_quant_outputs",
|
|
name_1="vllm_answer_outputs",
|
|
)
|
|
|
|
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "1"})
|
|
def test_qwen3_moe_fc2_tp2() -> None:
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
sampling_params = SamplingParams(max_tokens=5, temperature=0.0, top_k=50, top_p=0.9)
|
|
|
|
with VllmRunner(
|
|
"Qwen/Qwen3-30B-A3B",
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
distributed_executor_backend="mp",
|
|
enable_expert_parallel=True,
|
|
enforce_eager=True,
|
|
) as vllm_model:
|
|
vllm_model.generate(example_prompts, sampling_params)
|
|
|
|
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "1"})
|
|
def test_qwen3_moe_fc2_oshard_tp2() -> None:
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
sampling_params = SamplingParams(max_tokens=5, temperature=0.0, top_k=50, top_p=0.9)
|
|
|
|
with VllmRunner(
|
|
"Qwen/Qwen3-30B-A3B",
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
distributed_executor_backend="mp",
|
|
enable_expert_parallel=True,
|
|
enforce_eager=True, # TODO(Levi-JQ): support graph mode for fc2 in Qwen
|
|
additional_config={"layer_sharding": ["o_proj"]},
|
|
) as vllm_model:
|
|
vllm_model.generate(example_prompts, sampling_params)
|
|
|
|
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
|
def test_deepseek_v2_lite_fc1_tp2() -> None:
|
|
example_prompts = [
|
|
"test" * 1001,
|
|
]
|
|
sampling_params = SamplingParams(max_tokens=5, temperature=0.0, top_k=50, top_p=0.9)
|
|
with VllmRunner(
|
|
"vllm-ascend/DeepSeek-V2-Lite-W8A8",
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
distributed_executor_backend="mp",
|
|
enable_expert_parallel=True,
|
|
enforce_eager=True,
|
|
quantization="ascend",
|
|
) as vllm_model:
|
|
vllm_model.generate(example_prompts, sampling_params)
|
|
|
|
|
|
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
|
def test_qwen3_dense_fc1_tp2(model):
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
max_tokens = 5
|
|
|
|
with VllmRunner(
|
|
model,
|
|
max_model_len=8192,
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
cudagraph_capture_sizes=[1, 2, 4, 8],
|
|
quantization="ascend",
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
|
|
|
|
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
|
def test_qwen3_dense_prefetch_mlp_weight_tp2(model):
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
max_tokens = 5
|
|
|
|
with VllmRunner(
|
|
model,
|
|
max_model_len=8192,
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
cudagraph_capture_sizes=[1, 2, 4, 8],
|
|
quantization="ascend",
|
|
additional_config={"weight_prefetch_config": {"enabled": True}},
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
|
|
|
|
@patch.dict(os.environ, {"HCCL_OP_EXPANSION_MODE": "AIV"})
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
|
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
|
|
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
|
|
@wait_until_npu_memory_free()
|
|
def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
|
|
short_example_prompts = [
|
|
"Hello ",
|
|
]
|
|
# "max_position_embeddings": 163840,
|
|
long_example_prompts = ["Hello " * (163839 - 500) + "Hello"]
|
|
max_tokens = 500
|
|
with VllmRunner(
|
|
"vllm-ascend/DeepSeek-V3.2-W8A8-Pruning",
|
|
tensor_parallel_size=2,
|
|
quantization="ascend",
|
|
enable_expert_parallel=True,
|
|
max_model_len=163840,
|
|
compilation_config={"cudagraph_capture_sizes": [2, 4, 6, 8, 10, 12], "cudagraph_mode": "FULL_DECODE_ONLY"},
|
|
speculative_config={"num_speculative_tokens": 1, "method": "deepseek_mtp"},
|
|
additional_config={"layer_sharding": ["q_b_proj", "o_proj"]},
|
|
reasoning_parser="deepseek_v3",
|
|
tokenizer_mode="deepseek_v32",
|
|
gpu_memory_utilization=0.8,
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(short_example_prompts, max_tokens)
|
|
vllm_model.generate_greedy(long_example_prompts, max_tokens)
|
|
|
|
|
|
@patch.dict(os.environ, {"HCCL_OP_EXPANSION_MODE": "AIV"})
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
|
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
|
|
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
|
|
@wait_until_npu_memory_free()
|
|
def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep():
|
|
short_example_prompts = [
|
|
"Hello ",
|
|
]
|
|
# "max_position_embeddings": 163840,
|
|
long_example_prompts = ["Hello " * (163839 - 500) + "Hello"]
|
|
max_tokens = 500
|
|
with VllmRunner(
|
|
"vllm-ascend/DeepSeek-V3.2-W8A8-Pruning",
|
|
tensor_parallel_size=2,
|
|
quantization="ascend",
|
|
enable_expert_parallel=True,
|
|
max_model_len=163840,
|
|
compilation_config={"cudagraph_capture_sizes": [2, 4, 6, 8, 10, 12], "cudagraph_mode": "FULL_DECODE_ONLY"},
|
|
speculative_config={"num_speculative_tokens": 1, "method": "deepseek_mtp"},
|
|
additional_config={"layer_sharding": ["q_b_proj", "o_proj"], "enable_sparse_c8": True},
|
|
reasoning_parser="deepseek_v3",
|
|
tokenizer_mode="deepseek_v32",
|
|
gpu_memory_utilization=0.8,
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(short_example_prompts, max_tokens)
|
|
vllm_model.generate_greedy(long_example_prompts, max_tokens)
|
|
|
|
|
|
@pytest.mark.parametrize("model", QWEN_W4A4_MODELS)
|
|
def test_qwen3_w4a4_distributed_tp2(model):
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
max_tokens = 5
|
|
with VllmRunner(
|
|
model,
|
|
tensor_parallel_size=2,
|
|
cudagraph_capture_sizes=[1, 2, 4, 8],
|
|
quantization="ascend",
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
|
|
|
|
@pytest.mark.parametrize("model", GPT_OSS_MODELS)
|
|
def test_gpt_oss_distributed_tp2(model):
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
max_tokens = 5
|
|
with VllmRunner(
|
|
model,
|
|
tensor_parallel_size=2,
|
|
enforce_eager=True,
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|