v0.10.1rc1

This commit is contained in:
2025-09-09 09:40:35 +08:00
parent d6f6ef41fe
commit 9149384e03
432 changed files with 84698 additions and 1 deletions

0
tests/__init__.py Normal file
View File

View File

@@ -0,0 +1,72 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
import pytest
import vllm # noqa: F401
from vllm import SamplingParams
import vllm_ascend # noqa: F401
from tests.e2e.conftest import VllmRunner
MODELS = ["Qwen/Qwen3-0.6B", "Qwen/Qwen2.5-7B-Instruct"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float16"])
@pytest.mark.parametrize("max_tokens", [5])
def test_models(model: str, dtype: str, max_tokens: int) -> None:
example_prompts = [
"Hello, my name is",
"The future of AI is",
]
with VllmRunner(model,
tensor_parallel_size=1,
dtype=dtype,
max_model_len=2048,
enforce_eager=True,
compilation_config={
"custom_ops":
["none", "+rms_norm", "+rotary_embedding"]
}) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
VL_MODELS = ["Qwen/Qwen2.5-VL-3B-Instruct"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float16"])
def test_vl_model_with_samples(model: str, dtype: str) -> None:
example_prompts = [
"Hello, my name is",
"The future of AI is",
]
with VllmRunner(model,
tensor_parallel_size=1,
dtype=dtype,
max_model_len=2048,
enforce_eager=True,
compilation_config={
"custom_ops":
["none", "+rms_norm", "+rotary_embedding"]
}) as vllm_model:
sampling_params = SamplingParams(max_tokens=100,
top_p=0.95,
top_k=50,
temperature=0.6)
vllm_model.generate(example_prompts, sampling_params)

View File

@@ -0,0 +1,62 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
import pytest
import vllm # noqa: F401
import vllm_ascend # noqa: F401
from tests.e2e.conftest import VllmRunner
# Pangu local model path
MODELS = [
"IntervitensInc/pangu-pro-moe-model",
]
# set additional config for ascend scheduler and torchair graph
ADDITIONAL_CONFIG = [{
"additional_config": {
"torchair_graph_config": {
"enabled": True
},
"ascend_scheduler_config": {
"enabled": True,
}
}
}]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float16"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enfore_eager", [True, False])
@pytest.mark.parametrize("additional_config", ADDITIONAL_CONFIG)
def test_pangu_model(model: str, dtype: str, max_tokens: int,
enfore_eager: bool, additional_config: dict) -> None:
if enfore_eager:
additional_config = {}
example_prompts = [
"Hello, my name is",
"The future of AI is",
]
with VllmRunner(model,
tensor_parallel_size=4,
dtype=dtype,
max_model_len=1024,
enforce_eager=True,
enable_expert_parallel=True,
additional_config=additional_config,
distributed_executor_backend="mp") as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

0
tests/e2e/__init__.py Normal file
View File

74
tests/e2e/common.sh Normal file
View File

@@ -0,0 +1,74 @@
# bash fonts colors
cyan='\e[96m'
yellow='\e[33m'
red='\e[31m'
none='\e[0m'
_cyan() { echo -e "${cyan}$*${none}"; }
_yellow() { echo -e "${yellow}$*${none}"; }
_red() { echo -e "${red}$*${none}"; }
_info() { _cyan "Info: $*"; }
_warn() { _yellow "Warn: $*"; }
_err() { _red "Error: $*" && exit 1; }
CURL_TIMEOUT=1
CURL_COOLDOWN=5
CURL_MAX_TRIES=180
function wait_url_ready() {
local serve_name="$1"
local url="$2"
i=0
while true; do
_info "===> Waiting for ${serve_name} to be ready...${i}s"
i=$((i + CURL_COOLDOWN))
set +e
curl --silent --max-time "$CURL_TIMEOUT" "${url}" >/dev/null
result=$?
set -e
if [ "$result" -eq 0 ]; then
break
fi
if [ "$i" -gt "$CURL_MAX_TRIES" ]; then
_info "===> ${CURL_MAX_TRIES}s exceeded waiting for ${serve_name} to be ready"
return 1
fi
sleep "$CURL_COOLDOWN"
done
_info "===> ${serve_name} is ready."
}
function wait_for_exit() {
local VLLM_PID="$1"
while kill -0 "$VLLM_PID"; do
_info "===> Wait for ${VLLM_PID} to exit."
sleep 1
done
_info "===> Process ${VLLM_PID} has exited."
}
VENV_PATH=/tmp/vllm_venv
function clean_venv() {
if [[ -n "$VENV_PATH" && -d "$VENV_PATH" ]]; then
_info "Cleaning up default virtual env path: ${VENV_PATH}"
deactivate || true
rm -rf "$VENV_PATH"
fi
}
function create_vllm_venv() {
# make a clean env path
clean_venv
_info "Creating vllm virtual environment at ${VENV_PATH}"
python3 -m venv ${VENV_PATH}
source ${VENV_PATH}/bin/activate
}
function get_version() {
local VERSION_NAME="$1"
python3 "${SCRIPT_DIR}/../../docs/source/conf.py" | jq .${VERSION_NAME} | tr -d '"'
}
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)

431
tests/e2e/conftest.py Normal file
View File

@@ -0,0 +1,431 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
#
import contextlib
import gc
import os
from typing import Any, List, Optional, Tuple, TypeVar, Union
import numpy as np
import pytest
import torch
from modelscope import snapshot_download # type: ignore[import-untyped]
from PIL import Image
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BatchEncoding, BatchFeature)
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm import LLM, SamplingParams
from vllm.config import TaskOption, _get_and_verify_dtype
from vllm.inputs import TextPrompt
from vllm.outputs import RequestOutput
from vllm.transformers_utils.utils import maybe_model_redirect
from tests.e2e.model_utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs)
from vllm_ascend.ascend_config import clear_ascend_config
# TODO: remove this part after the patch merged into vllm, if
# we not explicitly patch here, some of them might be effectiveless
# in pytest scenario
from vllm_ascend.utils import adapt_patch # noqa E402
adapt_patch(True)
adapt_patch(False)
from vllm.distributed.parallel_state import ( # noqa E402
destroy_distributed_environment, destroy_model_parallel)
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
_M = TypeVar("_M")
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
PromptImageInput = _PromptMultiModalInput[Image.Image]
PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
_TEST_DIR = os.path.dirname(__file__)
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
if shutdown_ray:
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
class VllmRunner:
def __init__(
self,
model_name: str,
task: TaskOption = "auto",
tokenizer_name: Optional[str] = None,
tokenizer_mode: str = "auto",
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
max_model_len: int = 1024,
dtype: str = "auto",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = False,
swap_space: int = 4,
enforce_eager: Optional[bool] = False,
quantization: Optional[str] = None,
**kwargs,
) -> None:
self.model = LLM(
model=model_name,
task=task,
tokenizer=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=True,
dtype=dtype,
swap_space=swap_space,
enforce_eager=enforce_eager,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
block_size=block_size,
enable_chunked_prefill=enable_chunked_prefill,
quantization=quantization,
**kwargs,
)
def get_inputs(
self,
prompts: List[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[TextPrompt]:
if images is not None:
assert len(prompts) == len(images)
if videos is not None:
assert len(prompts) == len(videos)
if audios is not None:
assert len(prompts) == len(audios)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
if image is not None:
inputs[i]["multi_modal_data"] = {"image": image}
if videos is not None:
for i, video in enumerate(videos):
if video is not None:
inputs[i]["multi_modal_data"] = {"video": video}
if audios is not None:
for i, audio in enumerate(audios):
if audio is not None:
inputs[i]["multi_modal_data"] = {"audio": audio}
return inputs
def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids: List[List[int]] = []
req_sample_output_strs: List[str] = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
@staticmethod
def _final_steps_generate_w_logprobs(
req_outputs: List[RequestOutput],
) -> List[TokensTextLogprobsPromptLogprobs]:
outputs: List[TokensTextLogprobsPromptLogprobs] = []
for req_output in req_outputs:
assert len(req_output.outputs) > 0
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs,
req_output.prompt_logprobs))
return outputs
def generate_w_logprobs(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)
def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts,
greedy_params,
images=images,
videos=videos,
audios=audios)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]
def generate_greedy_logprobs(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
num_prompt_logprobs: Optional[int] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=num_prompt_logprobs,
stop_token_ids=stop_token_ids,
stop=stop)
return self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)
def encode(
self,
prompts: List[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[List[float]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.embed(inputs)
return [req_output.outputs.embedding for req_output in req_outputs]
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
del self.model
clear_ascend_config()
cleanup_dist_env_and_memory()
class HfRunner:
def get_default_device(self):
from vllm.platforms import current_platform
return ("cpu"
if current_platform.is_cpu() else current_platform.device_type)
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if x is None or isinstance(x, (bool, )):
return x
if device is None:
device = self.device
if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}
if hasattr(x, "device") and x.device.type == device:
return x
return x.to(device)
def __init__(
self,
model_name: str,
dtype: str = "auto",
*,
model_kwargs: Optional[dict[str, Any]] = None,
trust_remote_code: bool = True,
is_sentence_transformer: bool = False,
is_cross_encoder: bool = False,
skip_tokenizer_init: bool = False,
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
) -> None:
model_name = maybe_model_redirect(model_name)
self.model_name = model_name
self.config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
)
self.device = self.get_default_device()
self.dtype = torch_dtype = _get_and_verify_dtype(
self.model_name,
self.config,
dtype=dtype,
is_pooling_model=is_sentence_transformer or is_cross_encoder,
)
model_kwargs = model_kwargs if model_kwargs is not None else {}
model_kwargs.setdefault("torch_dtype", torch_dtype)
if is_sentence_transformer:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_name,
device=self.device,
model_kwargs=model_kwargs,
trust_remote_code=trust_remote_code,
)
elif is_cross_encoder:
# Lazy init required for AMD CI
from sentence_transformers import CrossEncoder
self.model = CrossEncoder(
model_name,
device=self.device,
automodel_args=model_kwargs,
trust_remote_code=trust_remote_code,
)
else:
model = auto_cls.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
**model_kwargs,
)
# in case some unquantized custom models are not in same dtype
if (getattr(model, "quantization_method", None) is None
and any(p.dtype != self.dtype
for p in model.parameters())):
model = model.to(dtype=self.dtype)
if (getattr(model, "quantization_method", None) != "bitsandbytes"
and len({p.device
for p in model.parameters()}) < 2):
model = model.to(device=self.device)
self.model = model
if not skip_tokenizer_init:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
self.processor = AutoProcessor.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
if skip_tokenizer_init:
self.tokenizer = self.processor.tokenizer
def encode(self, prompts: list[str], *args,
**kwargs) -> list[list[torch.Tensor]]:
return self.model.encode(prompts, *args, **kwargs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup_dist_env_and_memory()
@pytest.fixture(scope="session")
def ilama_lora_files():
return snapshot_download(repo_id="vllm-ascend/ilama-text2sql-spider")
def qwen_prompt(questions: List[str]) -> List[str]:
placeholder = "<|image_pad|>"
return [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{q}<|im_end|>\n<|im_start|>assistant\n") for q in questions]
PROMPT_TEMPLATES = {
"qwen2.5vl": qwen_prompt,
}
@pytest.fixture(params=list(PROMPT_TEMPLATES.keys()))
def prompt_template(request):
return PROMPT_TEMPLATES[request.param]

View File

@@ -0,0 +1,64 @@
#!/bin/bash
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
function install_system_packages() {
if command -v apt-get >/dev/null; then
sed -i 's|ports.ubuntu.com|mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list
apt-get update -y && apt install -y curl
elif command -v yum >/dev/null; then
yum update -y && yum install -y curl
else
echo "Unknown package manager. Please install gcc, g++, numactl-devel, git, curl, and jq manually."
fi
}
function simple_test() {
# Do real import test
python3 -c "import vllm; print(vllm.__version__)"
}
function quickstart_offline_test() {
# Do real script test
python3 "${SCRIPT_DIR}/../../examples/offline_inference_npu.py"
}
function quickstart_online_test() {
install_system_packages
vllm serve Qwen/Qwen2.5-0.5B-Instruct &
wait_url_ready "vllm serve" "localhost:8000/v1/models"
# Do real curl test
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"prompt": "Beijing is a",
"max_tokens": 5,
"temperature": 0
}' | python3 -m json.tool
VLLM_PID=$(pgrep -f "vllm serve")
_info "===> Try kill -2 ${VLLM_PID} to exit."
kill -2 "$VLLM_PID"
wait_for_exit "$VLLM_PID"
}
_info "====> Start simple_test"
simple_test
_info "====> Start quickstart_offline_test"
quickstart_offline_test
_info "====> Start quickstart_online_test"
quickstart_online_test

View File

@@ -0,0 +1,62 @@
#!/bin/bash
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
trap clean_venv EXIT
function install_system_packages() {
if command -v apt-get >/dev/null; then
sed -i 's|ports.ubuntu.com|mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list
apt-get update -y && apt-get install -y gcc g++ cmake libnuma-dev wget git curl jq
elif command -v yum >/dev/null; then
yum update -y && yum install -y gcc g++ cmake numactl-devel wget git curl jq
else
echo "Unknown package manager. Please install curl manually."
fi
}
function config_pip_mirror() {
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
}
function install_binary_test() {
install_system_packages
config_pip_mirror
create_vllm_venv
PIP_VLLM_VERSION=$(get_version pip_vllm_version)
PIP_VLLM_ASCEND_VERSION=$(get_version pip_vllm_ascend_version)
_info "====> Install vllm==${PIP_VLLM_VERSION} and vllm-ascend ${PIP_VLLM_ASCEND_VERSION}"
# Setup extra-index-url for x86 & torch_npu dev version
pip config set global.extra-index-url "https://download.pytorch.org/whl/cpu/ https://mirrors.huaweicloud.com/ascend/repos/pypi"
pip install vllm=="$(get_version pip_vllm_version)"
pip install vllm-ascend=="$(get_version pip_vllm_ascend_version)"
pip list | grep vllm
# Verify the installation
_info "====> Run offline example test"
pip install modelscope
python3 "${SCRIPT_DIR}/../../examples/offline_inference_npu.py"
}
_info "====> Start install_binary_test"
install_binary_test

74
tests/e2e/model_utils.py Normal file
View File

@@ -0,0 +1,74 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/models/utils.py
#
from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import PromptLogprobs, SampleLogprobs
TokensText = Tuple[List[int], str]
def check_outputs_equal(
*,
outputs_0_lst: Sequence[TokensText],
outputs_1_lst: Sequence[TokensText],
name_0: str,
name_1: str,
):
"""
Compare the two sequences generated by different models,
which should be equal.
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1
# The text and token outputs should exactly match
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_str_0 == output_str_1, fail_msg
assert output_ids_0 == output_ids_1, fail_msg
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * List of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
float]],
SampleLogprobs]]]
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * Optional list of top sample logprobs for each sampled token
# * Optional list of top prompt logprobs for each prompt token
#
# Allows prompt logprobs to be requested.
TokensTextLogprobsPromptLogprobs = Tuple[
List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]

View File

@@ -0,0 +1,13 @@
model_name: "deepseek-ai/DeepSeek-V2-Lite"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.375
- name: "exact_match,flexible-extract"
value: 0.375
tensor_parallel_size: 2
apply_chat_template: False
fewshot_as_multiturn: False
trust_remote_code: True
enforce_eager: True

View File

@@ -0,0 +1,8 @@
model_name: "Qwen/Qwen2.5-VL-7B-Instruct"
model: "vllm-vlm"
tasks:
- name: "mmmu_val"
metrics:
- name: "acc,none"
value: 0.51
max_model_len: 8192

View File

@@ -0,0 +1,18 @@
model_name: "Qwen/Qwen3-30B-A3B"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.89
- name: "exact_match,flexible-extract"
value: 0.85
- name: "ceval-valid"
metrics:
- name: "acc,none"
value: 0.84
num_fewshot: 5
gpu_memory_utilization: 0.6
enable_expert_parallel: True
tensor_parallel_size: 2
apply_chat_template: False
fewshot_as_multiturn: False

View File

@@ -0,0 +1,13 @@
model_name: "Qwen/Qwen3-8B-Base"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.82
- name: "exact_match,flexible-extract"
value: 0.83
- name: "ceval-valid"
metrics:
- name: "acc,none"
value: 0.82
num_fewshot: 5

View File

@@ -0,0 +1,3 @@
Qwen3-8B-Base.yaml
Qwen2.5-VL-7B-Instruct.yaml
Qwen3-30B-A3B.yaml

View File

@@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
import pytest
def pytest_addoption(parser):
parser.addoption(
"--config-list-file",
action="store",
default=None,
help="Path to the file listing model config YAMLs (one per line)",
)
parser.addoption(
"--tp-size",
action="store",
default="1",
help="Tensor parallel size to use for evaluation",
)
parser.addoption(
"--config",
action="store",
default="./tests/e2e/models/configs/Qwen3-8B-Base.yaml",
help="Path to the model config YAML file",
)
parser.addoption(
"--report-dir",
action="store",
default="./benchmarks/accuracy",
help="Directory to store report files",
)
@pytest.fixture(scope="session")
def config_list_file(pytestconfig, config_dir):
rel_path = pytestconfig.getoption("--config-list-file")
return config_dir / rel_path
@pytest.fixture(scope="session")
def tp_size(pytestconfig):
return pytestconfig.getoption("--tp-size")
@pytest.fixture(scope="session")
def config(pytestconfig):
return pytestconfig.getoption("--config")
@pytest.fixture(scope="session")
def report_dir(pytestconfig):
return pytestconfig.getoption("report_dir")
def pytest_generate_tests(metafunc):
if "config_filename" in metafunc.fixturenames:
if metafunc.config.getoption("--config-list-file"):
rel_path = metafunc.config.getoption("--config-list-file")
config_list_file = Path(rel_path).resolve()
config_dir = config_list_file.parent
with open(config_list_file, encoding="utf-8") as f:
configs = [
config_dir / line.strip() for line in f
if line.strip() and not line.startswith("#")
]
metafunc.parametrize("config_filename", configs)
else:
single_config = metafunc.config.getoption("--config")
config_path = Path(single_config).resolve()
metafunc.parametrize("config_filename", [config_path])

View File

@@ -0,0 +1,21 @@
# {{ model_name }}
- **vLLM Version**: vLLM: {{ vllm_version }} ([{{ vllm_commit[:7] }}](https://github.com/vllm-project/vllm/commit/{{ vllm_commit }})), **vLLM Ascend Version**: {{ vllm_ascend_version }} ([{{ vllm_ascend_commit[:7] }}](https://github.com/vllm-project/vllm-ascend/commit/{{ vllm_ascend_commit }}))
- **Software Environment**: **CANN**: {{ cann_version }}, **PyTorch**: {{ torch_version }}, **torch-npu**: {{ torch_npu_version }}
- **Hardware Environment**: Atlas A2 Series
- **Parallel mode**: {{ parallel_mode }}
- **Execution mode**: ACLGraph
**Command**:
```bash
export MODEL_ARGS={{ model_args }}
lm_eval --model {{ model_type }} --model_args $MODEL_ARGS --tasks {{ datasets }} \
{% if apply_chat_template %} --apply_chat_template {{ apply_chat_template }} {% endif %} {% if fewshot_as_multiturn %} --fewshot_as_multiturn {{ fewshot_as_multiturn }} {% endif %} {% if num_fewshot is defined and num_fewshot != "N/A" %} --num_fewshot {{ num_fewshot }} {% endif %} {% if limit is defined and limit != "N/A" %} --limit {{ limit }} {% endif %} --batch_size {{ batch_size}}
```
| Task | Metric | Value | Stderr |
|-----------------------|-------------|----------:|-------:|
{% for row in rows -%}
| {{ row.task }} | {{ row.metric }} | {{ row.value }} | ± {{ "%.4f" | format(row.stderr | float) }} |
{% endfor %}

View File

@@ -0,0 +1,153 @@
import os
from dataclasses import dataclass
import lm_eval
import numpy as np
import pytest
import yaml
from jinja2 import Environment, FileSystemLoader
RTOL = 0.03
TEST_DIR = os.path.dirname(__file__)
@dataclass
class EnvConfig:
vllm_version: str
vllm_commit: str
vllm_ascend_version: str
vllm_ascend_commit: str
cann_version: str
torch_version: str
torch_npu_version: str
@pytest.fixture
def env_config() -> EnvConfig:
return EnvConfig(vllm_version=os.getenv('VLLM_VERSION', 'unknown'),
vllm_commit=os.getenv('VLLM_COMMIT', 'unknown'),
vllm_ascend_version=os.getenv('VLLM_ASCEND_VERSION',
'unknown'),
vllm_ascend_commit=os.getenv('VLLM_ASCEND_COMMIT',
'unknown'),
cann_version=os.getenv('CANN_VERSION', 'unknown'),
torch_version=os.getenv('TORCH_VERSION', 'unknown'),
torch_npu_version=os.getenv('TORCH_NPU_VERSION',
'unknown'))
def build_model_args(eval_config, tp_size):
trust_remote_code = eval_config.get("trust_remote_code", False)
max_model_len = eval_config.get("max_model_len", 4096)
model_args = {
"pretrained": eval_config["model_name"],
"tensor_parallel_size": tp_size,
"dtype": "auto",
"trust_remote_code": trust_remote_code,
"max_model_len": max_model_len,
}
for s in [
"max_images", "gpu_memory_utilization", "enable_expert_parallel",
"tensor_parallel_size", "enforce_eager"
]:
val = eval_config.get(s, None)
if val is not None:
model_args[s] = val
print("Model Parameters:")
print(model_args)
return model_args
def generate_report(tp_size, eval_config, report_data, report_dir, env_config):
env = Environment(loader=FileSystemLoader(TEST_DIR))
template = env.get_template("report_template.md")
model_args = build_model_args(eval_config, tp_size)
parallel_mode = f"TP{model_args.get('tensor_parallel_size', 1)}"
if model_args.get('enable_expert_parallel', False):
parallel_mode += " + EP"
report_content = template.render(
vllm_version=env_config.vllm_version,
vllm_commit=env_config.vllm_commit,
vllm_ascend_version=env_config.vllm_ascend_version,
vllm_ascend_commit=env_config.vllm_ascend_commit,
cann_version=env_config.cann_version,
torch_version=env_config.torch_version,
torch_npu_version=env_config.torch_npu_version,
model_name=eval_config["model_name"],
model_args=f"'{','.join(f'{k}={v}' for k, v in model_args.items())}'",
model_type=eval_config.get("model", "vllm"),
datasets=",".join([task["name"] for task in eval_config["tasks"]]),
apply_chat_template=eval_config.get("apply_chat_template", True),
fewshot_as_multiturn=eval_config.get("fewshot_as_multiturn", True),
limit=eval_config.get("limit", "N/A"),
batch_size="auto",
num_fewshot=eval_config.get("num_fewshot", "N/A"),
rows=report_data["rows"],
parallel_mode=parallel_mode)
report_output = os.path.join(
report_dir, f"{os.path.basename(eval_config['model_name'])}.md")
os.makedirs(os.path.dirname(report_output), exist_ok=True)
with open(report_output, 'w', encoding='utf-8') as f:
f.write(report_content)
def test_lm_eval_correctness_param(config_filename, tp_size, report_dir,
env_config):
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
model_args = build_model_args(eval_config, tp_size)
success = True
report_data: dict[str, list[dict]] = {"rows": []}
eval_params = {
"model": eval_config.get("model", "vllm"),
"model_args": model_args,
"tasks": [task["name"] for task in eval_config["tasks"]],
"apply_chat_template": eval_config.get("apply_chat_template", True),
"fewshot_as_multiturn": eval_config.get("fewshot_as_multiturn", True),
"limit": eval_config.get("limit", None),
"batch_size": "auto",
}
for s in ["num_fewshot", "fewshot_as_multiturn", "apply_chat_template"]:
val = eval_config.get(s, None)
if val is not None:
eval_params[s] = val
print("Eval Parameters:")
print(eval_params)
results = lm_eval.simple_evaluate(**eval_params)
for task in eval_config["tasks"]:
task_name = task["name"]
task_result = results["results"][task_name]
for metric in task["metrics"]:
metric_name = metric["name"]
ground_truth = metric["value"]
measured_value = round(task_result[metric_name], 4)
task_success = bool(
np.isclose(ground_truth, measured_value, rtol=RTOL))
success = success and task_success
print(f"{task_name} | {metric_name}: "
f"ground_truth={ground_truth} | measured={measured_value} | "
f"success={'' if task_success else ''}")
report_data["rows"].append({
"task":
task_name,
"metric":
metric_name,
"value":
f"{measured_value}" if success else f"{measured_value}",
"stderr":
task_result[
metric_name.replace(',', '_stderr,') if metric_name ==
"acc,none" else metric_name.replace(',', '_stderr,')]
})
generate_report(tp_size, eval_config, report_data, report_dir, env_config)
assert success

View File

@@ -0,0 +1,73 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Compare the outputs of vLLM with and without aclgraph.
Run `pytest tests/multicard/test_data_parallel.py`.
"""
import os
import subprocess
import sys
from unittest.mock import patch
import pytest
MODELS = ["Qwen/Qwen3-0.6B", "Qwen/Qwen3-30B-A3B"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"})
def test_data_parallel_inference(model, max_tokens):
script = "examples/offline_data_parallel.py"
env = os.environ.copy()
cmd = [
sys.executable,
script,
"--model",
model,
"--dp-size",
"2",
"--tp-size",
"1",
"--node-size",
"1",
"--node-rank",
"0",
"--trust-remote-code",
"--enforce-eager",
]
if model == "Qwen/Qwen3-30B-A3B":
cmd.append("--enable-expert-parallel")
print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=600)
output = proc.stdout.decode()
print(output)
assert "DP rank 0 needs to process" in output
assert "DP rank 1 needs to process" in output
assert "Generated text:" in output
assert proc.returncode == 0

View File

@@ -0,0 +1,32 @@
import pytest
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
@pytest.mark.parametrize("model_name", ["deepseek-ai/DeepSeek-V2-Lite-Chat"])
def test_e2e_ep_correctness(model_name):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
max_tokens = 5
with VllmRunner(model_name, tensor_parallel_size=2,
enforce_eager=True) as vllm_model:
tp_output = vllm_model.generate_greedy(example_prompts, max_tokens)
with VllmRunner(model_name,
tensor_parallel_size=2,
enable_expert_parallel=True,
enforce_eager=True) as vllm_model:
ep_output = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=ep_output,
outputs_1_lst=tp_output,
name_0="ep_output",
name_1="tp_output",
)

View File

@@ -0,0 +1,187 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Compare the outputs of vLLM with and without aclgraph.
Run `pytest tests/multicard/test_external_launcher.py`.
"""
import os
import subprocess
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
import torch_npu
MODELS = ["Qwen/Qwen3-0.6B"]
MOE_MODELS = ["Qwen/Qwen3-30B-A3B"]
DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10]
@pytest.mark.parametrize("model", MODELS)
def test_external_launcher(model):
script = Path(
__file__
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
env = os.environ.copy()
# TODO: Change to 2 when ci machine has 4 cards
cmd = [
sys.executable,
str(script),
"--model",
model,
"--tp-size",
"1",
"--node-size",
"1",
"--node-rank",
"0",
"--proc-per-node",
"2",
"--trust-remote-code",
]
print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(
cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=600,
)
output = proc.stdout.decode()
print(output)
assert "TP RANKS: [0]" in output
assert "TP RANKS: [1]" in output
assert "Generated text:" in output
assert proc.returncode == 0
@pytest.mark.parametrize("model", MOE_MODELS)
def test_moe_external_launcher(model):
script = Path(
__file__
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
env = os.environ.copy()
# TODO: Change to 2 when ci machine has 4 cards
cmd = [
sys.executable,
str(script), "--model", model, "--tp-size", "2", "--node-size", "1",
"--node-rank", "0", "--proc-per-node", "2", "--trust-remote-code",
"--enable-expert-parallel"
]
print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(
cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=600,
)
output = proc.stdout.decode()
print(output)
assert "TP RANKS: [0, 1]" in output
assert "Generated text:" in output
assert proc.returncode == 0
def test_external_launcher_and_sleepmode():
script = Path(
__file__
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
env = os.environ.copy()
# TODO: Change to 2 when ci machine has 4 cards
cmd = [
sys.executable,
str(script),
"--model",
"Qwen/Qwen3-8B",
"--tp-size",
"1",
"--node-size",
"1",
"--node-rank",
"0",
"--proc-per-node",
"2",
"--trust-remote-code",
"--enable-sleep-mode",
"--temperature",
"0",
"--model-weight-gib",
"16",
]
print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(
cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=300,
)
output = proc.stdout.decode()
print(output)
assert "TP RANKS: [0]" in output
assert "TP RANKS: [1]" in output
assert "Generated text:" in output
assert "Sleep and wake up successfully!!" in output
assert proc.returncode == 0
@pytest.mark.skipif(
DEVICE_NAME != "Ascend910B",
reason="This test is only for Ascend910B devices.",
)
@pytest.mark.parametrize("model", MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": "1"})
def test_mm_allreduce(model):
script = Path(
__file__
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
env = os.environ.copy()
cmd = [
sys.executable,
str(script),
"--model",
model,
"--trust-remote-code",
]
print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(
cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=600,
)
output = proc.stdout.decode()
print(output)
assert "Generated text:" in output
assert proc.returncode == 0

View File

@@ -0,0 +1,86 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Execute the inference of fused_moe_allgather_ep and fused_moe_alltoall_ep.
Run 'pytest tests/multicard/test_fused_moe_allgather_ep.py'.
"""
import os
from unittest.mock import patch
import pytest
from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
@pytest.mark.skipif(
True,
reason=
"Current disaggregated pd implementation may cause memory pulse, which will cause this test OOM, skip this test until the ringmla is ready "
)
@patch.dict(
os.environ, {
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1",
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"
})
def test_generate_with_allgather():
example_prompts = ["Hello, my name is"]
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"),
tensor_parallel_size=2,
max_model_len=1024,
dtype="auto",
enable_expert_parallel=True,
additional_config={
"ascend_scheduler_config": {
"enabled": True,
"chunked_prefill_enabled": False,
},
}) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
@pytest.mark.skipif(
True,
reason=
"Current disaggregated pd implementation may cause memory pulse, which will cause this test OOM, skip this test until the ringmla is ready "
)
@patch.dict(os.environ, {
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1"
})
def test_generate_with_alltoall():
example_prompts = ["Hello, my name is"]
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"),
tensor_parallel_size=2,
max_model_len=1024,
dtype="auto",
enable_expert_parallel=True,
additional_config={
"ascend_scheduler_config": {
"enabled": True,
"chunked_prefill_enabled": False,
},
}) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)

View File

@@ -0,0 +1,23 @@
import pytest
from modelscope import snapshot_download # type: ignore
from tests.e2e.conftest import VllmRunner
from tests.e2e.singlecard.test_ilama_lora import (EXPECTED_LORA_OUTPUT,
MODEL_PATH, do_sample)
@pytest.mark.parametrize("distributed_executor_backend", ["mp"])
def test_ilama_lora_tp2(distributed_executor_backend, ilama_lora_files):
with VllmRunner(snapshot_download(MODEL_PATH),
enable_lora=True,
max_loras=4,
dtype="half",
max_model_len=1024,
max_num_seqs=16,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
output = do_sample(vllm_model.model, ilama_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output[i] == EXPECTED_LORA_OUTPUT[i]

View File

@@ -0,0 +1,152 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/test_offline_inference.py`.
"""
import os
from unittest.mock import patch
from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
def test_models_distributed_QwQ():
example_prompts = [
"Hello, my name is",
]
dtype = "half"
max_tokens = 5
with VllmRunner(
"Qwen/QwQ-32B",
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend="mp",
enforce_eager=True,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_DeepSeek_multistream_moe():
example_prompts = [
"Hello, my name is",
]
dtype = "half"
max_tokens = 5
with VllmRunner(
"vllm-ascend/DeepSeek-V3-Pruning",
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend="mp",
additional_config={
"torchair_graph_config": {
"enabled": True,
"enable_multistream_moe": True,
},
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
},
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_Qwen3_W8A8():
example_prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download("vllm-ascend/Qwen3-8B-W8A8"),
max_model_len=8192,
dtype="auto",
tensor_parallel_size=2,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_Qwen3_W4A8DYNAMIC():
example_prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download("vllm-ascend/Qwen3-8B-W4A8"),
max_model_len=8192,
dtype="auto",
tensor_parallel_size=2,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
@patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"})
def test_models_distributed_DeepSeek_W4A8DYNAMIC():
prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download("vllm-ascend/DeepSeek-V3-W4A8-Pruing"),
dtype="auto",
tensor_parallel_size=2,
quantization="ascend",
enforce_eager=True,
enable_expert_parallel=True,
additional_config={
"torchair_graph_config": {
"enabled": False,
},
"ascend_scheduler_config": {
"enabled": True,
}
},
) as vllm_model:
vllm_model.generate_greedy(prompts, max_tokens)
def test_sp_for_qwen3_moe() -> None:
example_prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(max_tokens=5,
temperature=0.0,
top_k=50,
top_p=0.9)
with VllmRunner(snapshot_download("Qwen/Qwen3-30B-A3B"),
dtype="auto",
tensor_parallel_size=2,
distributed_executor_backend="mp",
compilation_config={
"pass_config": {
"enable_sequence_parallelism": True
}
},
enable_expert_parallel=True,
enforce_eager=True) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)

View File

@@ -0,0 +1,46 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import pytest
from tests.e2e.conftest import VllmRunner
MODELS = [
"Qwen/Qwen3-0.6B",
]
TENSOR_PARALLELS = [1]
PIPELINE_PARALLELS = [2]
DIST_EXECUTOR_BACKEND = ["mp", "ray"]
prompts = [
"Hello, my name is",
"The future of AI is",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND)
def test_models(model: str, tp_size: int, pp_size: int,
distributed_executor_backend: str) -> None:
with VllmRunner(model,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
distributed_executor_backend=distributed_executor_backend,
gpu_memory_utilization=0.7) as vllm_model:
vllm_model.generate_greedy(prompts, 64)

View File

@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compare the with and without prefix caching on V1 scheduler or AscendScheduler."""
import pytest
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
MODELS = [
# for MHA
"Qwen/Qwen3-8B-Base",
# for MLA
"deepseek-ai/DeepSeek-V2-Lite-Chat"
]
# A prompt containing a large markdown table. The table is randomly generated by GPT-4.
LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """
| ID | Name | Age | Occupation | Country | Email | Phone Number | Address |
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL |
| 2 | Jane Smith | 34 | Doctor | Canada | jane.smith@example.com | 555-5678 | 456 Oak St, Toronto, ON |
| 3 | Alice Johnson | 27 | Teacher | UK | alice.j@example.com | 555-8765 | 789 Pine St, London, UK |
| 4 | Bob Brown | 45 | Artist | Australia | bob.b@example.com | 555-4321 | 321 Maple St, Sydney, NSW |
| 5 | Carol White | 31 | Scientist | New Zealand | carol.w@example.com | 555-6789 | 654 Birch St, Wellington, NZ |
| 6 | Dave Green | 28 | Lawyer | Ireland | dave.g@example.com | 555-3456 | 987 Cedar St, Dublin, IE |
| 7 | Emma Black | 40 | Musician | USA | emma.b@example.com | 555-1111 | 246 Ash St, New York, NY |
| 8 | Frank Blue | 37 | Chef | Canada | frank.b@example.com | 555-2222 | 135 Spruce St, Vancouver, BC |
| 9 | Grace Yellow | 50 | Engineer | UK | grace.y@example.com | 555-3333 | 864 Fir St, Manchester, UK |
| 10 | Henry Violet | 32 | Artist | Australia | henry.v@example.com | 555-4444 | 753 Willow St, Melbourne, VIC|
| 11 | Irene Orange | 26 | Scientist | New Zealand | irene.o@example.com | 555-5555 | 912 Poplar St, Auckland, NZ |
| 12 | Jack Indigo | 38 | Teacher | Ireland | jack.i@example.com | 555-6666 | 159 Elm St, Cork, IE |
| 13 | Karen Red | 41 | Lawyer | USA | karen.r@example.com | 555-7777 | 357 Cedar St, Boston, MA |
| 14 | Leo Brown | 30 | Chef | Canada | leo.b@example.com | 555-8888 | 246 Oak St, Calgary, AB |
| 15 | Mia Green | 33 | Musician | UK | mia.g@example.com | 555-9999 | 975 Pine St, Edinburgh, UK |
| 16 | Noah Yellow | 29 | Doctor | Australia | noah.y@example.com | 555-0000 | 864 Birch St, Brisbane, QLD |
| 17 | Olivia Blue | 35 | Engineer | New Zealand | olivia.b@example.com | 555-1212 | 753 Maple St, Hamilton, NZ |
| 18 | Peter Black | 42 | Artist | Ireland | peter.b@example.com | 555-3434 | 912 Fir St, Limerick, IE |
| 19 | Quinn White | 28 | Scientist | USA | quinn.w@example.com | 555-5656 | 159 Willow St, Seattle, WA |
| 20 | Rachel Red | 31 | Teacher | Canada | rachel.r@example.com | 555-7878 | 357 Poplar St, Ottawa, ON |
| 21 | Steve Green | 44 | Lawyer | UK | steve.g@example.com | 555-9090 | 753 Elm St, Birmingham, UK |
| 22 | Tina Blue | 36 | Musician | Australia | tina.b@example.com | 555-1213 | 864 Cedar St, Perth, WA |
| 23 | Umar Black | 39 | Chef | New Zealand | umar.b@example.com | 555-3435 | 975 Spruce St, Christchurch, NZ|
| 24 | Victor Yellow | 43 | Engineer | Ireland | victor.y@example.com | 555-5657 | 246 Willow St, Galway, IE |
| 25 | Wendy Orange | 27 | Artist | USA | wendy.o@example.com | 555-7879 | 135 Elm St, Denver, CO |
| 26 | Xavier Green | 34 | Scientist | Canada | xavier.g@example.com | 555-9091 | 357 Oak St, Montreal, QC |
| 27 | Yara Red | 41 | Teacher | UK | yara.r@example.com | 555-1214 | 975 Pine St, Leeds, UK |
| 28 | Zack Blue | 30 | Lawyer | Australia | zack.b@example.com | 555-3436 | 135 Birch St, Adelaide, SA |
| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ |
| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE |
"""
INPUT_PROMPTS = [
LONG_PROMPT +
"Question: what is the age of John Doe? Your answer: The age of John Doe is ",
LONG_PROMPT +
"Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is "
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [50])
def test_prefix_cache_with_v1_scheduler(model: str, max_tokens: int) -> None:
with VllmRunner(model,
enforce_eager=True,
max_model_len=2048,
tensor_parallel_size=2,
gpu_memory_utilization=0.7) as vllm_model:
prefix_cache_output = vllm_model.generate_greedy(
INPUT_PROMPTS, max_tokens)
with VllmRunner(model,
enable_prefix_caching=False,
enforce_eager=True,
max_model_len=2048,
tensor_parallel_size=2,
gpu_memory_utilization=0.7) as vllm_model:
vllm_output = vllm_model.generate_greedy(INPUT_PROMPTS, max_tokens)
check_outputs_equal(
outputs_0_lst=vllm_output,
outputs_1_lst=prefix_cache_output,
name_0="vllm_output",
name_1="prefix_cache_output",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [50])
def test_prefix_cache_with_ascend_scheduler(model: str,
max_tokens: int) -> None:
with VllmRunner(model,
additional_config={
'ascend_scheduler_config': {
'enabled': True,
},
},
enforce_eager=True,
max_model_len=2048,
tensor_parallel_size=2,
gpu_memory_utilization=0.7) as vllm_model:
vllm_output = vllm_model.generate_greedy(INPUT_PROMPTS, max_tokens)
with VllmRunner(model,
additional_config={
'ascend_scheduler_config': {
'enabled': True,
'enable_prefix_caching': True,
},
},
enforce_eager=True,
max_model_len=2048,
tensor_parallel_size=2,
gpu_memory_utilization=0.7) as vllm_model:
prefix_cache_output = vllm_model.generate_greedy(
INPUT_PROMPTS, max_tokens)
with VllmRunner(model,
additional_config={
'ascend_scheduler_config': {
'enabled': True,
'enable_prefix_caching': True,
"enable_chunked_prefill": True,
},
},
enforce_eager=True,
max_model_len=2048,
tensor_parallel_size=2,
gpu_memory_utilization=0.7) as vllm_model:
chunk_prefill_prefix_cache_output = vllm_model.generate_greedy(
INPUT_PROMPTS, max_tokens)
check_outputs_equal(
outputs_0_lst=vllm_output,
outputs_1_lst=prefix_cache_output,
name_0="vllm_output",
name_1="prefix_cache_output",
)
check_outputs_equal(
outputs_0_lst=chunk_prefill_prefix_cache_output,
outputs_1_lst=prefix_cache_output,
name_0="chunk_prefill_prefix_cache_output",
name_1="prefix_cache_output",
)

View File

@@ -0,0 +1,104 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/e2e/multicard/test_qwen3_moe.py`.
"""
import os
from modelscope import snapshot_download # type: ignore
from tests.e2e.conftest import VllmRunner
def test_models_distributed_Qwen3_MOE_TP2():
example_prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
tensor_parallel_size=2,
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_Qwen3_MOE_TP2_WITH_EP():
example_prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
tensor_parallel_size=2,
enable_expert_parallel=True,
distributed_executor_backend="mp",
enforce_eager=False,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_Qwen3_MOE_W8A8():
example_prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download("vllm-ascend/Qwen3-30B-A3B-W8A8"),
max_model_len=8192,
tensor_parallel_size=2,
quantization="ascend",
enforce_eager=True,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH_AIV():
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
example_prompts = [
"Hello, my name is",
]
dtype = "auto"
max_tokens = 5
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=False,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH():
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE']
example_prompts = [
"Hello, my name is",
]
dtype = "auto"
max_tokens = 5
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=False,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -0,0 +1,224 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/multicard/test_torchair_graph_mode.py`.
"""
import os
from typing import Dict
from tests.e2e.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
def _deepseek_torchair_test_fixture(
additional_config: Dict,
*,
tensor_parallel_size=2,
use_v1_schduler=False,
):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
kwargs = {}
if not use_v1_schduler:
kwargs = {
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
}
additional_config.update(**kwargs)
with VllmRunner(
"vllm-ascend/DeepSeek-V3-Pruning",
dtype="half",
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="mp",
additional_config=additional_config,
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
# inaccurate. This will only change if accuracy improves with the
# official weights of DeepSeek-V3.
golden_results = [
'Hello, my name is下载早点向前很有่อง',
'The president of the United States isSender)## physiological Albany',
'The capital of France is Rocky转角 hospitalizedinterval sparked',
'The future of AI is её asegο BIOS一扫',
]
assert len(golden_results) == len(vllm_output)
for i in range(len(vllm_output)):
assert golden_results[i] == vllm_output[i][1]
print(f"Generated text: {vllm_output[i][1]!r}")
def test_e2e_deepseekv3_with_torchair():
additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
_deepseek_torchair_test_fixture(additional_config)
def test_e2e_deepseekv3_with_torchair_ms_mla():
additional_config = {
"torchair_graph_config": {
"enabled": True,
"enable_multistream_mla": True,
},
}
_deepseek_torchair_test_fixture(additional_config)
def test_e2e_deepseekv3_with_torchair_v1scheduler():
additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
_deepseek_torchair_test_fixture(additional_config, use_v1_schduler=True)
def _pangu_torchair_test_fixture(
additional_config: Dict,
*,
tensor_parallel_size=2,
):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# torchair is only work without chunked-prefill now
kwargs = {
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
}
additional_config.update(**kwargs)
with VllmRunner(
"vllm-ascend/pangu-pro-moe-pruing",
dtype="half",
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="mp",
additional_config=additional_config,
enable_expert_parallel=True,
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
# with 2 hidden layers, thus the golden results seems inaccurate.
# This will only change if accuracy changes with the official weights
# of PanguProMoE.
golden_results = [
'Hello, my name is Remempondeprecatedmiot忱',
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
'The capital of France is Rememvoud administrativ Remem投',
'The future of AI isotope Segnali Zoeken精细化 supus',
]
assert len(golden_results) == len(vllm_output)
for i in range(len(vllm_output)):
assert golden_results[i] == vllm_output[i][1]
print(f"Generated text: {vllm_output[i][1]!r}")
def test_e2e_pangu_with_torchair():
additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
_pangu_torchair_test_fixture(additional_config)
def _qwen_torchair_test_fixture(
model,
tp,
enable_expert_parallel,
):
# The current access control does not support 16 cards,
# so the MC2 operator in Qwen's graph mode cannot run.
# Once 16-card support is available,
# this e2e can be switched to graph mode.
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
}
with VllmRunner(
model,
dtype="half",
tensor_parallel_size=tp,
distributed_executor_backend="mp",
enforce_eager=True,
additional_config=additional_config,
enable_expert_parallel=enable_expert_parallel,
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
# with 2 hidden layers, thus the golden results seems inaccurate.
# This will only change if accuracy changes with the official weights
# of PanguProMoE.
golden_results = [
'Hello, my name is Remempondeprecatedmiot忱',
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
'The capital of France is Rememvoud administrativ Remem投',
'The future of AI isotope Segnali Zoeken精细化 supus',
]
assert len(golden_results) == len(vllm_output)
for i in range(len(vllm_output)):
print(f"Generated text: {vllm_output[i][1]!r}")
def test_e2e_qwen2_with_torchair():
_qwen_torchair_test_fixture("Qwen/Qwen2.5-0.5B-Instruct", 2, False)
def test_e2e_qwen3_moe_with_torchair():
_qwen_torchair_test_fixture("Qwen/Qwen3-30B-A3B", 2, True)

View File

@@ -0,0 +1,141 @@
#!/bin/bash
export LCCL_DETERMINISTIC=1
export HCCL_DETERMINISTIC=true
export CLOSE_MATMUL_K_SHIFT=1
export VLLM_USE_V1=1
set -xe
# Models to run
MODELS=(
"Qwen/Qwen3-0.6B-Instruct"
)
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Gen ranktable
RANKTABLE_PATH=${GIT_ROOT}/examples/disaggregate_prefill_v1/ranktable.json
if [ -f "$RANKTABLE_PATH" ]; then
rm "$RANKTABLE_PATH"
fi
cd ${GIT_ROOT}/examples/disaggregate_prefill_v1
LOCAL_HOST=`hostname -I|awk -F " " '{print$1}'`
bash gen_ranktable.sh --ips $LOCAL_HOST --network-card-name enp189s0f0 --prefill-device-cnt 1 --decode-device-cnt 1
cd -
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH="$RANKTABLE_PATH"
# Waits for vLLM to start.
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/health > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Function to clean up previous instances
cleanup_instances() {
echo "Cleaning up any running vLLM instances..."
pkill -f "vllm serve" || true
sleep 2
}
# Handle to get model-specific arguments for deepseek
get_model_args() {
local model_name=$1
local extra_args=""
if [[ "$model_name" == *"deepseek"* ]]; then
extra_args="--trust-remote-code"
fi
echo "$extra_args"
}
# Function to run tests for a specific model
run_tests_for_model() {
local model_name=$1
echo "================================"
echo "Testing model: $model_name"
echo "================================"
# Get model-specific arguments
local model_args=$(get_model_args "$model_name")
# Start prefill instance
PREFILL_PORT=8001
BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=0 VLLM_LLMDD_RPC_PORT=5559 vllm serve $model_name \
--port $PREFILL_PORT \
--seed 1024 \
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.8 \
--kv-transfer-config '{\"kv_connector\":\"LLMDataDistCMgrConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"npu\",\"kv_parallel_size\":\"1\",\"kv_port\":\"20001\",\"engine_id\":\"0\",\"kv_connector_module_path\":\"vllm_ascend.distributed.llmdatadist_c_mgr_connector\"}'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Start decode instance
DECODE_PORT=8002
# Build the command with or without model-specific args
BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=1 VLLM_LLMDD_RPC_PORT=6000 vllm serve $model_name \
--port $DECODE_PORT \
--seed 1024 \
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.8 \
--kv-transfer-config '{\"kv_connector\":\"LLMDataDistCMgrConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"npu\",\"kv_parallel_size\":\"1\",\"kv_port\":\"20001\",\"engine_id\":\"0\",\"kv_connector_module_path\":\"vllm_ascend.distributed.llmdatadist_c_mgr_connector\"}'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Wait for all instances to start
echo "Waiting for prefill instance on port $PORT to start..."
wait_for_server $PREFILL_PORT
echo "Waiting for decode instance on port $PORT to start..."
wait_for_server $DECODE_PORT
# Build the command for the proxy server with all the hosts and ports
PROXY_PORT=8192
PROXY_CMD="python ${GIT_ROOT}/examples/disaggregate_prefill_v1/toy_proxy_server.py --port $PROXY_PORT"
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}"
PROXY_CMD+=" --decoder-ports ${DECODE_PORT}"
# Start the proxy server
echo "Starting proxy server with command: $PROXY_CMD"
$PROXY_CMD &
# Wait for the proxy to start
sleep 5
# Run lm eval for this model
echo "Running tests for $model_name"
PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/e2e/pd_disaggreate/test_edge_cases.py
# Clean up before running next model
cleanup_instances
sleep 3
}
# Run tests for each model
for model in "${MODELS[@]}"; do
run_tests_for_model "$model"
done
echo "All tests completed!"

View File

@@ -0,0 +1,136 @@
#!/bin/bash
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
function run_prefill_instance() {
local model_name=$1
local tp_size=$2
local prefill_port=$3
local register_port=$4
local prefill_device_ips=$5
local decode_device_ips=$6
echo "================================"
echo "Testing model: $model_name"
echo "================================"
# Start prefill instance
KV_CONFIG=$(jq -n \
--arg kv_connector "AscendSimpleConnector" \
--arg kv_buffer_device "npu" \
--arg kv_role "kv_producer" \
--argjson kv_parallel_size 8 \
--arg kv_port 11001 \
--argjson prefill_device_ips "$prefill_device_ips" \
--argjson decode_device_ips "$decode_device_ips" \
--argjson llmdatadist_comm_port 26000 \
--arg proxy_ip "0.0.0.0" \
--argjson proxy_port "$register_port" \
--argjson http_port "$prefill_port" \
'{
"kv_connector": $kv_connector,
"kv_buffer_device": $kv_buffer_device,
"kv_role": $kv_role,
"kv_parallel_size": $kv_parallel_size,
"kv_port": $kv_port,
"kv_connector_extra_config": {
"prefill_device_ips": $prefill_device_ips,
"decode_device_ips": $decode_device_ips,
"llmdatadist_comm_port": $llmdatadist_comm_port,
"proxy_ip": $proxy_ip,
"proxy_port": $proxy_port,
"http_port": $http_port
}
}')
# start prefill instance
ASCEND_RT_VISIBLE_DEVICES=0 vllm serve $model_name \
--host 0.0.0.0 \
--port $prefill_port \
--tensor-parallel-size $tp_size \
--served-model-name Deepseek \
--max-model-len 2000 \
--trust-remote-code \
--enforce-eager \
--kv-transfer-config "$KV_CONFIG"
}
function run_decode_instance() {
# Start decode instance
local model_name=$1
local tp_size=$2
local decode_port=$3
local register_port=$4
local prefill_device_ips=$5
local decode_device_ips=$6
KV_CONFIG=$(jq -n \
--arg kv_connector "AscendSimpleConnector" \
--arg kv_buffer_device "npu" \
--arg kv_role "kv_consumer" \
--argjson kv_parallel_size 8 \
--arg kv_port 21001 \
--argjson prefill_device_ips "$prefill_device_ips" \
--argjson decode_device_ips "$decode_device_ips" \
--argjson llmdatadist_comm_port 26000 \
--arg proxy_ip "0.0.0.0" \
--argjson proxy_port "$register_port" \
--argjson http_port "$decode_port" \
'{
"kv_connector": $kv_connector,
"kv_buffer_device": $kv_buffer_device,
"kv_role": $kv_role,
"kv_parallel_size": $kv_parallel_size,
"kv_port": $kv_port,
"kv_connector_extra_config": {
"prefill_device_ips": $prefill_device_ips,
"decode_device_ips": $decode_device_ips,
"llmdatadist_comm_port": $llmdatadist_comm_port,
"proxy_ip": $proxy_ip,
"proxy_port": $proxy_port,
"http_port": $http_port
}
}')
# start decode instance
ASCEND_RT_VISIBLE_DEVICES=1 vllm serve $model_name \
--host 0.0.0.0 \
--port $decode_port \
--tensor-parallel-size $tp_size \
--seed 1024 \
--served-model-name Deepseek \
--max-model-len 2000 \
--max-num-batched-tokens 2000 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--enforce-eager \
--kv-transfer-config "$KV_CONFIG"
}
function run_proxy_server() {
# Build the command for the proxy server with all the hosts and ports
register_port=$1
proxy_port=$2
PROXY_CMD="python examples/disaggregated_prefill/p2p_disaggrefated_prefill_proxy.py --http-port $proxy_port --register-port $register_port"
# Start the proxy server
echo "Starting proxy server with command: $PROXY_CMD"
$PROXY_CMD &
}

View File

@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# This code is from: https://github.com/vllm-project/vllm/blob/main/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import os
import openai
PREFILL_PORT = os.getenv("PREFILL_PORT", None)
DECODE_PORT = os.getenv("DECODE_PORT", None)
PROXY_PORT = os.getenv("PROXY_PORT", None)
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
raise ValueError(
"Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.")
LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501
PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501
SHORT_PROMPT = "Red Hat is "
def test_edge_cases():
# Set the OpenAI API key and base URL
decode_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{DECODE_PORT}/v1",
)
prefill_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{PREFILL_PORT}/v1",
)
proxy_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{PROXY_PORT}/v1",
)
# Get the list of models
models = decode_client.models.list()
MODEL = models.data[0].id
# (1) Check that we can handle a very short prompt,
# less than the length of the block size.
completion = proxy_client.completions.create(model=MODEL,
prompt=SHORT_PROMPT,
temperature=0)
proxy_response = completion.choices[0].text
completion = prefill_client.completions.create(model=MODEL,
prompt=SHORT_PROMPT,
temperature=0)
prefill_response = completion.choices[0].text
print(f"SMALL PROMPT: {proxy_response=}")
print(f"SMALL PROMPT: {prefill_response=}")
assert proxy_response == prefill_response
# (2) Check that we can handle a full prefix cache
# hit on the D worker but not on the P worker.
# (2a): prime the D worker.
completion = decode_client.completions.create(model=MODEL,
prompt=PROMPT,
temperature=0)
decode_response = completion.choices[0].text
# (2b): send via the P/D setup
completion = proxy_client.completions.create(model=MODEL,
prompt=PROMPT,
temperature=0)
proxy_response = completion.choices[0].text
print(f"FULL CACHE HIT: {proxy_response=}")
assert proxy_response == decode_response
# (3) Check that we can handle a partial prefix cache
# hit on the D worker.
completion = proxy_client.completions.create(model=MODEL,
prompt=LONG_PROMPT,
temperature=0)
proxy_response = completion.choices[0].text
completion = prefill_client.completions.create(model=MODEL,
prompt=LONG_PROMPT,
temperature=0)
prefill_response = completion.choices[0].text
print(f"PARTIAL CACHE HIT: {proxy_response=}")
assert proxy_response == prefill_response

View File

@@ -0,0 +1,109 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import os
import signal
import subprocess
import time
import psutil
import requests
def kill_process_and_children(pid):
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
for child in children:
print(f"Killing child process {child.pid}")
child.kill()
print(f"Killing parent process {pid}")
parent.kill()
except psutil.NoSuchProcess:
pass
def kill_all_vllm_related():
current_pid = os.getpid()
for proc in psutil.process_iter(['pid', 'cmdline']):
try:
if proc.pid == current_pid:
continue
cmd = ' '.join(proc.info['cmdline'])
if "vllm" in cmd or "proxy" in cmd or "engine_worker" in cmd:
kill_process_and_children(proc.pid)
except Exception:
continue
PROXY_PORT = 10102
DECODE_PORT = 8002
SCRIPT_PATH = os.path.abspath("./tests/e2e/run_disagg_pd.sh")
def wait_for_port(port, timeout=30):
import socket
start = time.time()
while time.time() - start < timeout:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
if sock.connect_ex(("127.0.0.1", port)) == 0:
return True
time.sleep(1)
raise TimeoutError(f"Port {port} not ready after {timeout}s")
def start_and_test_pipeline():
print("Launching bash script to run vLLM PD setup...")
proc = subprocess.Popen(["bash", SCRIPT_PATH])
try:
print("Waiting for proxy port to be available...")
wait_for_port(PROXY_PORT, 180)
wait_for_port(DECODE_PORT, 600)
# request
payload = {
"model": "Deepseek",
"prompt": "The future of AI is",
"max_tokens": 64,
"temperature": 0,
}
response = requests.post(
f"http://localhost:{PROXY_PORT}/v1/completions",
headers={"Content-Type": "application/json"},
json=payload,
timeout=10)
assert response.status_code == 200, f"HTTP failed: {response.status_code}"
result = response.json()
print("Response:", result)
assert "text" in result["choices"][0]
assert len(result["choices"][0]["text"].strip()) > 0
finally:
# clean up subprocesses
print("Cleaning up subprocess...")
proc.send_signal(signal.SIGINT)
try:
proc.wait(timeout=10)
except subprocess.TimeoutExpired:
proc.kill()
kill_all_vllm_related()
def test_disaggregated_pd_pipeline():
start_and_test_pipeline()

View File

@@ -0,0 +1,8 @@
vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.
Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.
Compare and contrast artificial intelligence with human intelligence in terms of processing information.
Describe the basic components of a neural network and how it can be trained.
Write a short story about a robot that dreams for the first time.
Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.
Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.
Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'

View File

@@ -0,0 +1,58 @@
#!/bin/bash
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
set -eo errexit
. $(dirname "$0")/common.sh
. $(dirname "$0")/pd_disaggreate/setup_pd.sh
export VLLM_USE_MODELSCOPE="True"
MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite"
# TODO: add tp case
TP_SIZE=1
# TODO: support multi-card
prefill_ip=$(/usr/local/Ascend/driver/tools/hccn_tool -i 0 -ip -g | grep "ipaddr" | awk -F: '{print $2}' | xargs)
PREFILL_DEVICE_IPS="[\"$prefill_ip\"]"
decode_ip=$(/usr/local/Ascend/driver/tools/hccn_tool -i 1 -ip -g | grep "ipaddr" | awk -F: '{print $2}' | xargs)
DECODE_DEVICE_IPS="[\"$decode_ip\"]"
_info "====> Start pd disaggregated test"
REGISTER_PORT=10101
PREOXY_PORT=10102
run_proxy_server $REGISTER_PORT $PREOXY_PORT
_info "Started pd disaggregated proxy server"
PREFILL_PROC_NAME="Prefill-instance"
PREFILL_PORT=8001
_info "Starting prefill instance"
run_prefill_instance $MODEL_NAME $TP_SIZE $PREFILL_PORT $REGISTER_PORT $PREFILL_DEVICE_IPS $DECODE_DEVICE_IPS &
_info "Waiting for prefill instance ready"
wait_url_ready $PREFILL_PROC_NAME "http://localhost:${PREFILL_PORT}/v1/completions"
DECODE_PROC_NAME="Decode-instance"
DECODE_PORT=8002
_info "Starting decode instance"
run_decode_instance $MODEL_NAME $TP_SIZE $DECODE_PORT $REGISTER_PORT $PREFILL_DEVICE_IPS $DECODE_DEVICE_IPS &
_info "Waiting for decode instance ready"
wait_url_ready $DECODE_PROC_NAME "http://localhost:${DECODE_PORT}/v1/completions"
_info "pd disaggregated system is ready for handling request"

33
tests/e2e/run_doctests.sh Executable file
View File

@@ -0,0 +1,33 @@
#!/bin/bash
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
set -eo errexit
. $(dirname "$0")/common.sh
export VLLM_USE_MODELSCOPE=true
export VLLM_LOGGING_LEVEL=ERROR
_info "====> Start Quickstart test"
. "${SCRIPT_DIR}/doctests/001-quickstart-test.sh"
_info "====> Start pip binary installation test"
. "${SCRIPT_DIR}/doctests/002-pip-binary-installation-test.sh"
_info "Doctest passed."

View File

View File

View File

@@ -0,0 +1,46 @@
import gc
import torch
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def bgmv_expand_cpu_impl(x: torch.Tensor, w: torch.Tensor,
indices: torch.Tensor, y: torch.tensor,
slice_offset: int, slice_size: int) -> torch.Tensor:
W = w[indices, :, :].transpose(-1, -2).to(torch.float32)
z = torch.bmm(x.unsqueeze(1).to(torch.float32), W).squeeze()
y[:, slice_offset:slice_offset + slice_size] += z
return y
@torch.inference_mode()
def test_bgmv_expand():
B = 1
x = torch.randn([B, 16], dtype=torch.float)
w = torch.randn([64, 128, 16], dtype=torch.float16)
indices = torch.zeros([B], dtype=torch.int64)
y = torch.randn([B, 128 * 3], dtype=torch.float16)
x_npu = x.npu()
w_npu = w.npu()
indices_npu = indices.npu()
y_npu = y.npu()
y_out = bgmv_expand_cpu_impl(x, w, indices, y, 0, 128)
y_out_npu = torch.ops._C.bgmv_expand(x_npu, w_npu, indices_npu, y_npu, 0,
128)
# Compare the results.
torch.testing.assert_close(y_out_npu.cpu(),
y_out,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,45 @@
import gc
import torch
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def bgmv_shrink_cpu_impl(x: torch.Tensor, w: torch.Tensor,
indices: torch.Tensor, y: torch.tensor,
scaling: float) -> torch.Tensor:
W = w[indices, :, :].transpose(-1, -2).to(torch.float32)
z = torch.bmm(x.unsqueeze(1).to(torch.float32), W).squeeze()
y[:, :] += z * scaling
return y
@torch.inference_mode()
def test_bgmv_shrink():
B = 1
x = torch.randn([B, 128], dtype=torch.float16)
w = torch.randn([64, 16, 128], dtype=torch.float16)
indices = torch.zeros([B], dtype=torch.int64)
y = torch.zeros([B, 16])
x_npu = x.npu()
w_npu = w.npu()
indices_npu = indices.npu()
y_npu = y.npu()
y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5)
torch.ops._C.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5)
# Compare the results.
torch.testing.assert_close(y_npu.cpu(),
y,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,284 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
"""Tests for the MOE layers.
Run `pytest tests/ops/test_fused_moe.py`.
"""
import gc
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch_npu
from vllm.model_executor.layers.activation import SiluAndMul
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
TokenDispatcherWithAllGather
NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
TOP_KS = [2, 6]
DEVICE = ["npu"]
def apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
) -> torch.Tensor:
w1 = w1.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
w2 = w2.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
return hidden_states
def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weights = topk_weights.view(-1)
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("device", DEVICE)
def test_token_dispatcher_with_all_gather(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
device: str,
):
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
score = torch.randn((m, e), device=device, dtype=dtype)
expert_map = None
local_e = e
w1_local = w1
w2_local = w2
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.arange(local_e * 0,
local_e * (0 + 1),
device=device,
dtype=torch.int32)
expert_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
expert_map[e_ids] = torch.arange(local_e,
device=device,
dtype=torch.int32)
w1_local = w1[e_ids]
w2_local = w2[e_ids]
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
row_idx = (torch.arange(
0,
m * topk,
device=device,
dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous())
dispatcher_kwargs = {
"num_experts": e,
"top_k": topk,
"num_local_experts": local_e,
}
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
apply_router_weight_on_input = False
dispatch_output = dispatcher.token_dispatch(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
w1=w1_local,
w2=w2_local,
group_list=group_list,
group_list_type=group_list_type)
combined_output = dispatcher.token_combine(hidden_states=expert_output,
bias=None)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
expert_map)
torch.testing.assert_close(combined_output,
torch_output,
atol=4e-2,
rtol=1)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("m", [1, 33, 64])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("use_grouped_topk", [True, False])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("with_e_correction", [True, False])
@pytest.mark.parametrize("custom_routing", [True, False])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts(
m: int,
n: int,
e: int,
topk: int,
scoring_func: str,
use_grouped_topk: bool,
renormalize: bool,
with_e_correction: bool,
custom_routing: bool,
dtype: torch.dtype,
device: str,
):
topk_group = 4 if use_grouped_topk else None
num_expert_group = e // 4 if use_grouped_topk else None
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
router_logits = torch.randn(m, e, device=device, dtype=dtype)
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
if with_e_correction else None)
custom_routing_function = None
if custom_routing:
custom_routing_function = MagicMock()
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
mock_ids = torch.randint(0,
e, (m, topk),
device=device,
dtype=torch.int32)
custom_routing_function.return_value = (mock_weights, mock_ids)
with patch("vllm_ascend.ops.layers.experts_selector._native_grouped_topk"
) as mock_native_grouped_topk:
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if use_grouped_topk:
mock_native_grouped_topk.assert_called_once()
else:
mock_native_grouped_topk.assert_not_called()
assert topk_weights.shape == (m, topk)
assert topk_ids.shape == (m, topk)
assert topk_ids.dtype == torch.int32
assert row_idx.shape == (m, topk)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_invalid_scoring_func(device: str):
with pytest.raises(ValueError,
match="Unsupported scoring function: invalid"):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 8, device=device),
top_k=2,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_missing_group_params(device: str):
with pytest.raises(AssertionError):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 64, device=device),
top_k=2,
use_grouped_topk=True,
renormalize=False,
scoring_func="softmax")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,37 @@
import pytest
import torch
import torch_npu
@pytest.mark.parametrize(
'B',
[1, 16, 64, 128, 32768],
)
@pytest.mark.parametrize(
'D',
[8, 16, 32, 64, 128],
)
@pytest.mark.parametrize(
'top_k',
[1, 2, 4, 8],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.float16, 1e-3, 1e-3),
(torch.bfloat16, 1e-3, 1e-3),
],
)
def test_quant_fpx_linear(B: int, D: int, top_k: int, dtype, atol, rtol):
x = torch.rand((B, D), dtype=dtype).to("npu")
# finished = torch.randint(1, size=(B,), dtype=torch.bool).to("npu")
finished = None
y, expert_idx, row_idx = torch_npu.npu_moe_gating_top_k_softmax(x,
finished,
k=top_k)
topk_weights = x.softmax(dim=-1)
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_ids = topk_ids.to(torch.int32)
torch.allclose(y, topk_weights, atol=atol, rtol=rtol)
torch.allclose(expert_idx, topk_ids, atol=atol, rtol=rtol)

View File

@@ -0,0 +1,175 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
import gc
from types import SimpleNamespace
import pytest
import torch
from vllm.model_executor.layers.fused_moe.config import ( # isort: skip
FusedMoEConfig, FusedMoEParallelConfig)
from vllm_ascend.distributed.moe_comm_method import ( # isort: skip
AllGatherCommImpl, NativeAllGatherCommImpl)
@pytest.mark.parametrize("num_tokens", [16, 128])
@pytest.mark.parametrize("hidden_size", [64, 128])
@pytest.mark.parametrize("global_num_experts", [8, 16])
@pytest.mark.parametrize("num_local_experts", [4, 8])
@pytest.mark.parametrize("top_k_num", [2, 4])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("ep_rank", [0, 1])
@pytest.mark.parametrize("apply_a8_quantization", [False])
def test_all_gather_comm_impl(
num_tokens,
hidden_size,
global_num_experts,
num_local_experts,
top_k_num,
dtype,
ep_rank,
apply_a8_quantization,
mocker,
):
"""
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
This test compares the outputs of the NPU-optimized AllGatherCommImpl
with a native PyTorch implementation (NativeAllGatherCommImpl) to ensure
correctness across various configurations.
"""
if top_k_num > global_num_experts:
pytest.skip("top_k_num cannot be greater than global_num_experts")
if num_local_experts > global_num_experts:
pytest.skip(
"num_local_experts cannot be greater than global_num_experts")
device = torch.device("npu")
# mock get_tensor_model_parallel_rank to return ep_rank
mocker.patch(
"vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank",
return_value=ep_rank,
)
# make moe config
parallel_config = SimpleNamespace(
enable_expert_parallel=num_local_experts < global_num_experts)
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=max(2, global_num_experts // num_local_experts),
dp_size_=1,
vllm_parallel_config=parallel_config,
)
moe_config = FusedMoEConfig(
num_experts=global_num_experts,
experts_per_token=top_k_num,
hidden_dim=hidden_size,
num_local_experts=num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=dtype,
quant_config=None, # No quantization in this test
max_num_tokens=num_tokens,
)
# Instantiate implementations
native_impl = NativeAllGatherCommImpl(moe_config)
all_gather_impl = AllGatherCommImpl(moe_config)
# --- Input Data ---
hidden_states = torch.randn(num_tokens,
hidden_size,
device=device,
dtype=dtype)
topk_ids = torch.randint(0,
global_num_experts, (num_tokens, top_k_num),
device=device,
dtype=torch.int32)
topk_weights = torch.rand(num_tokens, top_k_num, device=device).to(dtype)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=1)
num_experts = global_num_experts
expert_map = None
if num_local_experts < global_num_experts:
# Create a map where some experts are local and some are not
expert_map = torch.full((global_num_experts, ), -1, device=device)
expert_map[ep_rank * num_local_experts:(ep_rank + 1) *
num_local_experts] = torch.arange(num_local_experts,
device=device)
num_experts = num_local_experts
# --- Run Native Implementation (Golden Reference) ---
native_hidden_states_out = hidden_states.clone()
(
native_permuted_hidden,
native_expert_tokens,
_,
_,
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
num_experts, apply_a8_quantization)
# Simulate MLP output
native_mlp_output = torch.randn_like(native_permuted_hidden)
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
# --- Run AllGather Implementation ---
all_gather_hidden_states_out = hidden_states.clone()
(
all_gather_permuted_hidden,
all_gather_expert_tokens,
_,
_,
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
expert_map, num_experts, apply_a8_quantization)
# Use the same simulated MLP output for a fair comparison
all_gather_mlp_output = native_mlp_output.clone()
all_gather_impl.unpermute(all_gather_mlp_output,
all_gather_hidden_states_out)
# --- Assertions ---
# Define tolerance based on dtype
atol = 1e-3 if dtype == torch.float16 else 1e-2
rtol = 1e-3 if dtype == torch.float16 else 1e-2
# 1. Compare expert_tokens from pre_process
assert torch.allclose(native_expert_tokens.to(
all_gather_expert_tokens.device),
all_gather_expert_tokens,
atol=atol,
rtol=rtol), "Expert tokens do not match."
# 2. Compare permuted_hidden_states from pre_process
num_valid_tokens = native_expert_tokens.sum()
assert torch.allclose(native_permuted_hidden[:num_valid_tokens].to(
all_gather_permuted_hidden.device),
all_gather_permuted_hidden[:num_valid_tokens],
atol=atol,
rtol=rtol), "Permuted hidden states do not match."
# 3. Compare final hidden_states from post_process
assert torch.allclose(native_hidden_states_out.to(
all_gather_hidden_states_out.device),
all_gather_hidden_states_out,
atol=atol,
rtol=rtol), "Final hidden states do not match."
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,351 @@
# Copyright 2023 The vLLM team.
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py
import gc
from typing import Optional, Tuple, Union
import pytest
import torch
import torch.nn as nn
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
# Only Neox style true scenario is supported for now
IS_NEOX_STYLE = [True]
DTYPES = [torch.half]
HEAD_SIZES = [64, 64, 96, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [17] # Arbitrary values for testing
BATCH_SIZES = [5] # Arbitrary values for testing
SEQ_LENS = [11, 4096] # Arbitrary values for testing
NUM_TOKENS = [10, 21]
SEEDS = [0]
DEVICES = [f"npu:{0}"]
# Set tolerance to 1 for quant ops
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
# adapted from https://github.com/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding.py
class RotaryEmbedding(nn.Module):
"""Original rotary positional embedding."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
# test with leading dimension and merge seqlen and batch_size as num_tokens
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_rotary_embedding_quant_with_leading_dim(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
if rotary_dim is None:
rotary_dim = head_size
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
rope = rope.to(dtype=dtype)
num_tokens = batch_size * seq_len
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
qkv_tensor = torch.randn(num_tokens,
num_heads * head_size * 3,
dtype=dtype)
query, key, _ = qkv_tensor.split(
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
dim=-1,
)
ref_query, ref_key = rope.forward_native(positions, query, key)
query, key = torch.ops._C.rotary_embedding(
positions,
query,
key,
rope.head_size,
rope.cos_sin_cache,
rope.is_neox_style,
)
# Compare the results.
torch.testing.assert_close(query.view(ref_query.size()),
ref_query,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(key.view(ref_key.size()),
ref_key,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
class ModelwithRotaryEmbedding(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3)
self.rope = RotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
self.o_proj = nn.Linear(num_heads * head_size, hidden_size)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(3, dim=-1)
query, key = torch.ops._C.rotary_embedding(
positions,
q,
k,
self.rope.head_size,
self.rope.cos_sin_cache,
self.rope.is_neox_style,
)
query = query.view(q.shape)
key = key.view(k.shape)
o = self.o_proj(query)
return o
# The first graph seems will have some accuracy issue when directly run pytest on the ops folder,
# add a warmup graph replay for workaround
ACL_GRPAH_FIRST_RUN = True
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_capture_rotary_embedding_in_aclgraph(
is_neox_style: bool,
num_tokens: int,
num_heads: int,
head_size: int,
rotary_dim: int,
dtype: torch.dtype,
seed: int,
device: str,
max_position_embeddings: int = 8192,
base: int = 10000,
):
"""Test if the rotary embedding can be captured in aclgraph."""
torch.manual_seed(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
model = ModelwithRotaryEmbedding(
hidden_size=num_heads * head_size,
num_heads=num_heads,
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input):
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
# string match
graph = str(gm.graph)
assert "_C.rotary_embedding" in graph
return gm
static_positions = torch.randint(0, max_position_embeddings,
(num_tokens, ))
static_hidden_states = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device="npu")
compiled_model = torch.compile(model, backend=custom_op_checking_backend)
stream = torch.npu.Stream()
stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(stream):
# warmup the fx graph before capture
for i in range(3):
static_output = compiled_model(static_positions,
static_hidden_states,
offsets=None)
stream.wait_stream(torch.npu.current_stream())
aclgraph = torch.npu.NPUGraph()
with torch.npu.graph(aclgraph):
# Capture the model in aclgraph.
static_output = compiled_model(static_positions, static_hidden_states)
# Capture the model in aclgraph.
random_filled_positions = torch.randint(0,
max_position_embeddings,
(num_tokens, ),
device="npu")
random_filled_hidden_states = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device="npu")
static_positions.copy_(random_filled_positions)
static_hidden_states.copy_(random_filled_hidden_states)
aclgraph.replay()
global ACL_GRPAH_FIRST_RUN
if ACL_GRPAH_FIRST_RUN:
ACL_GRPAH_FIRST_RUN = False
return
output_reference = model(static_positions, static_hidden_states)
torch.testing.assert_close(static_output,
output_reference,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,98 @@
import gc
from typing import Tuple
import pytest
import torch
import torch_npu # noqa: F401
import vllm_ascend.platform # noqa: F401
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
# Test parameters
DTYPES = [torch.int32]
#SHAPES = [(100,), (5, 20), (3, 4, 5)] # Various tensor shapes
#SHAPES = [(3, 4, 8), (3, 4, 5)] # Various tensor shapes
SHAPES = [(3, 4, 3)]
DEVICES = [f"npu:{0}"]
SEEDS = [0]
def get_masked_input_and_mask_ref(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Reference implementation for verification"""
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
masked_input = vocab_mask * (input_ - valid_offset)
return masked_input, ~vocab_mask
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_get_masked_input_and_mask(
shape: Tuple[int, ...],
dtype: torch.dtype,
device: str,
seed: int,
) -> None:
# Set random seed
torch.manual_seed(seed)
torch.set_default_device(device)
# Generate random input tensor
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
# Test parameters
test_case = {
"org_start": 100,
"org_end": 200,
"padding": 0,
"added_start": 300,
"added_end": 400,
}
# Get reference result
ref_masked_input, ref_mask = get_masked_input_and_mask_ref(
input_tensor, test_case["org_start"], test_case["org_end"],
test_case["padding"], test_case["added_start"], test_case["added_end"])
# Get custom op result
print("input_tensor:", input_tensor)
custom_masked_input, custom_mask = torch.ops._C.get_masked_input_and_mask(
input_tensor, test_case["org_start"], test_case["org_end"],
test_case["padding"], test_case["added_start"], test_case["added_end"])
ref_masked_input = ref_masked_input.to(dtype)
print("custom_masked_input:", custom_masked_input)
print("ref_masked_input:", ref_masked_input)
print("custom_mask:", custom_mask)
print("ref_mask:", ref_mask)
# Compare results
torch.testing.assert_close(
custom_masked_input,
ref_masked_input,
rtol=1e-5,
atol=1e-5,
msg=f"Masked input mismatch for case: {test_case}")
torch.testing.assert_close(custom_mask,
ref_mask,
rtol=1e-5,
atol=1e-5,
msg=f"Mask mismatch for case: {test_case}")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,76 @@
from __future__ import annotations
import os
import pytest
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
@pytest.fixture
def model_name():
return "wemaster/deepseek_mtp_main_random_bf16"
def test_mtp_correctness(
sampling_config: SamplingParams,
model_name: str,
):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding.
'''
with VllmRunner(model_name,
tensor_parallel_size=1,
gpu_memory_utilization=0.7,
max_model_len=256,
enforce_eager=True) as ref_llm:
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
with VllmRunner(
model_name,
tensor_parallel_size=1,
max_num_seqs=256,
gpu_memory_utilization=0.7,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
enforce_eager=True,
max_model_len=2000,
additional_config={"ascend_scheduler_config": {
"enabled": False
}}) as spec_llm:
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0][0]
spec_token_ids = spec_output[0][0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1][0]}")
print(f"spec_output: {spec_output[1][0]}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))

View File

@@ -0,0 +1,85 @@
from __future__ import annotations
import os
import pytest
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
@pytest.fixture
def model_name():
return "wemaster/deepseek_mtp_main_random_bf16"
def test_mtp_torchair_correctness(
sampling_config: SamplingParams,
model_name: str,
):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding.
'''
with VllmRunner(model_name,
tensor_parallel_size=1,
gpu_memory_utilization=0.7,
max_model_len=256,
enforce_eager=False,
additional_config={
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"graph_batch_sizes": [1, 2, 4],
},
}) as ref_llm:
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
with VllmRunner(model_name,
tensor_parallel_size=1,
max_num_seqs=256,
gpu_memory_utilization=0.7,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
enforce_eager=False,
max_model_len=2000,
additional_config={
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"graph_batch_sizes": [1, 2, 4],
}
}) as spec_llm:
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0][0]
spec_token_ids = spec_output[0][0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1][0]}")
print(f"spec_output: {spec_output[1][0]}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))

View File

@@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import random
from typing import Any
import pytest
from vllm import LLM, SamplingParams
from tests.e2e.conftest import VllmRunner
@pytest.fixture
def test_prompts():
prompt_types = ["repeat", "sentence"]
num_prompts = 10
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
give no other output than the word at least ten times in a row,
in lowercase with spaces between each word and without quotes.
"""
elif kind == "sentence":
prompt = f"""
please give a ten-word sentence that
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
return prompts
@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
@pytest.fixture
def model_name():
return "LLM-Research/Meta-Llama-3.1-8B-Instruct"
def eagle_model_name():
return "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B"
def eagle3_model_name():
return "vllm-ascend/EAGLE3-LLaMA3.1-Instruct-8B"
def test_ngram_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
pytest.skip("Not current support for the test.")
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
with VllmRunner(model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
enforce_eager=True) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
@pytest.mark.skipif(True, reason="oom in CI, fix me")
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
def test_eagle_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
use_eagle3: bool,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
'''
if not use_eagle3:
pytest.skip("Not current support for the test.")
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
with VllmRunner(
model_name,
trust_remote_code=True,
enable_chunked_prefill=True,
max_num_seqs=1,
max_num_batched_tokens=2048,
gpu_memory_utilization=0.6,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"max_model_len": 128,
},
max_model_len=128,
enforce_eager=True,
) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))

View File

@@ -0,0 +1,75 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Compare the outputs of vLLM with and without aclgraph.
Run `pytest tests/compile/test_aclgraph.py`.
"""
import pytest
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
MODELS = [
"Qwen/Qwen3-0.6B",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
def test_models_with_aclgraph(
model: str,
max_tokens: int,
) -> None:
prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=False,
) as runner:
vllm_aclgraph_outputs = runner.model.generate(prompts, sampling_params)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=True,
) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
vllm_aclgraph_outputs_list = []
for output in vllm_aclgraph_outputs:
vllm_aclgraph_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
vllm_eager_outputs_list = []
for output in vllm_eager_outputs:
vllm_eager_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
check_outputs_equal(
outputs_0_lst=vllm_eager_outputs_list,
outputs_1_lst=vllm_aclgraph_outputs_list,
name_0="vllm_eager_outputs",
name_1="vllm_aclgraph_outputs",
)

View File

@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
MODEL = "Qwen/Qwen3-0.6B"
def test_concurrent_partial_prefill():
with VllmRunner(MODEL,
additional_config={
'ascend_scheduler_config': {
'enabled': True,
},
},
max_num_seqs=3,
max_num_batched_tokens=2048,
enforce_eager=True,
max_model_len=2048,
gpu_memory_utilization=0.7) as vllm_model:
outputs = vllm_model.model.generate(["Hello my name is Robert and I"] *
3)
assert len(outputs) == 3
for output in outputs:
assert len(output.outputs) == 1
def test_prefix_cache_stats_is_recorded():
with VllmRunner(MODEL,
additional_config={
'ascend_scheduler_config': {
'enabled': True,
},
},
max_num_seqs=3,
max_num_batched_tokens=2048,
enforce_eager=True,
max_model_len=2048,
gpu_memory_utilization=0.7) as vllm_model:
# 17 tokens will make sure first 16 tokens are cached in a block
input_tokens = {"prompt_token_ids": [101] * 129}
_ = vllm_model.model.generate([input_tokens])
outputs = vllm_model.model.generate([input_tokens])
assert outputs[0].num_cached_tokens == 128
@pytest.mark.parametrize("max_tokens",
[4]) # cannot align results when max_tokens > 4
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
def test_chunked_prefill_with_ascend_scheduler(
max_tokens: int, chunked_prefill_token_size: int) -> None:
example_prompts = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."
]
max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size
with VllmRunner(MODEL,
additional_config={
'ascend_scheduler_config': {
'enabled': True,
'enable_chunked_prefill': True,
},
},
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=2048,
gpu_memory_utilization=0.7) as vllm_model:
chunked_prefill_output = vllm_model.generate_greedy(
example_prompts, max_tokens)
with VllmRunner(MODEL,
additional_config={
'ascend_scheduler_config': {
'enabled': True,
},
},
max_model_len=2048,
gpu_memory_utilization=0.7) as vllm_model:
vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=vllm_output,
outputs_1_lst=chunked_prefill_output,
name_0="vllm_output",
name_1="chunked_prefill_output",
)

View File

@@ -0,0 +1,96 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import gc
import torch
from vllm import SamplingParams
from vllm.utils import GiB_bytes
from tests.e2e.conftest import VllmRunner
from tests.e2e.utils import fork_new_process_for_each_test
from vllm_ascend.device_allocator.camem import CaMemAllocator
@fork_new_process_for_each_test
def test_basic_camem():
# some tensors from default memory pool
shape = (1024, 1024)
x = torch.empty(shape, device='npu:0')
x.zero_()
# some tensors from custom memory pool
allocator = CaMemAllocator.get_instance()
with allocator.use_memory_pool():
# custom memory pool
y = torch.empty(shape, device='npu:0')
y.zero_()
y += 1
z = torch.empty(shape, device='npu:0')
z.zero_()
z += 2
# they can be used together
output = x + y + z
assert torch.allclose(output, torch.ones_like(output) * 3)
free_bytes = torch.npu.mem_get_info()[0]
allocator.sleep()
free_bytes_after_sleep = torch.npu.mem_get_info()[0]
assert free_bytes_after_sleep > free_bytes
allocator.wake_up()
# they can be used together
output = x + y + z
assert torch.allclose(output, torch.ones_like(output) * 3)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@fork_new_process_for_each_test
def test_end_to_end():
free, total = torch.npu.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
with VllmRunner("Qwen/Qwen3-0.6B",
enforce_eager=True,
enable_sleep_mode=True) as runner:
output = runner.model.generate(prompt, sampling_params)
# the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
# which is difficult to measure in the test. therefore, we only
# test sleep level 1 here.
runner.model.sleep(level=1)
free_gpu_bytes_after_sleep, total = torch.npu.mem_get_info()
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
# now the memory usage should be less than the model weights
# (0.5B model, 1GiB weights)
assert used_bytes < 1 * GiB_bytes
runner.model.wake_up()
output2 = runner.model.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text

View File

@@ -0,0 +1,81 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Compare the outputs of vLLM with and without aclgraph.
Run `pytest tests/compile/test_aclgraph.py`.
"""
import gc
import pytest
import torch
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [1])
def test_models(
model: str,
max_tokens: int,
) -> None:
prompts = ["The president of the United States is"]
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=0.0,
)
with VllmRunner(model, long_prefill_token_threshold=20,
enforce_eager=True) as vllm_model:
output1 = vllm_model.generate(prompts, sampling_params)
with VllmRunner(model,
enforce_eager=True,
additional_config={
'ascend_scheduler_config': {
'enabled': True
},
}) as vllm_model:
output2 = vllm_model.generate(prompts, sampling_params)
# Extract the generated token IDs for comparison
token_ids1 = output1[0][0][0]
token_ids2 = output2[0][0][0]
print(f"Token IDs 1: {token_ids1}")
print(f"Token IDs 2: {token_ids2}")
# Convert token IDs to tensors and calculate cosine similarity
# Take the length of a shorter sequence to ensure consistent dimensions
min_len = min(len(token_ids1), len(token_ids2))
tensor1 = torch.tensor(token_ids1[:min_len], dtype=torch.float32)
tensor2 = torch.tensor(token_ids2[:min_len], dtype=torch.float32)
# Calculate similarity using torch.cosine_similarity
similarity = torch.cosine_similarity(tensor1, tensor2, dim=0)
print(f"Token IDs cosine similarity: {similarity.item()}")
assert similarity > 0.95
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,49 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
from modelscope import snapshot_download # type: ignore[import-untyped]
from tests.e2e.conftest import HfRunner, VllmRunner
from tests.e2e.utils import check_embeddings_close
def test_embed_models_correctness():
queries = ['What is the capital of China?', 'Explain gravity']
model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B")
with VllmRunner(
model_name,
task="embed",
enforce_eager=True,
) as vllm_runner:
vllm_outputs = vllm_runner.encode(queries)
with HfRunner(
model_name,
dtype="float32",
is_sentence_transformer=True,
) as hf_runner:
hf_outputs = hf_runner.encode(queries)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)

View File

@@ -0,0 +1,150 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import jsonschema
import pytest
import regex as re
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from tests.e2e.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
MODEL_NAME = "Qwen/Qwen3-0.6B"
GuidedDecodingBackend = ["xgrammar", "guidance", "outlines"]
@pytest.fixture(scope="module")
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
@pytest.fixture(scope="module")
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}
@pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend)
def test_guided_json_completion(guided_decoding_backend: str,
sample_json_schema):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=500,
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
with VllmRunner(
MODEL_NAME,
seed=0,
guided_decoding_backend=guided_decoding_backend,
) as vllm_model:
prompts = [
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2
inputs = vllm_model.get_inputs(prompts)
outputs = vllm_model.model.generate(inputs,
sampling_params=sampling_params)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_json_schema)
@pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend)
def test_guided_regex(guided_decoding_backend: str, sample_regex):
if guided_decoding_backend == "outlines":
pytest.skip("Outlines doesn't support regex-based guided decoding.")
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex))
with VllmRunner(
MODEL_NAME,
seed=0,
guided_decoding_backend=guided_decoding_backend,
) as vllm_model:
prompts = [
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2
inputs = vllm_model.get_inputs(prompts)
outputs = vllm_model.model.generate(inputs,
sampling_params=sampling_params)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert re.fullmatch(".*", generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
import vllm
from modelscope import snapshot_download # type: ignore
from vllm.lora.request import LoRARequest
from tests.e2e.conftest import VllmRunner
MODEL_PATH = "vllm-ascend/ilama-3.2-1B"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
"SELECT DISTINCT Country FROM singer WHERE Age > 20",
]
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
),
PROMPT_TEMPLATE.format(
query=
"What are all distinct countries where singers above age 20 are from?" # noqa: E501
),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_ilama_lora(ilama_lora_files):
with VllmRunner(snapshot_download(MODEL_PATH),
enable_lora=True,
dtype="half",
max_loras=4,
max_model_len=1024,
max_num_seqs=16,
enforce_eager=True) as vllm_model:
output1 = do_sample(vllm_model.model, ilama_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(vllm_model.model, ilama_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]

View File

@@ -0,0 +1,71 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import gc
import os
import time
from unittest.mock import patch
import torch
import vllm # noqa: F401
from vllm_ascend.utils import ProfileExecuteDuration
@patch.dict(os.environ, {"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": "1"})
def test_execue_duration_enabled_discrepancy():
a = torch.randn(10000, 10000).npu()
b = torch.randn(10000, 10000).npu()
# warmup
torch.matmul(a, b)
torch.npu.synchronize()
cpu_start = time.perf_counter()
with ProfileExecuteDuration().capture_async("forward"):
torch.matmul(a, b)
torch.npu.synchronize()
cpu_duration = (time.perf_counter() - cpu_start) * 1000
npu_durations = ProfileExecuteDuration().pop_captured_sync()
assert npu_durations and 'forward' in npu_durations
assert not ProfileExecuteDuration._observations
# Assert discrepancy between CPU and NPU duration is within 50% roughly
diff = abs(cpu_duration - npu_durations['forward']) / max(
cpu_duration, npu_durations['forward'])
assert diff <= 0.5, (
f"CPU={cpu_duration:.2f}ms, NPU={npu_durations['forward']:.2f}ms")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
def test_execue_duration_disabled():
a = torch.randn(100, 100).npu()
b = torch.randn(100, 100).npu()
with ProfileExecuteDuration().capture_async("forward"):
torch.matmul(a, b)
torch.npu.synchronize()
npu_durations = ProfileExecuteDuration().pop_captured_sync()
assert not npu_durations
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,35 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from modelscope import snapshot_download # type: ignore[import-untyped]
from tests.e2e.conftest import VllmRunner
def test_quant_W8A8():
max_tokens = 5
example_prompts = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."
]
with VllmRunner(
snapshot_download("vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8"),
max_model_len=8192,
enforce_eager=True,
gpu_memory_utilization=0.7,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -0,0 +1,49 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
def test_models_topk() -> None:
example_prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(max_tokens=5,
temperature=0.0,
top_k=50,
top_p=0.9)
with VllmRunner("Qwen/Qwen3-0.6B",
max_model_len=8192,
gpu_memory_utilization=0.7) as runner:
runner.generate(example_prompts, sampling_params)
def test_models_prompt_logprobs() -> None:
example_prompts = [
"Hello, my name is",
]
with VllmRunner("Qwen/Qwen3-0.6B",
max_model_len=8192,
gpu_memory_utilization=0.7) as runner:
runner.generate_greedy_logprobs(example_prompts,
max_tokens=5,
num_logprobs=1)

View File

@@ -0,0 +1,89 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/test_offline_inference.py`.
"""
import os
import pytest
from vllm import SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from tests.e2e.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
@pytest.mark.skip(reason="fix me")
def test_multimodal_vl(prompt_template):
image = ImageAsset("cherry_blossom") \
.pil_image.convert("RGB")
img_questions = [
"What is the content of this image?",
"Describe the content of this image in detail.",
"What's in the image?",
"Where is this image taken?",
]
images = [image] * len(img_questions)
prompts = prompt_template(img_questions)
with VllmRunner("Qwen/Qwen2.5-VL-3B-Instruct",
max_model_len=4096,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
enforce_eager=True) as vllm_model:
vllm_model.generate_greedy(prompts=prompts,
images=images,
max_tokens=64)
def test_multimodal_audio():
audio_prompt = "".join([
f"Audio {idx+1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
for idx in range(2)
])
question = "What sport and what nursery rhyme are referenced?"
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
mm_data = {
"audio": [
asset.audio_and_sample_rate for asset in
[AudioAsset("mary_had_lamb"),
AudioAsset("winning_call")]
]
}
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
sampling_params = SamplingParams(temperature=0.2,
max_tokens=10,
stop_token_ids=None)
with VllmRunner("Qwen/Qwen2-Audio-7B-Instruct",
max_model_len=4096,
max_num_seqs=5,
dtype="bfloat16",
limit_mm_per_prompt={"audio": 2},
gpu_memory_utilization=0.9) as runner:
runner.generate(inputs, sampling_params=sampling_params)

106
tests/e2e/utils.py Normal file
View File

@@ -0,0 +1,106 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/utils.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import os
import signal
from collections.abc import Sequence
from typing import Callable
import torch
import torch.nn.functional as F
from typing_extensions import ParamSpec
_P = ParamSpec("_P")
def fork_new_process_for_each_test(
f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os.setpgrp()
from _pytest.outcomes import Skipped
pid = os.fork()
print(f"Fork a new process to run a test {pid}")
if pid == 0:
try:
f(*args, **kwargs)
except Skipped as e:
# convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception:
import traceback
traceback.print_exc()
os._exit(1)
else:
os._exit(0)
else:
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
# ignore SIGTERM signal itself
old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
# kill all child processes
os.killpg(pgid, signal.SIGTERM)
# restore the signal handler
signal.signal(signal.SIGTERM, old_signal_handler)
assert _exitcode == 0, (f"function {f} failed when called with"
f" args {args} and kwargs {kwargs}")
return wrapper
def matryoshka_fy(tensor: torch.Tensor, dimensions: int):
tensor = torch.tensor(tensor)
tensor = tensor[..., :dimensions]
tensor = F.normalize(tensor, p=2, dim=1)
return tensor
def check_embeddings_close(
*,
embeddings_0_lst: Sequence[list[float]],
embeddings_1_lst: Sequence[list[float]],
name_0: str,
name_1: str,
tol: float = 1e-3,
) -> None:
assert len(embeddings_0_lst) == len(embeddings_1_lst)
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
zip(embeddings_0_lst, embeddings_1_lst)):
assert len(embeddings_0) == len(embeddings_1), (
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")
sim = F.cosine_similarity(torch.tensor(embeddings_0),
torch.tensor(embeddings_1),
dim=0)
fail_msg = (f"Test{prompt_idx}:"
f"\nCosine similarity: \t{sim:.4f}"
f"\n{name_0}:\t{embeddings_0[:16]!r}"
f"\n{name_1}:\t{embeddings_1[:16]!r}")
assert sim >= 1 - tol, fail_msg

0
tests/ut/__init__.py Normal file
View File

View File

@@ -0,0 +1,133 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
class TestAttentionMaskBuilder(TestBase):
def test_init_attention_mask_builder(self):
# generate attention_mask_builder with float16
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.float16)
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
torch.float16)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(1024, 1024))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
# generate attention_mask_builder with bfloat16
attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048,
dtype=torch.bfloat16)
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
torch.bfloat16)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(2048, 2048))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(1, dtype=torch.bfloat16))
def test_get_mask_scale_factor(self):
# supported data types
self.assertEqual(
AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1)
self.assertEqual(
AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000)
# mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16
# Otherwise raise ValueError
with self.assertRaises(ValueError):
AttentionMaskBuilder.get_mask_scale_factor(torch.int8)
def test_get_attn_mask(self):
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.float16)
attn_mask = attention_mask_builder.get_attn_mask(
max_seq_len=512, dtype=torch.float16, device=torch.device("cpu"))
self.assertEqual(attn_mask.shape, (512, 512))
self.assertEqual(attn_mask[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(1024, 1024))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
# if the len is greater than max_seq_len, the attn_mask_cache will be updated
attn_mask = attention_mask_builder.get_attn_mask(
max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu"))
self.assertEqual(attn_mask.shape, (2048, 2048))
self.assertEqual(attn_mask[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(2048, 2048))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
def test_get_splitfuse_attn_mask(self):
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.float16)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=torch.tensor([10, 20, 100]),
position=torch.tensor([7, 8, 9, 18, 19, 99]),
dtype=torch.float16,
device=torch.device("cpu"),
)
self.assertEqual(attn_mask.shape, (6, 100))
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=torch.tensor([10, 3000, 2000]),
position=torch.tensor([7, 8, 9, 2999, 1999]),
dtype=torch.float16,
device=torch.device("cpu"),
)
self.assertEqual(attn_mask.shape, (5, 3000))
self.assertEqual(attention_mask_builder._seq_len_cached, 3000)
# splitfuse_attn_mask now only supports data types: torch.float16 and torch.bfloat16
# otherwise raise ValueError
with self.assertRaises(ValueError):
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=torch.tensor([10, 20, 100]),
position=torch.tensor([7, 8, 9, 18, 19, 99]),
dtype=torch.int8,
device=torch.device("cpu"),
)
def test_mask_value_cleanliness(self):
attention_mask_builder = AttentionMaskBuilder(max_seq_len=6,
dtype=torch.bfloat16)
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
torch.tensor(1, dtype=torch.bfloat16))
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=torch.tensor([6]),
position=torch.tensor([3, 4, 5]),
dtype=torch.bfloat16,
device=torch.device("cpu"),
)
self.assertEqual(
attn_mask[-2][-1],
torch.tensor(-10000, dtype=torch.bfloat16,
device=attn_mask.device))
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
torch.tensor(1, dtype=torch.bfloat16))

View File

@@ -0,0 +1,578 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendAttentionBackendImpl,
AscendAttentionMetadataBuilder,
AscendAttentionState,
AscendMetadata,
CommonAttentionState)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
class TestAscendAttentionBackend(TestBase):
def test_get_name(self):
self.assertEqual(AscendAttentionBackend.get_name(), "ASCEND")
def test_get_impl_cls(self):
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
AscendAttentionBackendImpl)
def test_get_metadata_cls(self):
self.assertEqual(AscendAttentionBackend.get_metadata_cls(),
AscendMetadata)
def test_get_state_cls(self):
self.assertEqual(AscendAttentionBackend.get_state_cls(),
CommonAttentionState)
def test_get_builder_cls(self):
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
AscendAttentionMetadataBuilder)
@patch('vllm_ascend.attention.attention_v1.is_310p')
def test_get_kv_cache_shape_310p(self, mock_is_310p):
mock_is_310p.return_value = True
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16))
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
def test_get_kv_cache_shape_not_310p(self, mock_is_310p):
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
self.assertEqual(result, (2, 10, 20, 30, 40))
def test_get_bsh_kv_cache_shape(self):
result = AscendAttentionBackend.get_bsh_kv_cache_shape(10, 20, 30, 40)
self.assertEqual(result, (2, 10, 20, 30 * 40))
def test_swap_blocks(self):
src_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))]
dst_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))]
src_to_dst = torch.tensor([[0, 1], [2, 3]])
AscendAttentionBackend.swap_blocks(src_kv_cache, dst_kv_cache,
src_to_dst)
self.assertTrue(torch.all(dst_kv_cache[0][1] == src_kv_cache[0][0]))
self.assertTrue(torch.all(dst_kv_cache[1][3] == src_kv_cache[1][2]))
def test_copy_blocks(self):
kv_caches = [torch.zeros((10, 20)), torch.zeros((10, 20))]
src_to_dists = torch.tensor([[0, 1], [2, 3]])
AscendAttentionBackend.copy_blocks(kv_caches, src_to_dists)
self.assertTrue(torch.all(kv_caches[0][1] == kv_caches[0][0]))
self.assertTrue(torch.all(kv_caches[1][3] == kv_caches[1][2]))
class TestAscendAttentionMetadataBuilder(TestBase):
def setUp(self):
self.mock_vllm_config = MagicMock()
self.mock_vllm_config.model_config.max_model_len = 640
self.mock_vllm_config.cache_config.block_size = 64
self.mock_device = 'cpu:0'
self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config,
self.mock_device)
def test_reorder_batch(self):
mock_input_batch = MagicMock()
mock_scheduler_output = MagicMock()
result = self.builder.reorder_batch(mock_input_batch,
mock_scheduler_output)
self.assertFalse(result)
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('torch_npu.npu_format_cast')
@patch('vllm_ascend.utils.nd_to_nz_2d')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
mock_npu_format_cast,
mock_ascend_metadata):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 3, 7]),
query_start_loc_cpu=torch.tensor([0, 3, 7]),
seq_lens_cpu=torch.tensor([5, 6]),
num_reqs=2,
num_actual_tokens=10,
max_query_len=5,
decode_token_per_req=torch.tensor([1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache)
mock_nz_tensor = MagicMock()
mock_model = MagicMock()
mock_nd_to_nz_2d.return_value = mock_nz_tensor
mock_npu_format_cast.return_value = mock_nz_tensor
self.builder.build(common_attn_metadata, mock_model)
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('torch_npu.npu_format_cast')
@patch('vllm_ascend.utils.nd_to_nz_spec')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
@patch('vllm_ascend.attention.attention_v1.AscendAttentionState')
def test_build_chunked_prefill(self, mock_ascend_attention_state,
mock_is_310p, mock_nd_to_nz_spec,
mock_npu_format_cast, mock_ascend_metadata):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2, 5, 9]),
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=15,
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)
mock_ascend_attention_state = MagicMock()
mock_ascend_attention_state.PrefillNoCache = 0
mock_nz_tensor = MagicMock()
mock_model = MagicMock()
mock_nd_to_nz_spec.return_value = mock_nz_tensor
mock_npu_format_cast.return_value = mock_nz_tensor
self.builder.build(common_attn_metadata, mock_model)
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2, 5, 9]),
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=15,
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)
mock_model = MagicMock()
self.builder.build(common_attn_metadata, mock_model)
class TestAscendAttentionBackendImpl(TestBase):
def setUp(self):
self.layer = MagicMock()
self.layer.layer_name = "test_layer"
self.layer._k_scale_float = 1.0
self.layer._v_scale_float = 1.0
self.attention_type = MagicMock()
self.attention_type.DECODER = "decoder"
self.attention_type.ENCODER = "encoder"
self.attn_metadata = MagicMock()
self.attn_metadata.return_value = "1"
self.layer_no_quant = MagicMock(
spec=['layer_name', '_k_scale_float', '_v_scale_float'])
self.layer_no_quant.layer_name = "test_layer"
self.layer_no_quant._k_scale_float = 1.0
self.layer_no_quant._v_scale_float = 1.0
self.impl = AscendAttentionBackendImpl(
num_heads=8,
head_size=64,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
self.impl_192 = AscendAttentionBackendImpl(
num_heads=8,
head_size=192,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
self.impl_error = AscendAttentionBackendImpl(
num_heads=8,
head_size=192,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=None,
kv_sharing_target_layer_name=None)
self.impl_swa = AscendAttentionBackendImpl(
num_heads=8,
head_size=64,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=1024,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
@patch('torch.ops.vllm.unified_ascend_attention_with_output')
def test_forward_trace_flag_true(self, mock_unified_attention):
"""Test forward pass when trace_flag is True"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 0, 0, 8, 64)
metadata = self.attn_metadata
layer = self.layer
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=True)
mock_unified_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_paged_attention_splitfuse')
def test_forward_with_quant_method(self, mock_paged_attention):
"""Test forward pass when layer has quant_method"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
k_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
v_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
kv_cache = [k_cache, v_cache]
ret_value = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
metadata = MagicMock()
metadata.num_actual_tokens = torch.randn(10, 8 * 64)
metadata.block_tables = torch.randn(10, 8 * 64)
metadata.seq_lens = torch.randn(10, 8 * 64)
metadata.attn_mask = torch.randn(10, 8 * 64)
metadata.query_lens = torch.randn(10, 8 * 64)
layer = self.layer
layer.quant_method = MagicMock()
layer.quant_method.apply.return_value = ret_value
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
layer.quant_method.apply.assert_called_once()
assert output.shape == (10, 8 * 64)
def test_forward_no_attn_metadata(self):
"""Test forward pass when attn_metadata is None"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 0, 0, 8, 64)
layer = self.layer_no_quant
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
None,
trace_flag=False)
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_flash_attention')
def test_forward_prefill_no_cache(self, mock_flash_attention,
mock_reshape_cache):
"""Test forward pass in PrefillNoCache state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillNoCache
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.seq_lens = torch.tensor([10])
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
# layer.quant_method.apply.return_value = metadata
print(self.layer_no_quant._v_scale_float)
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_reshape_cache.assert_called_once()
mock_flash_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_flash_attention')
def test_forward_prefill_no_cache_swa(self, mock_flash_attention,
mock_reshape_cache):
"""Test forward pass in PrefillNoCache state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillNoCache
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.seq_lens = torch.tensor([10])
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
# layer.quant_method.apply.return_value = metadata
print(self.layer_no_quant._v_scale_float)
output = self.impl_swa.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_reshape_cache.assert_called_once()
mock_flash_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_flash_attention_qlens')
def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,
mock_npu_reshape_and_cache):
"""Test forward pass in PrefillCacheHit state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillCacheHit
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_flash_attention_qlens.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention')
def test_forward_decode_only(self, mock_paged_attention,
mock_npu_reshape_and_cache):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
mock_npu_reshape_and_cache):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10] * 10)
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 100
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
output = self.impl_swa.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
print(output.shape)
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
def test_forward_head_size_192(self, mock_vanilla_prefill,
mock_npu_reshape_and_cache, mock_is_310p):
"""Test forward pass when head_size is 192"""
self.impl.head_size = 192
query = torch.randn(10, 8 * 192)
key = torch.randn(10, 8 * 192)
value = torch.randn(10, 8 * 192)
kv_cache = torch.empty(2, 5, 128, 8, 192)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
mock_vanilla_prefill.return_value = MagicMock()
output = self.impl_192.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_vanilla_prefill.assert_called_once()
assert output.shape == (10, 8 * 192)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention_splitfuse')
def test_forward_normal_v1_situation(self, mock_paged_attention,
mock_npu_reshape_and_cache):
"""Test forward pass in normal V1 situation"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu.npu_format_cast')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention_splitfuse')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
def test_forward_310p_device(self, mock_is_310p, mock_paged_attention,
mock_npu_reshape_and_cache,
mock_npu_format_cast):
"""Test forward pass on 310P device"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
mock_npu_format_cast.return_value = metadata.attn_mask
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_raise_error(self, mock_paged_attention):
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
with self.assertRaises(NotImplementedError):
self.impl_error.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)

View File

@@ -0,0 +1,631 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
AscendMLADecodeMetadata,
AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder,
AscendMLAPrefillMetadata)
class TestAscendMLABackend(TestBase):
def test_get_name(self):
self.assertEqual(AscendMLABackend.get_name(), "ASCEND_MLA")
def test_get_metadata_cls(self):
self.assertEqual(AscendMLABackend.get_metadata_cls(),
AscendMLAMetadata)
def test_get_builder_cls(self):
self.assertEqual(AscendMLABackend.get_builder_cls(),
AscendMLAMetadataBuilder)
def test_get_kv_cache_shape(self):
result = AscendMLABackend.get_kv_cache_shape(2, 4, 8, 128)
self.assertEqual(result, (2, 4, 8, 128))
def test_get_impl_cls(self):
result = AscendMLABackend.get_impl_cls()
self.assertEqual(result, AscendMLAImpl)
class TestAscendMLAPrefillMetadata(TestBase):
def test_ascend_mla_prefill_metadata_default(self):
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
query_lens = [1, 2]
seq_lens = [2, 2]
context_lens = torch.tensor([1, 2])
input_positions = torch.tensor([0, 1, 0, 1])
query_start_loc = torch.tensor([0, 1, 3])
block_table = torch.tensor([[0, 1], [2, 3]])
max_query_len = 2
max_seq_lens = 2
metadata = AscendMLAPrefillMetadata(attn_mask=attn_mask,
query_lens=query_lens,
seq_lens=seq_lens,
context_lens=context_lens,
input_positions=input_positions,
query_start_loc=query_start_loc,
block_table=block_table,
max_query_len=max_query_len,
max_seq_lens=max_seq_lens)
self.assertIs(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.context_lens, context_lens)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertIs(metadata.block_table, block_table)
self.assertEqual(metadata.max_query_len, max_query_len)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertIsNone(metadata.chunked_context)
def test_ascend_mla_prefill_metadata_with_chunked_context(self):
cu_seq_lens = torch.tensor([0, 2, 4])
starts = torch.tensor([0, 2])
seq_tot = [2, 2]
max_seq_lens = [2, 2]
workspace = torch.randn(2, 4)
chunk_seq_lens = torch.tensor([2, 2])
chunked_context = AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens,
starts=starts,
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens)
metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
query_lens=[1, 2],
seq_lens=[2, 2],
context_lens=torch.tensor([1, 2]),
input_positions=torch.tensor([0, 1, 0, 1]),
query_start_loc=torch.tensor([0, 1, 3]),
block_table=torch.tensor([[0, 1], [2, 3]]),
max_query_len=2,
max_seq_lens=2,
chunked_context=chunked_context)
self.assertIsNotNone(metadata.chunked_context)
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
self.assertIs(metadata.chunked_context.starts, starts)
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
class TestAscendMLADecodeMetadata(TestBase):
def test_ascend_mla_decode_metadata_default(self):
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
seq_lens = torch.tensor([[2], [3]])
max_seq_lens = 4
seq_lens_list = [2, 3]
attn_mask = None
metadata = AscendMLADecodeMetadata(input_positions, block_table,
seq_lens, max_seq_lens,
seq_lens_list, attn_mask)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.block_table, block_table)
self.assertIs(metadata.seq_lens, seq_lens)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
self.assertIsNone(attn_mask)
class TestAscendMLAMetadata(TestBase):
def test_ascend_mla_metadata_default(self):
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
query_start_loc = torch.tensor([1, 2, 3, 4])
seq_lens = [30, 50]
block_tables = torch.randint(0, 100, (100, 4))
num_decodes = 4
num_decode_tokens = 8
num_prefills = 8
num_input_tokens = 2
query_lens = None
head_dim = None
attn_mask = None
attn_state = AscendAttentionState.ChunkedPrefill
decode = None
prefill = None
metadata = AscendMLAMetadata(num_actual_tokens, slot_mapping,
query_start_loc, seq_lens, block_tables,
num_decodes, num_decode_tokens,
num_prefills, num_input_tokens,
query_lens, head_dim, attn_mask,
attn_state, decode, prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.block_tables, block_tables)
self.assertEqual(metadata.num_decodes, num_decodes)
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
self.assertEqual(metadata.num_prefills, num_prefills)
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.head_dim, head_dim)
self.assertEqual(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.attn_state, attn_state)
self.assertEqual(metadata.decode, decode)
self.assertEqual(metadata.prefill, prefill)
class TestAscendMLAMetadataBuilder(TestBase):
def test_ascend_mla_metadata_builder_default(self):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
ascend_config = MagicMock()
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config):
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
self.assertEqual(builder.block_size,
mock_vllm_config.cache_config.block_size)
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
def test_reorder_batch(self):
ascend_config = MagicMock()
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config):
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
builder.decode_threshold = 1
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = {0: 1, 1: 3, 2: 1, 3: 2}
scheduler_output.scheduled_spec_decode_tokens = {
0: [],
1: [1],
2: [],
3: []
}
input_batch.swap_states = MagicMock()
modified = builder.reorder_batch(input_batch, scheduler_output)
self.assertTrue(modified)
input_batch.swap_states.assert_called_once_with(1, 2)
class TestAscendMLAImpl(TestBase):
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
return_value=2)
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
mock_tp):
mock_tp.world_size = 2
vllm_config = MagicMock()
speculative_config = MagicMock()
model_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
model_config.dtype = torch.float16
vllm_config.model_config = model_config
get_current_vllm_config.return_value = vllm_config
num_heads = 256
head_size = 1024
scale = 0.1
num_kv_heads = 8
kv_cache_dtype = "auto"
kv_a_layernorm = MagicMock()
kv_a_layernorm.weight = torch.randn(96)
kv_a_layernorm.variance_epsilon = 1e-6
kwargs = {
"q_lora_rank": 64,
"kv_lora_rank": 32,
"qk_nope_head_dim": 64,
"qk_rope_head_dim": 32,
"qk_head_dim": 96,
"v_head_dim": 128,
"rotary_emb": MagicMock(),
"q_proj": MagicMock(),
"kv_b_proj": MagicMock(),
"o_proj": MagicMock(),
"kv_a_proj_with_mqa": MagicMock(),
"kv_a_layernorm": kv_a_layernorm,
}
self.impl = AscendMLAImpl(num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=kv_cache_dtype,
blocksparse_params=None,
logits_soft_cap=None,
attn_type=None,
kv_sharing_target_layer_name=None,
**kwargs)
def test_init(self):
self.assertEqual(self.impl.num_heads, 256)
self.assertEqual(self.impl.head_size, 1024)
self.assertEqual(self.impl.scale, 0.1)
self.assertEqual(self.impl.num_kv_heads, 8)
self.assertEqual(self.impl.kv_cache_dtype, "auto")
self.assertEqual(self.impl.q_lora_rank, 64)
self.assertEqual(self.impl.kv_lora_rank, 32)
self.assertEqual(self.impl.qk_nope_head_dim, 64)
self.assertEqual(self.impl.qk_rope_head_dim, 32)
self.assertEqual(self.impl.qk_head_dim, 96)
self.assertEqual(self.impl.v_head_dim, 128)
self.assertIsNotNone(self.impl.rotary_emb)
self.assertIsNotNone(self.impl.q_proj)
self.assertIsNotNone(self.impl.kv_b_proj)
self.assertIsNotNone(self.impl.o_proj)
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
self.assertIsNotNone(self.impl.kv_a_layernorm)
self.assertEqual(self.impl.num_queries_per_kv, 32)
self.assertEqual(self.impl.tp_size, 2)
def test_v_up_proj(self):
batch_size = 4
x = torch.randn(batch_size, self.impl.num_heads,
self.impl.kv_lora_rank)
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
self.impl.W_UV = torch.randn(self.impl.num_heads,
self.impl.kv_lora_rank,
self.impl.v_head_dim)
result = self.impl._v_up_proj(x)
self.assertEqual(result.shape[0], batch_size)
self.assertEqual(result.shape[1],
self.impl.num_heads * self.impl.v_head_dim)
def test_q_proj_and_k_up_proj(self):
batch_size = 4
x = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim)
q_proj_output = torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_head_dim)
self.impl.q_proj.return_value = (q_proj_output, )
if not hasattr(self.impl, 'W_UK_T') or self.impl.W_UK_T is None:
self.impl.W_UK_T = torch.randn(self.impl.num_heads,
self.impl.qk_nope_head_dim,
self.impl.kv_lora_rank)
result = self.impl._q_proj_and_k_up_proj(x)
ql_nope, q_pe = result
self.assertEqual(ql_nope.shape[0], batch_size)
self.assertEqual(ql_nope.shape[1], self.impl.num_heads)
self.assertEqual(ql_nope.shape[2], self.impl.kv_lora_rank)
self.assertEqual(q_pe.shape[0], batch_size)
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
def test_process_weights_after_loading(self):
layer = MagicMock(spec=LinearBase)
layer.input_size_per_partition = 10
quant_method = MagicMock()
apply = MagicMock()
quant_method.apply = apply
layer.quant_method = quant_method
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim +
self.impl.v_head_dim)
shape_1 = self.impl.kv_lora_rank
layer.weight = torch.randn(shape_0, shape_1)
self.impl.kv_b_proj = layer
apply.return_value = layer.weight.T
self.impl.process_weights_after_loading(torch.bfloat16)
self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads)
self.assertEqual(self.impl.W_UK_T.shape[1], self.impl.qk_nope_head_dim)
self.assertEqual(self.impl.W_UK_T.shape[2], self.impl.kv_lora_rank)
self.assertEqual(self.impl.W_UV.shape[0], self.impl.num_heads)
self.assertEqual(self.impl.W_UV.shape[1], self.impl.kv_lora_rank)
self.assertEqual(self.impl.W_UV.shape[2], self.impl.v_head_dim)
def test_compute_prefill_context_none(self):
batch_size = 4
kv_cache = torch.randn(10, 1, 1, 192)
query = torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_head_dim)
metadata = MagicMock()
metadata.prefill = None
prefix_out = torch.randn(2, 16, 128)
prefix_lse = torch.randn(2, 16, 8)
q_pe = query[..., self.impl.qk_nope_head_dim:]
q_nope = query[..., :self.impl.qk_nope_head_dim]
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
32, metadata, prefix_out,
prefix_lse)
self.assertTrue(torch.equal(prefix_out, out))
self.assertTrue(torch.equal(prefix_lse, lse))
@patch("torch_npu.atb.npu_paged_cache_load")
@patch("torch_npu.atb.npu_ring_mla")
def test_compute_prefill_context(self, mock_ring, mock_load):
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
latent_kv_dim = self.impl.kv_lora_rank
num_blocks, block_size = 100, 20
query = torch.randn(S, N, D)
q_nope = query[..., :self.impl.qk_nope_head_dim]
q_pe = query[..., self.impl.qk_nope_head_dim:]
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
kv_cache = [kv_cache_0, kv_cache_1]
prefix_out = torch.randn(S, N, 128)
prefix_lse = torch.randn(S, N)
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])]
prefill_meta = MagicMock()
prefill_meta.chunked_context = chunk_ctx
prefill_meta.query_lens = [8]
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
meta = MagicMock()
meta.prefill = prefill_meta
self.impl.prefill_mask = torch.triu(
torch.ones(512, 512, device=q_nope.device, dtype=q_nope.dtype), 1)
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
32, meta, prefix_out,
prefix_lse)
mock_load.assert_called_once()
mock_ring.assert_called_once()
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
@patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_decode_without_graph(self,
mock_npu_fused_infer_attention_score,
mock_up_proj):
num_tokens = 100
block_size = 4
q_nope = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_nope_head_dim)
q_pe = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim)
k_nope = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_nope_head_dim)
k_pe = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim)
metadata = MagicMock()
metadata.decode = MagicMock()
metadata.decode.block_table = MagicMock()
metadata.decode.seq_lens = 10
mock_npu_fused_infer_attention_score.return_value = [
torch.randn(num_tokens, self.impl.num_heads,
self.impl.kv_lora_rank), None
]
mock_up_proj.return_value = torch.randn(num_tokens,
self.impl.num_heads,
self.impl.v_head_dim)
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe,
block_size, metadata)
self.assertEqual(result.shape[0], num_tokens)
self.assertEqual(result.shape[1], self.impl.num_heads)
self.assertEqual(result.shape[2], self.impl.v_head_dim)
mock_up_proj.assert_called_once()
mock_npu_fused_infer_attention_score.assert_called_once()
@patch("vllm_ascend.attention.mla_v1.npu_prefetch")
def test_mla_preprocess(self, magic_npu_fetch):
magic_npu_fetch.return_value = MagicMock()
batch_size = 4
seq_len = 8
hidden_size = 1024
hidden_states = torch.randn(batch_size * seq_len, hidden_size)
kv_cache = MagicMock()
attn_metadata = MagicMock()
attn_metadata.num_decodes = 2
attn_metadata.num_prefills = 2
attn_metadata.num_decode_tokens = 2
attn_metadata.num_actual_tokens = 4
num_prefill_tokens = 2
attn_metadata.slot_mapping = torch.arange(4)
attn_metadata.decode.cos = torch.randn(2, 64)
attn_metadata.decode.sin = torch.randn(2, 64)
attn_metadata.prefill.cos = torch.randn(2, 64)
attn_metadata.prefill.sin = torch.randn(2, 64)
self.impl.q_a_proj = MagicMock()
self.impl.q_a_layernorm = MagicMock()
self.impl.q_a_layernorm.return_value = torch.randn(
attn_metadata.num_actual_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim)
self.impl.kv_a_proj_with_mqa = MagicMock()
self.impl.kv_a_proj_with_mqa.return_value = [
torch.randn(num_prefill_tokens, self.impl.num_heads,
self.impl.qk_nope_head_dim + self.impl.kv_lora_rank)
]
self.impl.q_proj = MagicMock()
self.impl.q_proj.return_value = [
torch.randn(num_prefill_tokens, self.impl.num_heads,
self.impl.qk_head_dim)
]
self.impl.kv_b_proj = MagicMock()
self.impl.kv_b_proj.return_value = [
torch.randn(num_prefill_tokens, self.impl.num_heads,
self.impl.v_head_dim + self.impl.qk_nope_head_dim)
]
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
self.impl.exec_kv_decode = MagicMock()
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
self.impl.exec_kv_prefill = MagicMock()
self.impl.exec_kv_prefill.return_value = [
torch.randn(num_prefill_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim),
torch.randn(num_prefill_tokens, self.impl.num_heads,
self.impl.kv_lora_rank)
]
self.impl._q_proj_and_k_up_proj = MagicMock()
self.impl._q_proj_and_k_up_proj.return_value = [
MagicMock(), MagicMock()
]
self.impl.num_kv_heads = self.impl.num_heads
decode_res, prefill_res = self.impl._mla_preprocess(
hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False)
self.assertIsNotNone(decode_res)
self.assertIsNotNone(prefill_res)
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
def test_exec_kv_prefill(self, mock_kv_rmsnorm_rope_cache):
B = 2
N = self.impl.num_kv_heads
D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim
kv_no_split = torch.randn(B, N, D)
self.impl.enable_kv_nz = None
self.impl.kv_a_layernorm.weight = MagicMock()
self.impl.kv_a_layernorm.variance_epsilon = MagicMock()
cos = MagicMock()
sin = MagicMock()
slots = MagicMock()
kv_cache = [MagicMock(), MagicMock()]
mock_kv_rmsnorm_rope_cache.return_value = [
None, None,
torch.randn(B, N, 1, self.impl.qk_rope_head_dim),
torch.randn(B, N, 1, self.impl.kv_lora_rank)
]
k_pe, k_nope = self.impl.exec_kv_prefill(kv_no_split, cos, sin,
kv_cache, slots)
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
def test_exec_kv_decode(self, mock_kv_rmsnorm_rope_cache):
B = 2
N = self.impl.num_kv_heads
D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim
kv_no_split = torch.randn(B, N, D)
self.impl.enable_kv_nz = None
self.impl.kv_a_layernorm.weight = MagicMock()
self.impl.kv_a_layernorm.variance_epsilon = MagicMock()
cos = MagicMock()
sin = MagicMock()
slots = MagicMock()
kv_cache = [MagicMock(), MagicMock()]
mock_kv_rmsnorm_rope_cache.return_value = [
torch.randn(B, N, 1, self.impl.qk_rope_head_dim),
torch.randn(B, N, 1, self.impl.kv_lora_rank), None, None
]
k_pe, k_nope = self.impl.exec_kv_decode(kv_no_split, cos, sin,
kv_cache, slots)
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
@patch("torch.npu.stream")
@patch("vllm_ascend.attention.mla_v1.get_multistream_comm_context")
@patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
mock_get_multistream_comm_context,
mock_npu_stream):
B = 2
N = self.impl.num_kv_heads
BS = 100
HD = self.impl.v_head_dim
self.impl.kv_lora_rank = 256
self.impl.spec_token_num = 1
self.impl._v_up_proj = MagicMock()
self.impl._v_up_proj.return_value = torch.randn(B, N, HD)
q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
k_nope = torch.randn(BS, N, self.impl.kv_lora_rank)
k_pe = torch.randn(BS, N, self.impl.qk_rope_head_dim)
attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.SpecDecoding
attn_metadata.decode = MagicMock()
attn_metadata.decode.actual_seq_lengths_q = MagicMock()
attn_metadata.decode.seq_lens_list = MagicMock()
self.impl.enable_kv_nz = True
mock_npu_fused_infer_attention_score.return_value = [
torch.randn(B, N, self.impl.kv_lora_rank), None
]
mock_get_multistream_comm_context.return_value = None
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
attn_metadata)
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], HD)
self.impl.enable_kv_nz = False
attn_metadata.attn_state = None
mock_return_value = MagicMock()
mock_get_multistream_comm_context.return_value = mock_return_value
mock_return_value.before_comm_event = MagicMock()
mock_return_value.comm_stream = MagicMock()
mock_npu_stream.return_value = MagicMock()
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
attn_metadata)
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], HD)

44
tests/ut/base.py Normal file
View File

@@ -0,0 +1,44 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import unittest
import pytest
from vllm_ascend.utils import adapt_patch, register_ascend_customop
class TestBase(unittest.TestCase):
def __init__(self, *args, **kwargs):
# adapt patch by default.
adapt_patch(True)
adapt_patch()
register_ascend_customop()
super().setUp()
super(TestBase, self).__init__(*args, **kwargs)
class PytestBase:
"""Base class for pytest-based tests.
because pytest mocker and parametrize usage are not compatible with unittest.
so we need to use a separate base class for pytest tests.
"""
@pytest.fixture(autouse=True)
def setup(self):
adapt_patch(True)
adapt_patch()
register_ascend_customop()

26
tests/ut/conftest.py Normal file
View File

@@ -0,0 +1,26 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from vllm_ascend.utils import adapt_patch # noqa E402
from vllm_ascend.utils import register_ascend_customop
adapt_patch()
adapt_patch(True)
# register Ascend CustomOp here because uts will use this
register_ascend_customop()

View File

@@ -0,0 +1,167 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm.config import SchedulerConfig
from tests.ut.base import TestBase
from vllm_ascend.core.schedule_config import AscendSchedulerConfig
class TestAscendSchedulerConfig(TestBase):
def setUp(self):
self.basic_scheduler_config = SchedulerConfig(
max_num_batched_tokens=8192,
max_model_len=8192,
is_multimodal_model=False,
send_delta_data=False,
scheduler_delay_factor=0,
)
def test_initialize_from_config_with_default(self):
# No additional config given, check the default value here.
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config, {})
self.assertEqual(ascend_config.enable_chunked_prefill, False)
self.assertEqual(ascend_config.policy, "fcfs")
self.assertEqual(ascend_config.num_scheduler_steps, 1)
self.assertEqual(ascend_config.scheduler_cls,
"vllm_ascend.core.scheduler.AscendScheduler")
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
self.assertEqual(ascend_config.encoder_cache_size, 8192)
def test_initialize_from_config_with_override(self):
# test override
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
enable_chunked_prefill=False,
policy="fcfs",
num_scheduler_steps=1,
scheduler_cls="vllm_ascend.core.scheduler.AscendScheduler",
max_num_batched_tokens=2048,
max_model_len=2048,
),
)
self.assertEqual(ascend_config.enable_chunked_prefill, False)
self.assertEqual(ascend_config.policy, "fcfs")
self.assertEqual(ascend_config.num_scheduler_steps, 1)
self.assertEqual(ascend_config.scheduler_cls,
"vllm_ascend.core.scheduler.AscendScheduler")
self.assertEqual(ascend_config.max_num_batched_tokens, 2048)
self.assertEqual(ascend_config.encoder_cache_size, 2048)
def test_not_implemented_policy(self):
with self.assertRaises(NotImplementedError) as context:
AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
policy="custom_policy",
max_num_batched_tokens=2048,
max_model_len=2048,
),
)
self.assertIn(
"currently AscendScheduler only supports fcfs policy",
str(context.exception),
)
def test_not_implemented_multimodal(self):
with self.assertRaises(NotImplementedError) as context:
AscendSchedulerConfig.initialize_from_config(
SchedulerConfig(is_multimodal_model=True), {})
self.assertIn("currently AscendScheduler only supports LLM models",
str(context.exception))
def test_not_implemented_multi_step(self):
with self.assertRaises(NotImplementedError) as context:
AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
num_scheduler_steps=2,
max_num_batched_tokens=2048,
max_model_len=2048,
),
)
self.assertIn(
"currently AscendScheduler doesn't support multi-step",
str(context.exception),
)
def test_not_implemented_send_delta_data(self):
with self.assertRaises(NotImplementedError) as context:
AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
send_delta_data=True,
max_num_batched_tokens=2048,
max_model_len=2048,
),
)
self.assertIn(
"currently AscendScheduler doesn't support send_delta_data",
str(context.exception),
)
def test_not_implemented_delay_factor(self):
with self.assertRaises(NotImplementedError) as context:
AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
delay_factor=1,
max_num_batched_tokens=2048,
max_model_len=2048,
),
)
self.assertIn(
"currently AscendScheduler doesn't support scheduler_delay_factor",
str(context.exception),
)
def test_no_override(self):
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config, {})
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
self.assertEqual(ascend_config.encoder_cache_size, 8192)
def test_valid_config_with_chunked_prefill(self):
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
enable_chunked_prefill=True,
max_num_batched_tokens=2048,
max_model_len=4096,
),
)
self.assertEqual(ascend_config.max_num_batched_tokens, 2048)
self.assertEqual(ascend_config.max_model_len, 4096)
self.assertTrue(ascend_config.enable_chunked_prefill)
def test_invalid_config_without_chunked_prefill(self):
with self.assertRaises(ValueError) as context:
AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
enable_chunked_prefill=False,
max_num_batched_tokens=2048,
max_model_len=4096,
),
)
self.assertIn(
"Ascend scheduler is enabled without chunked prefill feature",
str(context.exception),
)
self.assertIn("max_num_batched_tokens (2048)", str(context.exception))
self.assertIn("max_model_len (4096)", str(context.exception))

View File

@@ -0,0 +1,898 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import MagicMock, patch
import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from tests.ut.base import TestBase
from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
from vllm.v1.outputs import DraftTokenIds
else:
DraftTokenIds = None
EOS_TOKEN_ID = 50256
MODEL = "Qwen3-0.6B"
ENABLE_PREFIX_CACHING = None
PROMPT_LOGPROBS = None
ENABLE_CHUNKED_PREFILL = False
MAX_NUM_BATCHED_TOKENS = 10000
LONG_PREFILL_TOKEN_THRESHOLD = 0
NUM_SPECULATIVE_TOKENS = None
MAX_NUM_SEQS = 16
def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
block_size: int = 3,
hash_fn=hash,
):
init_none_hash(hash_fn)
prompt_logprobs = PROMPT_LOGPROBS
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
request = Request(request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_kwargs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
pooling_params=None,
block_hasher=get_request_block_hasher(
block_size, hash_fn))
else:
request = Request(request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
eos_token_id=EOS_TOKEN_ID,
pooling_params=None,
block_hasher=get_request_block_hasher(
block_size, hash_fn))
requests.append(request)
return requests
def make_output(scheduler):
req_ids = [req.request_id for req in scheduler.running]
req_id_to_index = {
req.request_id: i
for i, req in enumerate(scheduler.running)
}
sampled_token_ids = [[1000]] * len(scheduler.running)
logprobs = None
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
modelrunner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=logprobs,
prompt_logprobs_dict={},
pooler_output=[],
)
else:
modelrunner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprobs=logprobs,
prompt_logprobs_dict={},
pooler_output=[],
)
return modelrunner_output
class TestAscendScheduler(TestBase):
@patch("vllm.config.ModelConfig.__post_init__", MagicMock())
@patch("vllm.config.VllmConfig.__post_init__", MagicMock())
@patch('vllm.v1.core.sched.scheduler.compute_encoder_budget')
def create_scheduler(self, mock_compute_encoder_budget):
mock_compute_encoder_budget.return_value = [10, 20]
use_kv_connector = False
block_size = 16
scheduler_config = SchedulerConfig(
max_num_seqs=16,
max_model_len=MAX_NUM_BATCHED_TOKENS,
long_prefill_token_threshold=LONG_PREFILL_TOKEN_THRESHOLD,
disable_chunked_mm_input=False,
enable_chunked_prefill=ENABLE_CHUNKED_PREFILL,
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS,
)
scheduler_config.max_num_encoder_input_tokens = 10000
scheduler_config.encoder_cache_size = 10000
scheduler_config.chunked_prefill_enabled = False
model_config = ModelConfig(
model=MODEL,
task="auto",
tokenizer=MODEL,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
max_model_len=MAX_NUM_BATCHED_TOKENS,
)
model_config.pooler_config = MagicMock()
model_config.multimodal_config = MagicMock()
model_config.hf_config = MagicMock()
model_config.hf_config.is_encoder_decoder = False
# Cache config, optionally force APC
kwargs_cache: Dict[str,
Any] = ({} if ENABLE_PREFIX_CACHING is None else {
'enable_prefix_caching':
ENABLE_PREFIX_CACHING
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if NUM_SPECULATIVE_TOKENS is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=NUM_SPECULATIVE_TOKENS)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=10000, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1,
torch.float32, False))
],
)
cache_config.num_gpu_blocks = 10000
scheduler = AscendScheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=MagicMock(spec=StructuredOutputManager),
)
should_advance = MagicMock()
should_advance.return_value = False
scheduler.structured_output_manager.should_advance = should_advance
return scheduler
def test_add_requests(self):
scheduler = self.create_scheduler()
requests = create_requests(num_requests=10)
for i, request in enumerate(requests):
scheduler.add_request(request)
self.assertIn(request.request_id, scheduler.requests)
self.assertEqual(len(scheduler.waiting), i + 1)
def test_finish_request(self):
scheduler = self.create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_ABORTED)
self.assertNotIn(request.request_id, scheduler.requests)
self.assertEqual(len(scheduler.waiting), 9 - i)
def test_get_num_unfinished_requests(self):
scheduler = self.create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_STOPPED)
self.assertEqual(scheduler.get_num_unfinished_requests(),
len(requests) - i - 1)
def test_schedule(self):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = False
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
self.assertEqual(num_tokens,
len(requests[int(req_id)].prompt_token_ids))
# Verify requests moved from waiting to running
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), len(requests))
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)
def test_schedule_enable_prefix_caching(self):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
global ENABLE_PREFIX_CACHING
ENABLE_PREFIX_CACHING = True
global PROMPT_LOGPROBS
PROMPT_LOGPROBS = 5
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = False
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
self.assertEqual(num_tokens,
len(requests[int(req_id)].prompt_token_ids))
# Verify requests moved from waiting to running
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), len(requests))
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)
def test_stop_via_update_from_output(self):
"""Test stopping behavior through update_from_output"""
global NUM_SPECULATIVE_TOKENS
NUM_SPECULATIVE_TOKENS = 1
scheduler = self.create_scheduler()
# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID], [
10, 11
]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID], [
10, 11
]], # First request hits EOS, second continues
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped, second continues
self.assertEqual(len(scheduler.running), 1)
self.assertEqual(scheduler.running[0].request_id,
requests[1].request_id)
self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED)
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
self.assertEqual(list(requests[0].output_token_ids), [EOS_TOKEN_ID])
self.assertEqual(list(requests[1].output_token_ids), [10, 11])
# Test case 2: Stop on custom stop token
NUM_SPECULATIVE_TOKENS = 2
scheduler = self.create_scheduler()
requests = create_requests(num_requests=2,
max_tokens=10,
stop_token_ids=[42, 43])
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped on custom token
self.assertEqual(len(scheduler.running), 1)
self.assertEqual(scheduler.running[0].request_id,
requests[1].request_id)
self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED)
self.assertEqual(requests[0].stop_reason, 42)
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
self.assertEqual(list(requests[0].output_token_ids), [10, 42])
self.assertEqual(list(requests[1].output_token_ids), [13, 14])
# Test case 3: Stop on max tokens
NUM_SPECULATIVE_TOKENS = 2
scheduler = self.create_scheduler()
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped due to length
self.assertEqual(len(scheduler.running), 1)
self.assertEqual(scheduler.running[0].request_id,
requests[1].request_id)
self.assertEqual(requests[0].status,
RequestStatus.FINISHED_LENGTH_CAPPED)
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
self.assertEqual(list(requests[0].output_token_ids), [10, 11])
self.assertEqual(list(requests[1].output_token_ids), [13])
# Test case 4: Ignore EOS flag
scheduler = self.create_scheduler()
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify request continues past EOS
self.assertEqual(len(scheduler.running), 1)
self.assertFalse(requests[0].is_finished())
self.assertEqual(list(requests[0].output_token_ids),
[EOS_TOKEN_ID, 10, 11])
def test_schedule_concurrent_batches(self):
global MAX_NUM_BATCHED_TOKENS
global ENABLE_PREFIX_CACHING
global ENABLE_CHUNKED_PREFILL
global MAX_NUM_SEQS
global PROMPT_LOGPROBS
ENABLE_PREFIX_CACHING = None
MAX_NUM_BATCHED_TOKENS = 1024
MAX_NUM_SEQS = 2
ENABLE_CHUNKED_PREFILL = True
PROMPT_LOGPROBS = None
enable_prefix_caching_list = [None, True]
prompt_logprobs_list = [None, 5]
for i in range(len(enable_prefix_caching_list)):
ENABLE_PREFIX_CACHING = enable_prefix_caching_list[i]
PROMPT_LOGPROBS = prompt_logprobs_list[i]
scheduler = self.create_scheduler()
requests = create_requests(
num_requests=2,
num_tokens=512,
)
# Schedule the first request.
scheduler.add_request(requests[0])
scheduler_output0 = scheduler.schedule()
self.assertEqual(len(scheduler_output0.scheduled_new_reqs), 1)
self.assertEqual(
scheduler_output0.num_scheduled_tokens[requests[0].request_id],
512)
# The first request is still running, so only schedule the second request.
scheduler.add_request(requests[1])
scheduler_output1 = scheduler.schedule()
self.assertEqual(len(scheduler_output1.scheduled_new_reqs), 1)
self.assertEqual(
scheduler_output1.num_scheduled_tokens[requests[1].request_id],
512)
# Model output of the first request.
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output0,
model_runner_output)
# Schedule the next step.
# The first request can be scheduled again while the second
# request is still running.
scheduler.schedule()
# Model output of the second request.
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output1,
model_runner_output)
def test_schedule_spec_decoding_stats(self):
"""Test scheduling behavior with speculative decoding.
This test verifies that:
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
spec_tokens_list: List[List[List[int]]] = [[[1, 2, 3]], [[1, 2, 3]],
[[1, 2], [3]], [[1]], [[]],
[[1, 2, 3], [4, 5, 6]]]
output_tokens_list: List[List[List[int]]] = [[[1, 2, 3, 4]], [[1, 5]],
[[1, 2, 5], [3, 4]],
[[1, 2]], [[5]],
[[1, 2, 7], [4, 8]]]
expected_list: List[Tuple[int, int,
int, List[int]]] = [(1, 3, 3, [1, 1, 1]),
(1, 3, 1, [1, 0, 0]),
(2, 3, 3, [2, 1]),
(1, 1, 1, [1]),
(0, 0, 0, [0]),
(2, 6, 3, [2, 1, 0])]
global NUM_SPECULATIVE_TOKENS
for idx in range(len(spec_tokens_list)):
spec_tokens = spec_tokens_list[idx]
output_tokens = output_tokens_list[idx]
expected = expected_list[idx]
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
NUM_SPECULATIVE_TOKENS = num_spec_tokens
scheduler = self.create_scheduler()
requests = create_requests(num_requests=len(spec_tokens),
num_tokens=1)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
# Schedule a decode, which will also draft speculative tokens
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.total_num_scheduled_tokens, len(requests))
for i in range(len(requests)):
req_id = requests[i].request_id
self.assertEqual(output.num_scheduled_tokens[req_id], 1)
self.assertNotIn(req_id, output.scheduled_spec_decode_tokens)
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
spec_token_ids=spec_tokens,
pooler_output=[])
else:
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
engine_core_outputs = scheduler.update_from_output(
output, model_runner_output)
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
scheduler.update_draft_token_ids(draft_token_ids)
for i in range(len(requests)):
running_req = scheduler.running[i]
# The prompt token
self.assertEqual(running_req.num_computed_tokens, 1)
# The prompt token and the sampled token
self.assertEqual(running_req.num_tokens, 2)
# The prompt token, the sampled token, and the speculated tokens
self.assertEqual(running_req.num_tokens_with_spec,
2 + len(spec_tokens[i]))
# No draft or accepted tokens counted yet
self.assertTrue(
not engine_core_outputs
or (engine_core_outputs[0].scheduler_stats.spec_decoding_stats
is None))
# Schedule the speculated tokens for validation
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), 0)
# The sampled token and speculated tokens
self.assertEqual(
output.total_num_scheduled_tokens,
len(requests) + sum(len(ids) for ids in spec_tokens))
for i in range(len(requests)):
req_id = requests[i].request_id
self.assertEqual(output.num_scheduled_tokens[req_id],
1 + len(spec_tokens[i]))
if spec_tokens[i]:
self.assertEqual(
len(output.scheduled_spec_decode_tokens[req_id]),
len(spec_tokens[i]))
else:
self.assertNotIn(req_id,
output.scheduled_spec_decode_tokens)
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
engine_core_outputs = scheduler.update_from_output(
output, model_runner_output)
scheduler_stats = engine_core_outputs[0].scheduler_stats \
if engine_core_outputs else None
if expected[0] == 0:
self.assertIsNone(scheduler_stats.spec_decoding_stats)
else:
self.assertIsNotNone(scheduler_stats.spec_decoding_stats)
stats = scheduler_stats.spec_decoding_stats
self.assertEqual(stats.num_drafts, expected[0])
self.assertEqual(stats.num_draft_tokens, expected[1])
self.assertEqual(stats.num_accepted_tokens, expected[2])
self.assertEqual(stats.num_accepted_tokens_per_pos,
expected[3])
def assert_scheduler_empty(self, scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
scheduler = self.create_scheduler()
self.assertEqual(len(scheduler.requests), 0)
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), 0)
self.assertEqual(len(scheduler.finished_req_ids), 0)
# EncoderCacheManager.
self.assertEqual(len(scheduler.encoder_cache_manager.freed), 0)
self.assertEqual(len(scheduler.encoder_cache_manager.cached), 0)
# KVCache Manager.
self.assertEqual(
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks), 0)
self.assertEqual(
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block), 0)
num_free_blocks = (scheduler.kv_cache_manager.block_pool.
free_block_queue.num_free_blocks)
self.assertEqual(
num_free_blocks,
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
self.assertEqual(block.ref_cnt, 0)
def test_memory_leak(self):
"""Test that we do not have a memory leak."""
scheduler = self.create_scheduler()
NUM_REQUESTS = 5
NUM_TOKENS = 10
MAX_TOKENS = 10
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
# Add each request.
for request in requests:
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Iterate until done.
while True:
scheduler_output = scheduler.schedule()
if len(scheduler.running) == 0:
break
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm no memory leak.
self.assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,188 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import MagicMock, patch
import pytest
import torch
from tests.ut.base import PytestBase
from vllm_ascend.device_allocator.camem import (AllocationData, CaMemAllocator,
create_and_map,
find_loaded_library,
get_pluggable_allocator,
unmap_and_release)
def dummy_malloc(args):
pass
def dummy_free(ptr):
return (0, 0, 0, 0)
class TestCaMem(PytestBase):
def test_find_loaded_library_success_and_not_found(self):
path = find_loaded_library("libc")
assert path is not None, "Expected to find libc library"
assert path.endswith(".so.6") or ".so" in path
assert "libc" in path
path = find_loaded_library("non_existent_library")
assert path is None, "Expected to not find non-existent library"
@pytest.mark.parametrize("handle", [
(1, 2, 3),
("device", 99),
(None, ),
])
def test_create_and_map_calls_python_create_and_map(self, handle):
with patch("vllm_ascend.device_allocator.camem.python_create_and_map"
) as mock_create:
create_and_map(handle)
mock_create.assert_called_once_with(*handle)
@pytest.mark.parametrize("handle", [
(42, "bar"),
("foo", ),
])
def test_unmap_and_release_calls_python_unmap_and_release(self, handle):
with patch(
"vllm_ascend.device_allocator.camem.python_unmap_and_release"
) as mock_release:
unmap_and_release(handle)
mock_release.assert_called_once_with(*handle)
@patch("vllm_ascend.device_allocator.camem.init_module")
@patch(
"vllm_ascend.device_allocator.camem.torch.npu.memory.NPUPluggableAllocator"
)
def test_get_pluggable_allocator(self, mock_allocator_class,
mock_init_module):
mock_allocator_instance = MagicMock()
mock_allocator_class.return_value = mock_allocator_instance
def side_effect_malloc_and_free(malloc_fn, free_fn):
malloc_fn((1, 2, 3))
free_fn(123)
mock_init_module.side_effect = side_effect_malloc_and_free
allocator = get_pluggable_allocator(dummy_malloc, dummy_free)
mock_init_module.assert_called_once_with(dummy_malloc, dummy_free)
assert allocator == mock_allocator_instance
def test_singleton_behavior(self):
instance1 = CaMemAllocator.get_instance()
instance2 = CaMemAllocator.get_instance()
assert instance1 is instance2
def test_python_malloc_and_free_callback(self):
allocator = CaMemAllocator.get_instance()
# mock allocation_handle
handle = (1, 100, 1234, 0)
allocator.current_tag = "test_tag"
allocator.python_malloc_callback(handle)
# check pointer_to_data store data
ptr = handle[2]
assert ptr in allocator.pointer_to_data
data = allocator.pointer_to_data[ptr]
assert data.handle == handle
assert data.tag == "test_tag"
# check free callback with cpu_backup_tensor
data.cpu_backup_tensor = torch.zeros(1)
result_handle = allocator.python_free_callback(ptr)
assert result_handle == handle
assert ptr not in allocator.pointer_to_data
assert data.cpu_backup_tensor is None
@patch("vllm_ascend.device_allocator.camem.unmap_and_release")
@patch("vllm_ascend.device_allocator.camem.memcpy")
def test_sleep_offload_and_discard(self, mock_memcpy, mock_unmap):
allocator = CaMemAllocator.get_instance()
# prepare allocation one tag matchone not match
handle1 = (1, 10, 1000, 0)
data1 = AllocationData(handle1, "tag1")
handle2 = (2, 20, 2000, 0)
data2 = AllocationData(handle2, "tag2")
allocator.pointer_to_data = {
1000: data1,
2000: data2,
}
# mock is_pin_memory_available, return False as some machine only has cpu
with patch(
"vllm_ascend.device_allocator.camem.NPUPlatform.is_pin_memory_available",
return_value=False):
allocator.sleep(offload_tags="tag1")
# only offload tag1, other tag2 call unmap_and_release
assert data1.cpu_backup_tensor is not None
assert data2.cpu_backup_tensor is None
mock_unmap.assert_any_call(handle1)
mock_unmap.assert_any_call(handle2)
assert mock_unmap.call_count == 2
assert mock_memcpy.called
@patch("vllm_ascend.device_allocator.camem.create_and_map")
@patch("vllm_ascend.device_allocator.camem.memcpy")
def test_wake_up_loads_and_clears_cpu_backup(self, mock_memcpy,
mock_create_and_map):
allocator = CaMemAllocator.get_instance()
handle = (1, 10, 1000, 0)
tensor = torch.zeros(5, dtype=torch.uint8)
data = AllocationData(handle, "tag1", cpu_backup_tensor=tensor)
allocator.pointer_to_data = {1000: data}
allocator.wake_up(tags=["tag1"])
mock_create_and_map.assert_called_once_with(handle)
assert data.cpu_backup_tensor is None
assert mock_memcpy.called
def test_use_memory_pool_context_manager(self):
allocator = CaMemAllocator.get_instance()
old_tag = allocator.current_tag
# mock use_memory_pool_with_allocator
mock_ctx = MagicMock()
mock_ctx.__enter__.return_value = "data"
mock_ctx.__exit__.return_value = None
with patch(
"vllm_ascend.device_allocator.camem.use_memory_pool_with_allocator",
return_value=mock_ctx):
with allocator.use_memory_pool(tag="my_tag"):
assert allocator.current_tag == "my_tag"
# restore old tag after context manager exits
assert allocator.current_tag == old_tag
def test_get_current_usage(self):
allocator = CaMemAllocator.get_instance()
allocator.pointer_to_data = {
1: AllocationData((0, 100, 1, 0), "tag"),
2: AllocationData((0, 200, 2, 0), "tag"),
}
usage = allocator.get_current_usage()
assert usage == 300

View File

@@ -0,0 +1,84 @@
import os
from unittest.mock import MagicMock, patch
from vllm.distributed.utils import StatelessProcessGroup
from tests.ut.base import TestBase
from vllm_ascend.distributed.device_communicators.pyhccl import \
PyHcclCommunicator
class MockHcclLib:
pass
class MockUniqueId:
pass
class TestPyHcclCommunicator(TestBase):
@patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "1"})
def test_world_size_1_return_early(self):
comm = PyHcclCommunicator(
group=StatelessProcessGroup(0, 1, None, None),
device="npu:0",
)
self.assertTrue(comm.disabled)
self.assertFalse(comm.available)
@patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "2"})
def test_load_hccl_fail(self):
comm = PyHcclCommunicator(group=StatelessProcessGroup(
0, 2, None, None),
device="npu:0",
library_path="/not/exist/path/libhccl.so")
self.assertTrue(comm.disabled)
@patch(
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary",
MockHcclLib)
@patch(
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId",
MockUniqueId)
@patch("torch.npu.device")
@patch("vllm_ascend.utils.current_stream",
return_value=MagicMock(npu_stream=5678))
def test_stateless_group(self, *_):
group = StatelessProcessGroup(rank=3,
world_size=4,
store=None,
socket=None)
comm = PyHcclCommunicator(group=group, device=3)
self.assertEqual(comm.rank, 3)
self.assertEqual(comm.world_size, 4)
@patch.dict(os.environ, {"RANK": "1", "WORLD_SIZE": "2"})
@patch(
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary",
MockHcclLib)
@patch(
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId",
MockUniqueId)
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_backend", return_value="nccl")
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.get_world_size", return_value=2)
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
@patch("torch.distributed.broadcast")
@patch("torch.npu.device")
@patch("vllm_ascend.utils.current_stream",
return_value=MagicMock(npu_stream=1234))
def test_multi_gpu_pg_torch(
self,
*_,
):
fake_pg = MagicMock()
comm = PyHcclCommunicator(group=fake_pg, device="npu:1")
self.assertEqual(comm.rank, 1)
self.assertEqual(comm.world_size, 2)
self.assertFalse(comm.available)
self.assertTrue(comm.disabled)

View File

@@ -0,0 +1,173 @@
from unittest.mock import MagicMock, patch
import torch
from torch.distributed import ReduceOp
from tests.ut.base import TestBase
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
Function, HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t,
hcclDataType_t, hcclDataTypeEnum, hcclRedOp_t, hcclRedOpTypeEnum,
hcclResult_t, hcclUniqueId)
class TestHcclUniqueId(TestBase):
def test_construct(self):
uid = hcclUniqueId()
uid.internal[0] = 12
self.assertEqual(len(uid.internal), 4108)
self.assertEqual(uid.internal[0], 12)
class TestHcclDataTypeEnum(TestBase):
def test_torch_dtype_mapping(self):
expected = {
torch.int8: hcclDataTypeEnum.hcclInt8,
torch.uint8: hcclDataTypeEnum.hcclUint8,
torch.int32: hcclDataTypeEnum.hcclInt32,
torch.int64: hcclDataTypeEnum.hcclInt64,
torch.float16: hcclDataTypeEnum.hcclFloat16,
torch.float32: hcclDataTypeEnum.hcclFloat32,
torch.float64: hcclDataTypeEnum.hcclFloat64,
torch.bfloat16: hcclDataTypeEnum.hcclBfloat16,
}
for torch_dtype, expected_enum in expected.items():
with self.subTest(torch_dtype=torch_dtype):
self.assertEqual(hcclDataTypeEnum.from_torch(torch_dtype),
expected_enum)
def test_unsupported_dtype_raises(self):
with self.assertRaises(ValueError):
hcclDataTypeEnum.from_torch(torch.complex64)
class TestHcclRedOpTypeEnum(TestBase):
def test_torch_reduce_op_mapping(self):
expected = {
ReduceOp.SUM: hcclRedOpTypeEnum.hcclSum,
ReduceOp.PRODUCT: hcclRedOpTypeEnum.hcclProd,
ReduceOp.MAX: hcclRedOpTypeEnum.hcclMax,
ReduceOp.MIN: hcclRedOpTypeEnum.hcclMin,
}
for torch_op, expected_enum in expected.items():
with self.subTest(torch_op=torch_op):
self.assertEqual(hcclRedOpTypeEnum.from_torch(torch_op),
expected_enum)
def test_unsupported_op_raises(self):
unsupported_op = "NOT_EXIST"
with self.assertRaises(ValueError):
hcclRedOpTypeEnum.from_torch(unsupported_op)
class TestFunction(TestBase):
def test_construct_with_valid_args(self):
func = Function(name="foo", restype=int, argtypes=[int, str, float])
self.assertEqual(func.name, "foo")
self.assertIs(func.restype, int)
self.assertEqual(func.argtypes, [int, str, float])
class TestHCLLLibrary(TestBase):
def test_init_with_nonexistent_so(self):
fake_path = "/definitely/not/exist/libhccl.so"
with self.assertRaises(OSError):
HCCLLibrary(fake_path)
def test_hccl_get_error_string(self):
lib = MagicMock(sepc=HCCLLibrary)
mock_fn = MagicMock()
mock_fn.return_value = "HCCL internal error"
lib.hcclGetErrorString = mock_fn
result = hcclResult_t(1)
msg = lib.hcclGetErrorString(result)
self.assertEqual(msg, "HCCL internal error")
mock_fn.assert_called_once()
def test_hccl_check(self):
lib = HCCLLibrary.__new__(HCCLLibrary)
mock_fn = MagicMock()
mock_fn.return_value = "fake error"
lib.hcclGetErrorString = mock_fn
result = hcclResult_t(123)
with self.assertRaises(RuntimeError) as cm:
lib.HCCL_CHECK(result)
self.assertEqual(str(cm.exception), "HCCL error: fake error")
@patch.object(HCCLLibrary, "HCCL_CHECK")
def test_hccl_get_uniqueId(self, mock_HCCL_CHECK):
lib = HCCLLibrary.__new__(HCCLLibrary)
lib._funcs = {"HcclGetRootInfo": MagicMock(return_value=0)}
unique_id = lib.hcclGetUniqueId()
self.assertIsInstance(unique_id, hcclUniqueId)
lib._funcs["HcclGetRootInfo"].assert_called_once()
mock_HCCL_CHECK.assert_called_once_with(0)
@patch.object(HCCLLibrary, "HCCL_CHECK")
def test_hccl_comm_initRank(self, mock_hccl_check):
lib = HCCLLibrary.__new__(HCCLLibrary)
lib._funcs = {"HcclCommInitRootInfo": MagicMock(return_value=0)}
world_size = 4
unique_id = hcclUniqueId()
rank = 1
comm = lib.hcclCommInitRank(world_size, unique_id, rank)
self.assertIsInstance(comm, hcclComm_t)
lib._funcs["HcclCommInitRootInfo"].assert_called_once()
mock_hccl_check.assert_called_once_with(0)
@patch.object(HCCLLibrary, "HCCL_CHECK")
def test_hccl_all_reduce(self, mock_hccl_check):
lib = HCCLLibrary.__new__(HCCLLibrary)
lib._funcs = {"HcclAllReduce": MagicMock(return_value=0)}
sendbuff = buffer_type()
recvbuff = buffer_type()
count = 10
datatype = hcclDataType_t(1)
op = hcclRedOp_t(0)
comm = hcclComm_t()
stream = aclrtStream_t()
lib.hcclAllReduce(sendbuff, recvbuff, count, datatype, op, comm,
stream)
lib._funcs["HcclAllReduce"].assert_called_once_with(
sendbuff, recvbuff, count, datatype, op, comm, stream)
mock_hccl_check.assert_called_once_with(0)
@patch.object(HCCLLibrary, "HCCL_CHECK")
def test_hccl_broad_cast(self, mock_hccl_check):
lib = HCCLLibrary.__new__(HCCLLibrary)
lib._funcs = {"HcclBroadcast": MagicMock(return_value=0)}
buff = buffer_type()
count = 10
datatype = 1
root = 0
comm = hcclComm_t()
stream = aclrtStream_t()
lib.hcclBroadcast(buff, count, datatype, root, comm, stream)
lib._funcs["HcclBroadcast"].assert_called_once_with(
buff, count, datatype, root, comm, stream)
mock_hccl_check.assert_called_once_with(0)
@patch.object(HCCLLibrary, "HCCL_CHECK")
def test_hcclCommDestroy_success(self, mock_hccl_check):
lib = HCCLLibrary.__new__(HCCLLibrary)
lib._funcs = {"HcclCommDestroy": MagicMock(return_value=0)}
comm = hcclComm_t()
lib.hcclCommDestroy(comm)
lib._funcs["HcclCommDestroy"].assert_called_once_with(comm)
mock_hccl_check.assert_called_once_with(0)

View File

@@ -0,0 +1,89 @@
import unittest
from unittest.mock import MagicMock, patch
import torch
import torch.distributed as dist
from vllm_ascend.distributed.communicator import NPUCommunicator
class TestNPUCommunicator(unittest.TestCase):
@patch("vllm.config.get_current_vllm_config", return_value=None)
@patch("torch.npu.current_device", return_value=MagicMock())
@patch("torch.npu.set_device", return_value=MagicMock())
@patch("torch.distributed.get_process_group_ranks",
return_value={
0: 0,
1: 1
})
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_backend", return_value="hccl")
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.get_world_size", return_value=2)
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
@patch("torch.npu.device")
def test_all_to_all_with_sizes(self, *_):
def patched_all_to_all(output_tensor_list,
input_tensor_list,
group=None,
async_op=False):
output_tensor_list[:] = ([
torch.tensor([10, 20]),
torch.tensor([50, 60])
])
torch.distributed.all_to_all = patched_all_to_all
scatter_sizes = [2, 2]
gather_sizes = [2, 2]
input_ = torch.tensor([10, 20, 30, 40])
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
output = comm.all_to_all(input_,
scatter_sizes=scatter_sizes,
gather_sizes=gather_sizes)
assert output.tolist() == [10, 20, 50, 60]
@patch("vllm.config.get_current_vllm_config", return_value=None)
@patch("torch.npu.current_device", return_value=MagicMock())
@patch("torch.npu.set_device", return_value=MagicMock())
@patch("torch.distributed.get_process_group_ranks",
return_value={
0: 0,
1: 1
})
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_backend", return_value="hccl")
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.get_world_size", return_value=2)
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
@patch("torch.npu.device")
def test_all_to_all_without_sizes(self, *_):
def patched_all_to_all(output_tensor_list,
input_tensor_list,
group=None,
async_op=False):
output_tensor_list[:] = ([
torch.tensor([[10, 20]]),
torch.tensor([[50, 60]])
])
torch.distributed.all_to_all = patched_all_to_all
input_ = torch.tensor([[10, 20], [30, 40]])
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0)
assert output.tolist() == [[10, 20], [50, 60]]

View File

@@ -0,0 +1,139 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
import importlib
import pytest
import torch
from pytest_mock import MockerFixture
from tests.ut.base import PytestBase
from vllm_ascend.distributed.tensor_parallel import (
_gather_along_first_dim, _gather_along_last_dim,
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim,
all_to_all_hp2sp, all_to_all_sp2hp)
class TestDistributedCommunication(PytestBase):
@pytest.fixture(autouse=True)
def context(self, mocker: MockerFixture):
mocker.patch("torch.npu.current_device", return_value="cpu")
mocker.patch("torch.distributed.get_world_size", return_value=4)
mocker.patch("torch.distributed.get_rank", return_value=0)
@pytest.mark.parametrize("world_size, test_tensor, expected",
[(1, torch.randn(8, 16), (8, 16)),
(4, torch.randn(8, 16), (32, 16))])
def test_gather_along_first_dim(self, test_tensor, expected, world_size,
mocker: MockerFixture):
"""test _gather_along_first_dim"""
mocker.patch("torch.distributed.get_world_size",
return_value=world_size)
result = _gather_along_first_dim(test_tensor, mocker.MagicMock())
assert result.shape == expected
@pytest.mark.parametrize("test_tensor, output_split_sizes, expected", [
(torch.randn(8, 16), [5, 10, 15, 2], (32, 16)),
])
def test_gather_along_first_dim_unequal_split(self, test_tensor, expected,
output_split_sizes,
mocker: MockerFixture):
"""test _gather_along_first_dim"""
result = _gather_along_first_dim(test_tensor, mocker.MagicMock(),
output_split_sizes)
assert result.shape == expected
@pytest.mark.parametrize("world_size, test_tensor, expected",
[(1, torch.randn(8, 16, 32), (8, 16, 32)),
(4, torch.randn(8, 16, 32), (8, 16, 32 * 4))])
def test_gather_along_last_dim(self, test_tensor, expected, world_size,
mocker: MockerFixture):
"""test _gather_along_last_dim"""
mocker.patch("torch.distributed.get_world_size",
return_value=world_size)
result = _gather_along_last_dim(test_tensor, mocker.MagicMock())
assert result.shape == expected
@pytest.mark.parametrize("input_shape,expected_shape", [
((32, 16), (8, 16)),
((40, 10), (10, 10)),
])
def test_reduce_scatter_along_first_dim(self, input_shape, expected_shape,
mocker: MockerFixture):
input_tensor = torch.randn(*input_shape)
result = _reduce_scatter_along_first_dim(input_tensor,
mocker.MagicMock())
assert result.shape == expected_shape
@pytest.mark.parametrize("input_shape,expected_shape", [
((8, 16, 32), (8, 16, 8)),
])
def test_reduce_scatter_along_last_dim(self, input_shape, expected_shape,
mocker: MockerFixture):
input_tensor = torch.randn(*input_shape)
result = _reduce_scatter_along_last_dim(input_tensor,
mocker.MagicMock())
assert result.shape == expected_shape
@pytest.mark.parametrize("func,input_shape,expected_shape", [
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32),
(8, 16, 128)),
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)),
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32),
(8, 16, 8)),
("gather_from_sequence_parallel_region", (8, 16), (32, 16)),
])
def test_wrapper_functions(self, func, input_shape, expected_shape,
mocker: MockerFixture):
"""test wrapper funcs"""
mod = importlib.import_module(
'vllm_ascend.distributed.tensor_parallel')
globals = mod.__dict__
test_func = globals[func]
input_tensor = torch.randn(*input_shape)
result = test_func(input_tensor, mocker.MagicMock())
assert result.shape == expected_shape
@pytest.mark.parametrize(
"input_shape,output_shape",
[
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP]
])
def test_all_to_all_sp2hp(self, input_shape, output_shape,
mocker: MockerFixture):
input_tensor = torch.randn(*input_shape)
result = all_to_all_sp2hp(input_tensor, mocker.MagicMock())
assert result.shape == output_shape
@pytest.mark.parametrize(
"input_shape,output_shape",
[
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H]
])
def test_all_to_all_hp2sp(self, input_shape, output_shape,
mocker: MockerFixture):
input_tensor = torch.randn(*input_shape)
result = all_to_all_hp2sp(input_tensor, mocker.MagicMock())
assert result.shape == output_shape

View File

@@ -0,0 +1,44 @@
from unittest.mock import MagicMock, patch
import pytest
from vllm.config import ParallelConfig
from vllm_ascend.distributed.parallel_state import (
_LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group,
get_mc2_group, init_ascend_model_parallel)
@pytest.fixture
def parallel_config():
return ParallelConfig(data_parallel_size=2,
tensor_parallel_size=2,
pipeline_parallel_size=2)
@pytest.fixture
def mock_distributed():
with patch('torch.distributed.is_initialized', return_value=True), \
patch('torch.distributed.get_world_size', return_value=8), \
patch('torch.distributed.get_backend', return_value='nccl'), \
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group:
mock_group.return_value.local_rank = 0
mock_group.return_value.device_group = MagicMock()
yield
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_ascend_config = MagicMock()
mock_ascend_config.lmhead_tensor_parallel_size = 2
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
init_ascend_model_parallel(parallel_config)
mc2_group = get_mc2_group()
assert mc2_group is not None
lmheadtp_group = get_lmhead_tp_group()
assert lmheadtp_group is not None
destroy_ascend_model_parallel()
assert _MC2 is None
assert _LMTP is None

View File

@@ -0,0 +1,28 @@
{
"_name_or_path": "facebook/opt-125m",
"activation_dropout": 0.0,
"activation_function": "relu",
"architectures": [
"OPTForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 2,
"do_layer_norm_before": true,
"dropout": 0.1,
"eos_token_id": 2,
"ffn_dim": 3072,
"hidden_size": 768,
"init_std": 0.02,
"layerdrop": 0.0,
"max_position_embeddings": 2048,
"model_type": "opt",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"prefix": "</s>",
"torch_dtype": "float16",
"transformers_version": "4.21.0.dev0",
"use_cache": true,
"vocab_size": 50272,
"word_embed_proj_dim": 768
}

View File

@@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import os
import types
from tests.ut.kv_connector.utils import (create_request, create_scheduler,
create_vllm_config)
from vllm_ascend.distributed.llmdatadist_c_mgr_connector import (
LLMDataDistCMgrConnectorMetadata, LLMDataDistCMgrConnectorWorker, LLMRole)
def test_basic_inferface():
"""Unit test for basic LLMDataDistCMgrConnector interface functionality."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_id = request.request_id
scheduler.add_request(request)
# Remote Prefill, triggers LLMDataDistCMgrConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, LLMDataDistCMgrConnectorMetadata)
assert len(kv_connector_metadata.requests) == 1
assert request_id in kv_connector_metadata.requests
req_meta = kv_connector_metadata.requests[request_id]
for block_id, block in zip(
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id]):
assert block_id == block.block_id
def test_read_agent_metadata():
rank_table = {
"version":
"1.2",
"server_count":
"2",
"prefill_device_list": [{
"server_id": "192.168.1.1",
"device_id": "0",
"device_ip": "10.30.0.1",
"cluster_id": "0",
}, {
"server_id": "192.168.1.1",
"device_id": "1",
"device_ip": "10.30.0.2",
"cluster_id": "1",
}, {
"server_id": "192.168.1.2",
"device_id": "0",
"device_ip": "10.30.0.3",
"cluster_id": "2",
}, {
"server_id": "192.168.1.2",
"device_id": "1",
"device_ip": "10.30.0.4",
"cluster_id": "3",
}]
}
def get_device_ip(worker_local_ip, worker_tp_rank, worker_visible_devices):
old_visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
worker = types.SimpleNamespace()
worker.local_ip = worker_local_ip
worker.tp_rank = worker_tp_rank
worker.llm_datadist_role = LLMRole.PROMPT
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = worker_visible_devices
agent_metadata = LLMDataDistCMgrConnectorWorker.read_agent_metadata(
worker, rank_table)
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = old_visible_devices
return agent_metadata.device_ip
assert get_device_ip("192.168.1.1", 0, "0") == "10.30.0.1"
assert get_device_ip("192.168.1.1", 0, "1") == "10.30.0.2"
assert get_device_ip("192.168.1.2", 0, "0") == "10.30.0.3"
assert get_device_ip("192.168.1.2", 0, "1") == "10.30.0.4"
assert get_device_ip("192.168.1.1", 0, "0,1") == "10.30.0.1"
assert get_device_ip("192.168.1.1", 1, "0,1") == "10.30.0.2"
assert get_device_ip("192.168.1.1", 0, "") == "10.30.0.1"
assert get_device_ip("192.168.1.1", 1, "") == "10.30.0.2"

View File

@@ -0,0 +1,998 @@
import os
import queue
import socket
import sys
import threading
import time
import types
import unittest
from collections import defaultdict, deque
from unittest.mock import MagicMock, patch
import msgspec
import zmq
from vllm.utils import make_zmq_path
fake_engine = types.ModuleType("mooncake.engine")
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
sys.modules["mooncake.engine"] = fake_engine
from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402
KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker,
KVConnectorRole, MooncakeAgentMetadata, MooncakeConnector,
MooncakeConnectorMetadata, MooncakeConnectorScheduler,
MooncakeConnectorWorker, ReqMeta, ensure_zmq_recv, ensure_zmq_send,
group_concurrent_contiguous, string_to_int64_hash, zmq_ctx)
GET_META_MSG = b"get_meta_msg"
DONE_RECVING_MSG = b"done_recving_msg"
class TestKVCacheTaskTrackerInit(unittest.TestCase):
def test_init_basic_properties(self):
tracker = KVCacheTaskTracker()
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
self.assertIsInstance(tracker.finished_requests, set)
self.assertIsInstance(tracker.delayed_free_requests, deque)
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
def setUp(self):
self.tracker = KVCacheTaskTracker()
self.tracker.finished_requests = set()
self.tracker.done_task_lock = threading.Lock()
def test_empty_requests(self):
result = self.tracker.get_and_clear_finished_requests()
self.assertEqual(result, set())
self.assertEqual(len(self.tracker.finished_requests), 0)
def test_single_request(self):
self.tracker.finished_requests = {"req_123"}
result = self.tracker.get_and_clear_finished_requests()
self.assertEqual(result, {"req_123"})
self.assertEqual(len(self.tracker.finished_requests), 0)
def test_multiple_requests(self):
self.tracker.finished_requests = {"req_1", "req_2", "req_3"}
result = self.tracker.get_and_clear_finished_requests()
self.assertSetEqual(result, {"req_1", "req_2", "req_3"})
self.assertEqual(len(self.tracker.finished_requests), 0)
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_concurrent_access(self, mock_logger):
from concurrent.futures import ThreadPoolExecutor
self.tracker.finished_requests = {"req_1", "req_2"}
with ThreadPoolExecutor(max_workers=3) as executor:
futures = [
executor.submit(self.tracker.get_and_clear_finished_requests)
for _ in range(3)
]
results = [f.result() for f in futures]
self.assertEqual(sum(1 for r in results if r), 1)
self.assertEqual(len(self.tracker.finished_requests), 0)
class TestKVCacheSendingThreadInit(unittest.TestCase):
def setUp(self):
self.common_args = {
'tp_rank': 1,
'decode_tp_size': 4,
'local_engine_id': 'engine_1',
'side_channel_host': 'localhost',
'side_channel_port': 5555,
'metadata': MagicMock(),
'ready_event': threading.Event()
}
self.threads = []
def tearDown(self):
for thread in self.threads:
if hasattr(thread, 'task_tracker') and hasattr(
thread.task_tracker, 'socket'):
thread.task_tracker.socket.close()
if hasattr(thread, 'is_alive') and thread.is_alive():
thread.join(timeout=0.1)
def test_thread_daemon_property(self):
thread = KVCacheSendingThread(**self.common_args)
self.threads.append(thread)
self.assertTrue(thread.daemon)
def test_thread_name_format(self):
thread = KVCacheSendingThread(**self.common_args)
self.threads.append(thread)
self.assertEqual(thread.name, "KVCacheSendingThread")
def test_ready_event_reference(self):
custom_event = threading.Event()
args = self.common_args.copy()
args['ready_event'] = custom_event
thread = KVCacheSendingThread(**args)
self.threads.append(thread)
self.assertIs(thread.ready_event, custom_event)
class TestGetAndClearFinishedRequests(unittest.TestCase):
def setUp(self):
self.common_args = {
'tp_rank': 1,
'decode_tp_size': 4,
'local_engine_id': 'engine_1',
'side_channel_host': 'localhost',
'side_channel_port': 5555,
'metadata': {
"test": "metadata"
},
'ready_event': threading.Event()
}
self.thread = KVCacheSendingThread(**self.common_args)
@patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests')
def test_get_and_clear_finished_requests(self, mock_get_clear):
expected_requests = {'req1', 'req2'}
mock_get_clear.return_value = expected_requests
result = self.thread.get_and_clear_finished_requests()
mock_get_clear.assert_called_once()
self.assertEqual(result, expected_requests)
class TestKVCacheSendingThread(unittest.TestCase):
def test_run_handles_get_meta_and_done_recv_msgs(self):
ready_event = threading.Event()
metadata = MooncakeAgentMetadata(
engine_id="engine1",
te_rpc_port=9090,
kv_caches_base_addr=[12345678],
num_blocks=2,
)
host = "127.0.0.1"
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
free_port = s.getsockname()[1]
thread = KVCacheSendingThread(
tp_rank=0,
decode_tp_size=1,
local_engine_id="engine1",
side_channel_host=host,
side_channel_port=free_port,
metadata=metadata,
ready_event=ready_event,
)
thread.start()
self.assertTrue(ready_event.wait(timeout=3),
"Server thread startup timeout")
context = zmq.Context() # type: ignore
sock = context.socket(zmq.DEALER) # type: ignore
sock.connect(f"tcp://{host}:{free_port}")
encoder = msgspec.msgpack.Encoder()
decoder = msgspec.msgpack.Decoder(type=MooncakeAgentMetadata)
sock.send_multipart([b"", encoder.encode((GET_META_MSG, ))])
frames = sock.recv_multipart()
self.assertEqual(frames[0], b"")
meta = decoder.decode(frames[1])
self.assertEqual(meta.engine_id, "engine1")
self.assertEqual(meta.kv_caches_base_addr, [12345678])
self.assertEqual(meta.num_blocks, 2)
req_id = "request_42"
sock.send_multipart(
[b"", encoder.encode((DONE_RECVING_MSG, req_id, 0))])
frames = sock.recv_multipart()
self.assertEqual(frames[0], b"")
self.assertEqual(frames[1], b"ACK")
self.assertIn(req_id, thread.task_tracker.finished_requests)
sock.close()
context.term()
class TestKVCacheRecvingThreadBasic(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
def test_add_request(self):
test_req = {
"request_id": "req1",
"local_block_ids": [1, 2],
"remote_block_ids": [3, 4],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_handshake_port": 6666,
}
self.thread.add_request(**test_req)
queued = self.thread.request_queue.get_nowait()
self.assertEqual(queued["request_id"], "req1")
self.assertEqual(queued["remote_host"], "localhost")
@patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests')
def test_get_finished_requests(self, mock_tracker):
mock_tracker.return_value = {"req1", "req2"}
result = self.thread.get_and_clear_finished_requests()
self.assertEqual(result, {"req1", "req2"})
class TestSocketManagement(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
self.thread.remote_sockets = defaultdict(deque)
self.thread.remote_poller = MagicMock()
@patch('vllm_ascend.distributed.mooncake_connector.zmq.Context')
@patch('vllm_ascend.distributed.mooncake_connector.make_zmq_socket')
def test_get_remote_socket(self, mock_make_socket, mock_context):
mock_sock = MagicMock()
mock_make_socket.return_value = mock_sock
test_host = "test_host"
test_port = 12345
sock = self.thread._get_remote_socket(test_host, test_port)
self.assertEqual(sock, mock_sock)
mock_make_socket.assert_called_once()
args, kwargs = mock_make_socket.call_args
self.assertEqual(kwargs.get('path'), 'tcp://test_host:12345')
self.assertEqual(kwargs.get('socket_type'), zmq.REQ) # type: ignore
self.assertFalse(kwargs.get('bind', True))
self.thread.remote_poller.register.assert_called_with(
mock_sock, zmq.POLLIN) # type: ignore
def test_return_socket_to_pool(self):
mock_sock = MagicMock()
test_host = "test_host"
test_port = 12345
test_path = make_zmq_path("tcp", test_host, test_port)
self.thread._return_remote_socket(mock_sock, test_host, test_port)
self.assertEqual(len(self.thread.remote_sockets[test_path]), 1)
self.assertEqual(self.thread.remote_sockets[test_path][0], mock_sock)
self.thread.remote_poller.register.assert_not_called()
class TestCoreFunctionality(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.mock_queue = MagicMock()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
self.thread.request_queue = self.mock_queue
self.test_req = {
"request_id": "req1",
"local_block_ids": [1, 2],
"remote_block_ids": [3, 4],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_handshake_port": 6666,
"remote_transfer_port": 7777
}
self.thread.task_tracker = MagicMock()
self.engine.batch_transfer_sync_read.return_value = 0
self.thread.remote_te_port = {"remote_engine": {6666: 7777}}
@patch.object(KVCacheRecvingThread, '_transfer_kv_cache')
@patch.object(KVCacheRecvingThread, '_send_done_recv_signal')
def test_handle_request(self, mock_send, mock_transfer):
self.thread._handle_request(self.test_req)
mock_transfer.assert_called_once_with(self.test_req)
mock_send.assert_called_once_with("req1", "localhost", 6666)
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
"req1")
self.mock_queue.task_done.assert_called_once()
@patch.object(KVCacheRecvingThread, '_get_remote_metadata')
def test_transfer_kv_cache(self, mock_get_meta):
self.thread.kv_caches_base_addr["remote_engine"] = {
6666: [0x3000, 0x4000]
}
self.thread._transfer_kv_cache(self.test_req)
self.engine.batch_transfer_sync_read.assert_called_once()
call_args, call_kwargs = self.engine.batch_transfer_sync_read.call_args
self.assertEqual(call_args[0], "localhost:7777")
self.assertIsInstance(call_args[1], list)
self.assertIsInstance(call_args[2], list)
self.assertIsInstance(call_args[3], list)
self.assertEqual(len(call_args[1]), len(call_args[2]))
self.assertEqual(len(call_args[1]), len(call_args[3]))
mock_get_meta.assert_not_called()
def test_transfer_kv_cache_failure(self):
self.engine.batch_transfer_sync_read.return_value = -1
self.thread.kv_caches_base_addr["remote_engine"] = {
6666: [0x3000, 0x4000]
}
with self.assertRaises(RuntimeError):
self.thread._transfer_kv_cache(self.test_req)
class TestMetadataHandling(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
self.test_metadata = MooncakeAgentMetadata(
engine_id="remote_engine",
te_rpc_port=9090,
kv_caches_base_addr=[0x3000, 0x4000],
num_blocks=2)
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send')
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv')
def test_get_remote_metadata_success(self, mock_recv, mock_send):
mock_recv.return_value = msgspec.msgpack.encode(self.test_metadata)
with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \
patch.object(self.thread, '_return_remote_socket') as mock_return_socket:
mock_socket = MagicMock()
mock_get_socket.return_value = mock_socket
self.thread._get_remote_metadata("host1", 5555)
mock_get_socket.assert_called_once_with("host1", 5555)
mock_return_socket.assert_called_once_with(mock_socket, "host1",
5555)
mock_send.assert_called_once_with(
mock_socket, self.thread.encoder.encode((GET_META_MSG, "")))
mock_recv.assert_called_once_with(mock_socket,
self.thread.remote_poller)
self.assertEqual(
self.thread.kv_caches_base_addr["remote_engine"][5555],
[0x3000, 0x4000])
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send')
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv',
side_effect=Exception("Network error"))
def test_get_remote_metadata_failure(self, mock_recv, mock_send):
with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \
patch.object(self.thread, '_return_remote_socket') as mock_return_socket:
mock_socket = MagicMock()
mock_get_socket.return_value = mock_socket
with self.assertRaises(Exception) as context:
self.thread._get_remote_metadata("host1", 5555)
self.assertEqual(str(context.exception), "Network error")
mock_return_socket.assert_called_once()
class TestMainThreadLoop(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
self.thread.request_queue = queue.Queue()
@patch.object(KVCacheRecvingThread, '_handle_request')
def test_run_loop_normal(self, mock_handle):
test_request = {
"request_id": "req1",
"local_block_ids": [1, 2],
"remote_block_ids": [3, 4],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_handshake_port": 6666,
"remote_transfer_port": 7777
}
self.thread.request_queue.put(test_request)
self.thread.request_queue.put(None)
self.thread.start()
time.sleep(0.1)
self.thread.join(timeout=1.0)
self.assertTrue(self.thread.ready_event.is_set())
mock_handle.assert_called_once_with(test_request)
self.assertTrue(self.thread.request_queue.empty())
class MockVllmConfig:
def __init__(self):
self.model_config = MagicMock()
self.parallel_config = MagicMock()
self.cache_config = MagicMock()
self.kv_transfer_config = MagicMock()
self.model_config.use_mla = True
self.parallel_config.tensor_parallel_size = 2
self.parallel_config.data_parallel_rank_local = 0
self.parallel_config.data_parallel_size_local = 1
self.cache_config.block_size = 16
self.kv_transfer_config.kv_port = 5000
self.kv_transfer_config.kv_role = 'kv_producer'
self.kv_transfer_config.get_from_extra_config = MagicMock()
self.kv_transfer_config.get_from_extra_config.side_effect = lambda k, d: {
"prefill": {
"tp_size": 2,
"dp_size": 1
},
"decode": {
"tp_size": 2,
"dp_size": 1
}
}.get(k, d)
class MockRequest:
def __init__(self,
request_id,
prompt_token_ids=None,
kv_transfer_params=None,
status=None):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids or [1, 2, 3, 4]
self.kv_transfer_params = kv_transfer_params or {}
self.status = status or "running"
self.output_token_ids = [101, 102]
class TestKVCacheTaskTracker(unittest.TestCase):
def setUp(self):
self.tracker = KVCacheTaskTracker()
def test_update_done_task_count(self):
self.assertEqual(len(self.tracker.finished_requests), 0)
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
current_time = time.time()
self.tracker.add_delayed_request("req_1", current_time)
result = self.tracker.delayed_free_requests
self.assertEqual(len(result), 1)
self.assertEqual(result[0], ("req_1", current_time))
self.tracker.update_done_task_count("req_1")
result_finished = self.tracker.finished_requests
result_delayed = self.tracker.delayed_free_requests
self.assertEqual(result_finished, {"req_1"})
self.assertEqual(len(result_delayed), 0)
def test_retrieve_expired_requests(self):
current_time = time.time()
self.tracker.add_delayed_request("req_1", current_time - 600)
self.tracker.add_delayed_request("req_2", current_time)
result = self.tracker._retrieve_expired_requests()
self.assertEqual(result, {
"req_1",
})
result_delay = self.tracker.delayed_free_requests
self.assertEqual(len(result_delay), 1)
self.assertEqual(result_delay[0], ("req_2", current_time))
def test_duplicate_task_update(self):
self.tracker.update_done_task_count("req1")
self.tracker.update_done_task_count("req1")
self.tracker.update_done_task_count("req1")
finished = self.tracker.get_and_clear_finished_requests()
self.assertEqual(finished, {"req1"})
class TestMooncakeConnectorMetadata(unittest.TestCase):
def test_add_new_req(self):
meta = MooncakeConnectorMetadata()
self.assertEqual(len(meta.requests), 0)
self.assertEqual(len(meta.requests_to_send), 0)
meta.add_new_req(request_id="req1",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_port": 5000
})
self.assertEqual(len(meta.requests), 1)
req_meta = meta.requests["req1"]
self.assertIsInstance(req_meta, ReqMeta)
self.assertEqual(req_meta.local_block_ids, [1, 2, 3])
self.assertEqual(req_meta.remote_block_ids, [4, 5, 6])
self.assertEqual(req_meta.remote_engine_id, "remote_engine")
self.assertEqual(req_meta.remote_host, "localhost")
self.assertEqual(req_meta.remote_port, 5000)
class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
def setUp(self):
config = MockVllmConfig()
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
def test_get_num_new_matched_tokens(self):
request = MockRequest("req1")
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
request, 0)
self.assertEqual(tokens, 0)
self.assertFalse(async_flag)
request.kv_transfer_params = {"do_remote_prefill": True}
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
request, 0)
self.assertEqual(tokens, 3)
self.assertTrue(async_flag)
def test_build_connector_meta(self):
request = MockRequest("req1")
blocks_mock = MagicMock()
blocks_mock.get_unhashed_block_ids.return_value = [4, 5, 6]
self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6])
request.kv_transfer_params = {
"remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote",
"remote_host": "localhost",
"remote_port": 5000
}
meta = self.scheduler.build_connector_meta(MagicMock())
self.assertIsInstance(meta, MooncakeConnectorMetadata)
self.assertEqual(len(meta.requests), 1)
self.assertEqual(meta.requests["req1"].local_block_ids, [4, 5, 6])
self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3])
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
def test_get_finished_count(self):
count = self.scheduler.get_finished_count()
self.assertEqual(count, 2)
class TestHelperFunctions(unittest.TestCase):
def test_group_concurrent_contiguous(self):
src: list[int] = [1, 2, 3, 5, 6]
dst: list[int] = [10, 11, 12, 14, 15]
src_groups, dst_groups = group_concurrent_contiguous(src, dst)
self.assertEqual(len(src_groups), 2)
self.assertEqual(src_groups[0], [1, 2, 3])
self.assertEqual(src_groups[1], [5, 6])
self.assertEqual(dst_groups[0], [10, 11, 12])
self.assertEqual(dst_groups[1], [14, 15])
def test_group_concurrent_contiguous_empty(self):
src: list[int] = []
dst: list[int] = []
src_groups, dst_groups = group_concurrent_contiguous(src, dst)
self.assertEqual(src_groups, [])
self.assertEqual(dst_groups, [])
def test_string_to_int64_hash(self):
hash1 = string_to_int64_hash("test_string")
hash2 = string_to_int64_hash("test_string")
self.assertEqual(hash1, hash2)
hash3 = string_to_int64_hash("different_string")
self.assertNotEqual(hash1, hash3)
class TestMooncakeConnectorForScheduler(unittest.TestCase):
def test_scheduler_role(self):
config = MockVllmConfig()
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler)
self.assertIsNone(connector.connector_worker)
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
def test_scheduler_methods(self, mock_method):
config = MockVllmConfig()
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
connector.get_num_new_matched_tokens(request, 0)
mock_method.assert_called_once_with(request, 0)
class MockKVCacheBlocks:
def get_unhashed_block_ids(self):
return [4, 5, 6]
class MockSchedulerOutput:
pass
class MockForwardContext:
pass
class TestMooncakeConnector(unittest.TestCase):
def setUp(self):
self.config = MockVllmConfig()
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
def test_scheduler_initialization(self):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler)
self.assertIsNone(connector.connector_worker)
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
def test_get_num_new_matched_tokens(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
connector.get_num_new_matched_tokens(request, 0)
mock_method.assert_called_once_with(request, 0)
@patch.object(MooncakeConnectorScheduler, "update_state_after_alloc")
def test_update_state_after_alloc(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
blocks = MockKVCacheBlocks()
connector.update_state_after_alloc(request, blocks, 3)
mock_method.assert_called_once_with(request, blocks, 3)
@patch.object(MooncakeConnectorScheduler, "build_connector_meta")
def test_build_connector_meta(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
scheduler_output = MockSchedulerOutput()
connector.build_connector_meta(scheduler_output)
mock_method.assert_called_once_with(scheduler_output)
@patch.object(MooncakeConnectorScheduler, "request_finished")
def test_request_finished(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
connector.request_finished(request, [1, 2, 3])
mock_method.assert_called_once_with(request, [1, 2, 3])
class TestMooncakeConnectorScheduler(unittest.TestCase):
def setUp(self):
self.config = MockVllmConfig()
self.scheduler = MooncakeConnectorScheduler(self.config, "test_engine")
def test_get_num_new_matched_tokens_no_remote_prefill(self):
request = MockRequest("req1")
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
request, 0)
self.assertEqual(tokens, 0)
self.assertFalse(async_flag)
def test_get_num_new_matched_tokens_with_remote_prefill(self):
request = MockRequest("req1",
kv_transfer_params={"do_remote_prefill": True})
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
request, 0)
self.assertEqual(tokens, 3)
self.assertTrue(async_flag)
def test_update_state_after_alloc_no_remote_prefill(self):
request = MockRequest("req1")
blocks = MagicMock()
self.scheduler.update_state_after_alloc(request, blocks, 0)
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
def test_update_state_after_alloc_with_remote_prefill(self):
request = MockRequest("req1",
kv_transfer_params={
"do_remote_prefill": True,
"remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote",
"remote_host": "localhost",
"remote_port": 5000
})
blocks = MockKVCacheBlocks()
self.scheduler.update_state_after_alloc(request, blocks, 3)
self.assertEqual(len(self.scheduler._reqs_need_recv), 1)
self.assertEqual(self.scheduler._reqs_need_recv["req1"][0], request)
self.assertEqual(self.scheduler._reqs_need_recv["req1"][1], [4, 5, 6])
def test_request_finished_no_remote_decode(self):
request = MockRequest("req1")
delay_free, params = self.scheduler.request_finished(
request, [1, 2, 3])
self.assertFalse(delay_free)
self.assertIsNone(params)
class TestUtils(unittest.TestCase):
def test_string_to_int64_hash(self):
h1 = string_to_int64_hash("hello")
h2 = string_to_int64_hash("hello")
h3 = string_to_int64_hash("world")
self.assertEqual(h1, h2)
self.assertNotEqual(h1, h3)
self.assertIsInstance(h1, int)
def test_group_concurrent_contiguous(self):
src: list[int] = [1, 2, 3, 5, 6]
dst: list[int] = [10, 11, 12, 20, 21]
src_g, dst_g = group_concurrent_contiguous(src, dst)
self.assertEqual(src_g, [[1, 2, 3], [5, 6]])
self.assertEqual(dst_g, [[10, 11, 12], [20, 21]])
def test_group_empty(self):
src_g, dst_g = group_concurrent_contiguous([], [])
self.assertEqual(src_g, [])
self.assertEqual(dst_g, [])
def test_zmq_ctx_invalid_type(self):
with self.assertRaises(ValueError):
with zmq_ctx("INVALID", "tcp://127.0.0.1:5555"):
pass
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket")
def test_zmq_ctx_ok(self, mock_make_socket):
mock_socket = MagicMock()
mock_make_socket.return_value = mock_socket
with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore
self.assertEqual(s, mock_socket)
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_ensure_zmq_send_success(self, mock_logger):
mock_socket = MagicMock()
ensure_zmq_send(mock_socket, b"hello")
mock_socket.send.assert_called_once_with(b"hello")
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_ensure_zmq_send_retry_and_fail(self, mock_logger):
mock_socket = MagicMock()
mock_socket.send.side_effect = zmq.ZMQError( # type: ignore
"send failed")
with self.assertRaises(RuntimeError):
ensure_zmq_send(mock_socket, b"hello", max_retries=2)
self.assertEqual(mock_socket.send.call_count, 2)
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_ensure_zmq_recv_success(self, mock_logger):
mock_socket = MagicMock()
mock_socket.recv.return_value = b"response"
mock_poller = MagicMock()
mock_poller.poll.return_value = [
(mock_socket, zmq.POLLIN) # type: ignore
]
data = ensure_zmq_recv(mock_socket, mock_poller)
self.assertEqual(data, b"response")
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_ensure_zmq_recv_timeout_and_fail(self, mock_logger):
mock_socket = MagicMock()
mock_poller = MagicMock()
mock_poller.poll.return_value = []
with self.assertRaises(RuntimeError):
ensure_zmq_recv(mock_socket,
mock_poller,
timeout=0.01,
max_retries=2)
class MockMooncakeAgentMetadata:
def __init__(self, **kwargs):
pass
class MockMooncakeConnectorMetadata:
def __init__(self):
self.requests = {}
class MockKVCacheSendingThread(threading.Thread):
def __init__(self, *args, **kwargs):
super().__init__()
self.daemon = True
self._finished_requests = set()
def get_and_clear_finished_requests(self):
return self._finished_requests
def start(self):
pass
class MockKVCacheRecvingThread(threading.Thread):
def __init__(self, *args, **kwargs):
super().__init__()
self.daemon = True
self._finished_requests = set()
self.add_request = MagicMock()
def get_and_clear_finished_requests(self):
return self._finished_requests
def start(self):
pass
class MockTensor:
def __init__(self, *args, **kwargs):
self.size = MagicMock(return_value=(10, 16, 8, 16))
self.element_size = MagicMock(return_value=4)
self.shape = (10, 16, 8, 16)
self.data_ptr = MagicMock(return_value=0x1000)
mock_envs_ascend = MagicMock()
mock_envs_ascend.MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol"
mock_logger = MagicMock()
class MockTransferEngine:
def initialize(self, *args, **kwargs):
return 0
def register_memory(self, *args, **kwargs):
return 1
class MockEnvsAscend:
MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol"
PHYSICAL_DEVICES = "10,11"
def mock_get_tensor_model_parallel_rank():
return 0
def mock_get_tp_group():
return MagicMock()
def mock_get_ip():
return "127.0.0.1"
def mock_string_to_int64_hash(s):
return hash(s)
class TestMooncakeConnectorWorker(unittest.TestCase):
def setUp(self):
self.envs_ascend_mock = MockEnvsAscend()
self.mock_transfer_engine = MagicMock()
self.mock_transfer_engine.get_rpc_port.return_value = 9090
self.mock_transfer_engine.initialize.return_value = 0
self.mock_transfer_engine.register_memory.return_value = 0
self.patches = [
patch('os.getenv', return_value="10,11"),
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
patch('torch.Tensor.element_size', return_value=4),
patch('torch.Tensor.data_ptr', return_value=0x1000),
patch('math.prod', return_value=128),
patch('random.Random'),
patch(
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank',
mock_get_tensor_model_parallel_rank),
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
mock_get_tp_group),
patch('vllm_ascend.distributed.mooncake_connector.get_ip',
mock_get_ip),
patch(
'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash',
mock_string_to_int64_hash),
patch('vllm_ascend.distributed.mooncake_connector.TransferEngine',
return_value=self.mock_transfer_engine),
patch(
'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread',
MagicMock()),
patch(
'vllm_ascend.distributed.mooncake_connector.KVCacheRecvingThread',
MagicMock()),
patch('vllm_ascend.distributed.mooncake_connector.logger',
MagicMock()),
patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
MagicMock()),
patch.dict('sys.modules',
{'vllm_ascend.envs': self.envs_ascend_mock}),
]
for p in self.patches:
p.start() # type: ignore
self.vllm_config = MockVllmConfig()
self.engine_id = "test_engine"
self.kv_caches = {"layer1": (MagicMock(), MagicMock())}
def tearDown(self):
for p in self.patches:
p.stop() # type: ignore
def test_register_kv_caches_producer(self):
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.register_kv_caches(self.kv_caches)
self.assertEqual(len(worker.kv_caches), 1)
self.assertIsNotNone(worker.kv_send_thread)
self.assertIsNone(worker.kv_recv_thread)
def test_register_kv_caches_consumer(self):
self.vllm_config.kv_transfer_config.kv_role = 'kv_consumer'
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.register_kv_caches(self.kv_caches)
self.assertIsNone(worker.kv_send_thread)
self.assertIsNotNone(worker.kv_recv_thread)
def test_register_kv_caches_mla_case(self):
mla_cache1 = MagicMock()
mla_cache1.size.return_value = (10, 16, 1, 16)
mla_cache2 = MagicMock()
mla_cache2.size.return_value = (10, 16, 1, 8)
mla_caches = {"layer1": (mla_cache1, mla_cache2)}
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.register_kv_caches(mla_caches)
self.assertTrue(worker.use_mla)
self.assertEqual(len(worker.block_len), 2)
def test_device_id_selection_with_physical_devices(self):
# Test with physical devices set
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
# Default tp_rank is 0, so device_id should be 10
self.assertEqual(worker.device_id, 10)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,169 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
#
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.request import FinishReason, RequestStatus
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
create_model_runner_output,
create_request, create_scheduler,
create_vllm_config)
def test_basic_lifecycle():
"""Test lifecycle of a Remote Decode request."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
# (1b): execute_model()
model_runner_output = create_model_runner_output(reqs=[request])
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
# Ensure the request is finished after 1 tokens.
assert request.is_finished()
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
output = engine_core_outputs[0].outputs[0]
assert output.finish_reason == FinishReason.LENGTH
assert output.kv_transfer_params is not None
# Request freed in Scheduler and blocks should be freed
assert request_id in scheduler.finished_req_ids
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
# ... but blocks should not be freed.
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (2b): execute_model()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (2c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP (3): Finished sending.
# (3a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (3b): execute_model()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_id])
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm we do not have any memory leaks after req lifecycle.
assert_scheduler_empty(scheduler)
def test_prefix_cache_lifecycle():
"""Test that remote decode params still works with a prefix cache hit."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Prime the KVCache.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 3
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote_a = create_request(request_id=1, num_tokens=NUM_TOKENS)
scheduler.add_request(request_remote_a)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote_a],
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
#####################
# Actual Test: confirm we send all blocks.
# Step (1): Send the KV Transfer.
NUM_EXTERNAL_FULL_BLOCKS -= 1
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
# Ensure we send all block ids, even if there is a cache hit.
assert (len(
kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
1))
# STEP (2): Ensure it is freed.
scheduler_output = scheduler.schedule()
scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_remote.request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,239 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
#
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.request import RequestStatus
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
create_model_runner_output,
create_request, create_scheduler,
create_vllm_config)
def test_basic_lifecycle():
"""Test lifecycle of a remote prefill."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
START_FREE_BLOCK_QUEUE_SIZE = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
block_size=BLOCK_SIZE)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1):
# (1a): schedule()
scheduler_output = scheduler.schedule()
# Nothing running and empty scheduler output.
assert len(scheduler.running) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler_output.num_scheduled_tokens) == 0
assert scheduler_output.total_num_scheduled_tokens == 0
# Req waiting for KVs with no computed/scheduled toks ...
assert len(scheduler.waiting) == 1
assert request in scheduler.waiting
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
assert (request.num_computed_tokens == 0)
# ... but should have (uncached) blocks allocated to it.
block_pool = scheduler.kv_cache_manager.block_pool
assert (block_pool.free_block_queue.num_free_blocks
< START_FREE_BLOCK_QUEUE_SIZE)
assert len(block_pool.cached_block_hash_to_block) == 0
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block._block_hash is None
# (1b): forward()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert not engine_core_outputs or not engine_core_outputs[0].outputs
# STEP (2):
# (2a): schedule(): nothing happens!
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 0
# (2b): forward(): request finishes recv.
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving=[request_id])
# (2c): update_from_output():
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
# STEP (3):
# (3a): schedule(): this should actually schedule.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
# Confirm the block are actually allocated.
num_hashed_blocks = 0
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
num_hashed_blocks += (1 if block._block_hash is not None else 0)
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
# Confirm the rest of the prompt is scheduled in this step.
scheduled_req = scheduler_output.scheduled_new_reqs[0]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
num_computed_tokens = scheduled_req.num_computed_tokens
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
# (3b): execute_model()
model_runner_output = create_model_runner_output([request])
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
scheduler.schedule()
assert_scheduler_empty(scheduler)
def test_no_spurious_prefix_caching():
"""
With P/D, blocks can be allocated but uncomputed for
multiple engine steps. This test confirms that we do
not accidentally have cache hits against uncomputed
blocks.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 and a half full external blocks.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
# Both of these requests have prompts like [1,1,1,1,1, ...]
request_remote = create_request(
request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
use_all_1s_for_prompt_tokens=True,
)
# Schedule the remote prefill request. This should not
# cause any blocks to be cached.
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
assert len(scheduler.waiting) == 1
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_remote.request_id]
# Remote blocks should not be cached.
for block in remote_blocks:
assert block.ref_cnt == 1
assert block._block_hash is None
def test_full_block_prompt():
"""Test that we handle a prompt that is the full block size."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Initialize a recv.
scheduler_output = scheduler.schedule()
# All blocks should be allocated.
num_blocks = len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
scheduler.update_from_output(scheduler_output, model_runner_output)
# # STEP (2): Recv.
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving=[request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
# # STEP (3): Run as usual.
scheduler_output = scheduler.schedule()
# We need to recompute the final token of the prompt to generate
# the first new token, so we should not have a new block.
num_blocks = len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
NUM_TOKENS - 1)
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
model_runner_output = create_model_runner_output([request])
scheduler.update_from_output(scheduler_output, model_runner_output)
# # Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
scheduler.schedule()
assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# This code is from: https://github.com/vllm-project/vllm/tests/v1/kv_connector/unit/utils.py
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import os
from typing import Any, Optional
import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.utils import vllm_version_is
EOS_TOKEN_ID = 50256
os.environ["VLLM_USE_V1"] = "1"
def assert_scheduler_empty(scheduler: Scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert len(scheduler.requests) == 0
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler.finished_recving_kv_req_ids) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
assert block.ref_cnt == 0
def create_vllm_config(
max_num_seqs: int = 16,
max_num_batched_tokens: int = 1024,
block_size: int = 128,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens,
)
fake_weight_path = os.path.join(os.path.dirname(__file__), "..",
"fake_weight")
model_config = ModelConfig(
model=fake_weight_path,
skip_tokenizer_init=True,
)
# Cache config, optionally force APC
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="LLMDataDistCMgrConnector",
kv_role="kv_both",
kv_connector_module_path=
"vllm_ascend.distributed.llmdatadist_c_mgr_connector")
return VllmConfig(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"))
def create_scheduler(
vllm_config: VllmConfig,
num_blocks: int = 10000,
) -> Scheduler:
"""Initialize Scheduler For Testing."""
block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float16,
False))
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
_none_hash_initialized = False
def create_request(
request_id: int,
num_tokens: int = 10,
max_tokens: int = 128,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
block_size: int = 16,
) -> Request:
"""Make dummy request for testing."""
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash)
kv_transfer_params: Optional[dict[str, Any]] = None
if do_remote_decode:
assert not do_remote_prefill
kv_transfer_params = dict(do_remote_prefill=False,
do_remote_decode=True)
elif do_remote_prefill:
kv_transfer_params = dict(do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(
range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
remote_tp_size=1)
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)
if use_all_1s_for_prompt_tokens:
prompt_token_ids = [1] * num_tokens
else:
prompt_token_ids = [i * request_id for i in range(num_tokens)]
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
req = Request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
multi_modal_kwargs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
pooling_params=[],
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
else:
req = Request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=[],
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
req.kv_transfer_params = kv_transfer_params
return req
def create_model_runner_output(
reqs: list[Request],
finished_sending: Optional[list[str]] = None,
finished_recving: Optional[list[str]] = None,
use_eos: bool = False,
) -> ModelRunnerOutput:
"""Make dummy model runner output for testing."""
# Make request data.
req_ids = [req.request_id for req in reqs]
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
# Make sampled tokens.
sampled_token = EOS_TOKEN_ID if use_eos else 0
sampled_token_ids = [[sampled_token] for _ in req_ids]
# Make output data structure.
extra_args = {}
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
finished_recving=finished_recving)
extra_args = {"kv_connector_output": kv_connector_output}
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
**extra_args,
)
else:
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
**extra_args,
)
return model_runner_output

View File

View File

@@ -0,0 +1,195 @@
import pytest
import torch
from pytest_mock import MockerFixture
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from tests.ut.base import PytestBase
from vllm_ascend.models.deepseek_mtp import (
CustomDeepSeekMTP, CustomDeepSeekMultiTokenPredictor,
CustomDeepSeekMultiTokenPredictorLayer)
class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase):
@pytest.fixture
def setup_mtp_layer(self, mocker: MockerFixture):
config = PretrainedConfig(vocab_size=1000,
hidden_size=768,
rms_norm_eps=1e-5)
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
return_value=None)
mocker.patch(
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekShareHead.__init__",
return_value=None)
mocker_deepseek_v2_decode_layer = mocker.patch(
"vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())
mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None)
mocker_deepseek_v2_decode_layer.assert_called_once()
return mtp_layer
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
mtp_layer = setup_mtp_layer
assert isinstance(mtp_layer, CustomDeepSeekMultiTokenPredictorLayer)
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
mtp_layer = setup_mtp_layer
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch.object(mtp_layer,
'eh_proj',
return_value=torch.randn(2, 3, 768))
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
torch.randn(2, 3, 768))
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
kv_cache = torch.randn(2, 3, 768)
previous_hidden_states = torch.randn(2, 3, 768)
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
output = mtp_layer(input_ids, positions, kv_cache, None,
previous_hidden_states, inputs_embeds, 0)
assert output.shape == (2, 3, 768)
class TestCustomDeepSeekMultiTokenPredictor(PytestBase):
@pytest.fixture
def setup_predictor(self, mocker: MockerFixture):
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
mock_model_config = mocker.MagicMock(spec=ModelConfig)
mock_hf_config = mocker.MagicMock()
mock_hf_config.num_hidden_layers = 12
mock_hf_config.num_nextn_predict_layers = 3
mock_hf_config.vocab_size = 30000
mock_model_config.hf_config = mock_hf_config
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = CacheConfig()
mock_vllm_config.quant_config = mocker.MagicMock()
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())
predictor = CustomDeepSeekMultiTokenPredictor(
vllm_config=mock_vllm_config)
return predictor
def test_init(self, mocker: MockerFixture, setup_predictor):
predictor = setup_predictor
assert predictor.num_mtp_layers == 3
assert isinstance(predictor, CustomDeepSeekMultiTokenPredictor)
@pytest.mark.parametrize(
'kv_caches, inputs_embeds',
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
inputs_embeds):
predictor = setup_predictor
mock_layer = mocker.MagicMock()
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
predictor.layers_list = [mock_layer]
# todo: need or not?
# predictor.num_mtp_layers = 1
input_ids = torch.tensor([[1, 2, 3]])
positions = torch.tensor([[0, 1, 2]])
mocker.patch(
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
output = predictor.forward(input_ids, positions, kv_caches, None, None,
inputs_embeds, 0)
mock_layer.assert_called_once()
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
predictor = setup_predictor
mock_layer = mocker.MagicMock()
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
predictor.layers_list = [mock_layer]
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch(
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
return_value=None)
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
result_logits = predictor.compute_logits(hidden_states=hidden_states,
sampling_metadata=None)
predictor.logits_processor.assert_called_once()
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
class TestCustomDeepSeekMTP(PytestBase):
@pytest.fixture
def setup_mtp(self, mocker: MockerFixture):
vllm_config = mocker.MagicMock()
vllm_config.model_config.hf_config.num_hidden_layers = 12
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
vllm_config.cache_config = mocker.MagicMock()
vllm_config.quant_config = mocker.MagicMock()
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
return_value=None)
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())
mtp = CustomDeepSeekMTP(vllm_config=vllm_config)
return mtp
def test_init(self, mocker: MockerFixture, setup_mtp):
mtp = setup_mtp
assert isinstance(mtp, CustomDeepSeekMTP)
def test_forward(self, mocker: MockerFixture, setup_mtp):
input_ids = torch.tensor([[1, 2, 3]])
positions = torch.tensor([[0, 1, 2]])
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
spec_step_idx = 0
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
previous_hidden_states, inputs_embeds,
spec_step_idx)
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))

View File

@@ -0,0 +1,295 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
import torch
from transformers import PretrainedConfig
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import GroupCoordinator
from vllm_ascend.models.deepseek_v2 import (
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
CustomDeepseekV2RowParallelLinear,
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
@pytest.fixture
def base_config():
config = PretrainedConfig(
hidden_size=128,
num_attention_heads=8,
num_hidden_layers=2,
intermediate_size=256,
hidden_act="silu",
rms_norm_eps=1e-6,
rope_theta=10000.0,
max_position_embeddings=2048,
n_routed_experts=4,
n_shared_experts=1,
moe_intermediate_size=256,
num_experts_per_tok=2,
routed_scaling_factor=1.0,
first_k_dense_replace=0,
moe_layer_freq=1,
kv_lora_rank=16,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
topk_method="noaux_tc",
scoring_func="softmax",
norm_topk_prob=True,
n_group=1,
topk_group=1,
vocab_size=10000,
)
return config
@pytest.fixture
def vllm_config(base_config):
model_config = SimpleNamespace(
hf_config=base_config,
tensor_parallel_size=1,
dtype=torch.float32,
use_mla=False,
quant_config=None,
max_model_len=2048,
)
cache_config = CacheConfig()
vllm_config = Mock()
vllm_config.model_config = model_config
vllm_config.cache_config = cache_config
vllm_config.quant_config = None
return vllm_config
@pytest.fixture
def mock_distributed():
tp_group = Mock(spec=GroupCoordinator)
tp_group.rank_in_group = 0
tp_group.world_size = 1
tp_group.device_group = Mock()
dp_group = Mock(spec=GroupCoordinator)
dp_group.rank_in_group = 0
dp_group.world_size = 1
ep_group = Mock(spec=GroupCoordinator)
ep_group.rank_in_group = 0
ep_group.world_size = 1
pp_group = Mock(spec=GroupCoordinator)
pp_group.rank_in_group = 0
pp_group.world_size = 1
mock_vllm_config = Mock()
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
with patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \
patch("vllm_ascend.models.deepseek_v2.get_tp_group", return_value=tp_group), \
patch("vllm_ascend.models.deepseek_v2.get_ep_group", return_value=ep_group), \
patch("vllm_ascend.models.deepseek_v2.get_dp_group", return_value=dp_group), \
patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \
patch("vllm_ascend.models.deepseek_v2.get_pp_group",
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group), \
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \
patch("torch.npu.current_device", return_value=0):
yield
@pytest.fixture
def mock_forward_context():
forward_context = Mock(in_profile_run=False, with_prefill=False)
with patch("vllm_ascend.models.deepseek_v2.get_forward_context",
return_value=forward_context):
yield
def test_custom_deepseek_v2_silu_and_mul():
torch.set_default_device("cpu")
silu = CustomDeepseekV2SiluAndMul()
assert silu.weight_scale is None
x = torch.randn(2, 4)
output = silu.forward_oot(x)
assert output.shape == (2, 2)
weight_scale = Mock(return_value=torch.tensor(0.1))
silu = CustomDeepseekV2SiluAndMul(weight_scale=weight_scale)
quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32)
dynamic_scale = torch.randn(2, 1)
with patch("torch_npu.npu_dequant_swiglu_quant",
return_value=torch.randn(2, 4)):
output = silu.forward_oot((quant_x, dynamic_scale))
assert output.shape == (2, 4)
def test_custom_deepseek_v2_merged_replicated_linear(mock_distributed):
linear = CustomDeepseekV2MergedReplicatedLinear(input_size=128,
output_sizes=[64, 64],
bias=False,
quant_config=None)
assert linear.output_sizes == [64, 64]
param = Mock()
param.data = torch.zeros(128, 128)
param.output_dim = 1
param.is_gguf_weight = False
param.is_gguf_weight_type = False
loaded_weight = torch.randn(128, 64)
linear.weight_loader(param, loaded_weight, loaded_shard_id=0)
with pytest.raises(AssertionError):
linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0)
@pytest.mark.parametrize("cls", [
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
CustomDeepseekV2RowParallelLinear
])
def test_row_parallel_linear(cls, mock_distributed):
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
linear.quant_method = Mock()
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
input_ = torch.randn(2, 4, 128)
with patch("vllm_ascend.models.deepseek_v2.split_tensor_along_last_dim",
return_value=[torch.randn(2, 4, 64)]):
linear.input_is_parallel = False
output = linear(input_, is_prefill=True)
assert output[0].shape == (2, 4, 64)
linear.input_is_parallel = True
output = linear(input_, is_prefill=False)
assert output[0].shape == (2, 4, 64)
def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
mlp = CustomDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="silu",
quant_config=None)
assert isinstance(mlp.act_fn, CustomDeepseekV2SiluAndMul)
x = torch.randn(2, 4, 128)
output = mlp(x)
assert output.shape == (2, 4, 128)
with patch("vllm_ascend.models.deepseek_v2.QuantizationConfig"
) as mock_quant_config:
mock_quant_config.name = "w8a8dynamic"
with pytest.raises(NotImplementedError):
CustomDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="silu",
quant_config=mock_quant_config,
force_replicate=False)
with pytest.raises(ValueError):
CustomDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="relu",
quant_config=None)
def test_custom_deepseek_v2_moe(mock_distributed, base_config,
mock_forward_context):
base_config.n_shared_experts = 1
moe = CustomDeepseekV2MoE(config=base_config,
quant_config=None,
prefix="mlp")
assert moe.top_k == 2
x = torch.randn(2, 4, 128)
attn_metadata = Mock(num_prefills=1)
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
output = moe(x, attn_metadata)
assert output.shape == (2, 4, 128)
@patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
num_heads=8,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
q_lora_rank=16,
kv_lora_rank=16,
cache_config=CacheConfig(),
quant_config=None,
prefix="layers.0.self_attn")
assert attn.debug_layer_idx == 0
x = torch.randn(2, 4, 128)
positions = torch.arange(4).repeat(2, 1)
with patch.object(attn.mla_attn,
"__call__",
return_value=torch.randn(2, 4, 128)):
with pytest.raises(AssertionError):
attn(positions, x)
attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
num_heads=8,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
q_lora_rank=None,
kv_lora_rank=16,
prefix="layers.1.self_attn")
assert hasattr(attn, "q_proj")
def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
# 创建一个简单的配置对象
class SimpleConfig:
def __init__(self):
self.vocab_size = 10000
self.hidden_size = 128
config = SimpleConfig()
# 直接创建lmhead和logits_processor
lmhead = ParallelLMHead(config.vocab_size, config.hidden_size)
logits_processor = LogitsProcessor(config.vocab_size)
# 创建模拟输出
mock_output = torch.randn(2, 4, config.hidden_size)
mock_logits = torch.randn(2, 4, config.vocab_size)
# 直接测试logits_processor
with patch.object(lmhead.quant_method, "apply", return_value=mock_logits):
with patch.object(logits_processor,
"_gather_logits",
return_value=mock_logits):
logits = logits_processor(lmhead, mock_output)
assert logits.shape == (2, 4, config.vocab_size)

View File

@@ -0,0 +1,424 @@
import pytest
import torch
import torch.nn.functional as F
from pytest_mock import MockerFixture
from tests.ut.base import PytestBase
from vllm_ascend.models.qwen2_5_vl import (
AscendQwen2_5_VisionAttention, AscendQwen2_5_VisionBlock,
AscendQwen2_5_VisionPatchEmbed, AscendQwen2_5_VisionRotaryEmbedding,
AscendQwen2_5_VisionTransformer, AscendQwen2_5_VLForConditionalGeneration)
class TestAscendQwen2_5_VisionAttention(PytestBase):
def init_attention(
self,
mocker,
embed_dim=1000,
num_heads=10,
projection_size=100,
quant_config=None,
prefix="",
):
mocker_attn = mocker.patch(
"vllm_ascend.models.qwen2_5_vl.Qwen2_5_VisionAttention.__init__")
attention = AscendQwen2_5_VisionAttention(
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
)
args, kwargs = mocker_attn.call_args
assert args == (embed_dim, num_heads, projection_size, None, "")
assert not kwargs
attention.num_attention_heads_per_partition = num_heads
return attention
def test_attn_init_should_normal(self, mocker: MockerFixture):
embed_dim = 1000
num_heads = 10
projection_size = 100
quant_config = None
prefix = ""
vit = self.init_attention(
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
mocker=mocker,
)
assert vit.embed_dim == 1000
assert vit.hidden_size_per_attention_head == 10
def test_attn_init_should_raise_error(self, mocker: MockerFixture):
embed_dim = 1000
num_heads = 7
projection_size = 100
quant_config = None
prefix = ""
with pytest.raises(AssertionError):
# projection_size should divided by num heads
self.init_attention(
mocker=mocker,
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
)
def test_split_qkv(self, mocker: MockerFixture):
attention = self.init_attention(mocker=mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
q, k, v = attention.split_qkv(torch.rand((100, 10, 300)))
assert q.shape == (100, 10, 10, 10)
assert k.shape == (100, 10, 10, 10)
assert v.shape == (100, 10, 10, 10)
def test_attn_forward(self, mocker: MockerFixture):
attention = self.init_attention(mocker=mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim
cu_seqlens = torch.tensor([10, 50, 100])
cos = torch.rand((1, 100, 1, 128))
sin = torch.rand((1, 100, 1, 128))
qkv = lambda x: (x, 0) # noqa
split_qkv = lambda x: [ #noqa
torch.rand((100, 3, 10, 128)) for i in range(3)
] # noqa
npu_rotary_mul = lambda q, cos, sin: q # noqa
_npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa
proj = lambda x: (x, 0) # noqa
mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv)
mocker_split_qkv = mocker.patch.object(
attention,
"split_qkv",
side_effect=split_qkv,
)
mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul",
side_effect=npu_rotary_mul)
mocker_npu_flash_attention_unpad = mocker.patch(
"torch_npu._npu_flash_attention_unpad",
side_effect=_npu_flash_attention_unpad,
)
mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj)
attention.__dict__["qkv"] = mocker_qkv
attention.__dict__["split_qkv"] = mocker_split_qkv
attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul
attention.__dict__["_npu_flash_attention_unpad"] = (
mocker_npu_flash_attention_unpad)
attention.__dict__["proj"] = mocker_proj
output = attention.forward(
x=x,
cu_seqlens=cu_seqlens,
cos=cos,
sin=sin,
)
qkv_args, qkv_kwargs = mocker_qkv.call_args
assert qkv_args == (x, )
assert not qkv_kwargs
split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args
assert split_qkv_args == (x, )
assert not split_qkv_kwargs
npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args
assert npu_rotary_mul_args[1:] == (cos, sin)
assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128])
assert not npu_rotary_mul_kwargs
assert output.shape == torch.Size([100, 3, 1280])
class TestAscendQwen2_5_VisionBlock(PytestBase):
def init_vision_block(
self,
mocker,
dim=100,
num_heads=10,
mlp_hidden_dim=100,
):
mocker_vit = mocker.patch(
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionBlock.__init__",
return_value=None,
)
mocker_attn = mocker.patch(
"vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionAttention.__init__",
return_value=None,
)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
vision_block = AscendQwen2_5_VisionBlock(
dim=dim,
num_heads=num_heads,
mlp_hidden_dim=mlp_hidden_dim,
)
args, kwargs = mocker_vit.call_args
assert args == (dim, num_heads, mlp_hidden_dim, F.silu, None, None, "")
assert not kwargs
args1, kwargs1 = mocker_attn.call_args
assert not args1
assert kwargs1 == {
"embed_dim": dim,
"num_heads": num_heads,
"projection_size": dim,
"quant_config": None,
"prefix": ".attn",
}
return vision_block
def test_init_vision_block_should_normal(
self,
mocker: MockerFixture,
):
vision_block = self.init_vision_block(mocker)
assert isinstance(vision_block, AscendQwen2_5_VisionBlock)
def test_vision_block_forward(self, mocker: MockerFixture):
x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d
cu_seqlens = torch.tensor([10, 50, 100])
cos = torch.rand((1, 100, 1, 128))
sin = torch.rand((1, 100, 1, 128))
vision_block = self.init_vision_block(mocker)
mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x)
mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x)
vision_block.__dict__["attn"] = mocker_attn
vision_block.__dict__["mlp"] = mocker_mlp
output = vision_block.forward(x.clone(), cu_seqlens, cos, sin)
_, attn_kwargs = mocker_attn.call_args
assert attn_kwargs == {
"cu_seqlens": cu_seqlens,
"cos": cos,
"sin": sin,
}
assert torch.all(x * 3 == output)
class TestAscendQwen2_5_VisionPatchEmbed(PytestBase):
def test_forward(self):
patch_embed = AscendQwen2_5_VisionPatchEmbed()
ret = patch_embed(torch.rand((120, 1176)))
assert ret.shape == (120, 1152)
class TestAscendQwen2_5_VisionRotaryEmbedding(PytestBase):
def init_rotary_embedding(
self,
mocker,
dim=128,
):
mocker_ebed = mocker.patch(
"vllm_ascend.models.qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding.__init__",
return_value=None,
)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
rotary_embedding = AscendQwen2_5_VisionRotaryEmbedding(dim=dim, )
args, kwargs = mocker_ebed.call_args
assert args == (dim, 10000.0)
assert not kwargs
return rotary_embedding
def test_init_rotary_embedding_should_normal(self, mocker: MockerFixture):
rotary_embedding = self.init_rotary_embedding(mocker)
assert isinstance(rotary_embedding,
AscendQwen2_5_VisionRotaryEmbedding)
class TestAscendQwen2_5_VisionTransformer(PytestBase):
input_data = torch.tensor([[0.1, 0.2], [0.3, 0.4]])
def init_vision_transformer(
self,
mocker,
):
norm_eps = 1e-6
vision_config = mocker.MagicMock()
vision_config.patch_size = 16
vision_config.temporal_patch_size = 2
vision_config.in_channels = 3
vision_config.hidden_act = "gelu"
vision_config.depth = 0
vision_config.num_heads = 10
vision_config.hidden_size = 300
mocker.patch(
"vllm_ascend.models.qwen2_5_vl.parallel_state.get_tensor_model_parallel_rank",
return_value=0,
)
mocker.patch("vllm.distributed.utils.divide", return_value=100)
mocker.patch(
"vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size",
return_value=2,
)
mocker.patch(
"vllm.model_executor.layers.linear.divide",
return_value=2,
)
mocker.patch(
"vllm.model_executor.layers.linear.get_tensor_model_parallel_rank",
return_value=0)
mocker.patch(
"vllm_ascend.models.qwen2_5_vl.parallel_state.get_tensor_model_parallel_world_size",
return_value=2,
)
vision_transformer = AscendQwen2_5_VisionTransformer(
vision_config,
norm_eps,
)
assert not vision_transformer.interleaved
return vision_transformer
def test_init_vision_transformer(self, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
assert isinstance(vision_transformer, AscendQwen2_5_VisionTransformer)
@pytest.mark.parametrize(
"interleaved, expected",
[
(
False,
torch.tensor([
input_data[0, 0].cos(),
input_data[0, 1].cos(),
input_data[0, 0].cos(),
input_data[0, 1].cos(),
input_data[1, 0].cos(),
input_data[1, 1].cos(),
input_data[1, 0].cos(),
input_data[1, 1].cos(),
]),
),
(
True,
torch.tensor([
input_data[0, 0].cos(),
input_data[0, 0].cos(),
input_data[0, 1].cos(),
input_data[0, 1].cos(),
input_data[1, 0].cos(),
input_data[1, 0].cos(),
input_data[1, 1].cos(),
input_data[1, 1].cos(),
]),
),
],
)
def test_cal_cos_sin(self, interleaved, expected, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
vision_transformer.__dict__["interleaved"] = interleaved
vision_transformer.__dict__["hidden_size_per_attention_head"] = 2
vision_transformer.hidden_size_per_attention_head = 4
cos_new, _ = vision_transformer.cal_cos_sin(self.input_data)
assert cos_new.shape == (1, 32, 1, 2)
def test_forward(self, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
x = torch.randn(1, 3, 224, 224)
grid_thw = torch.tensor([[1, 4, 4]])
mocker_patch_embed = mocker.patch.object(
vision_transformer,
"patch_embed",
side_effect=lambda _: torch.randn(16, 512), # noqa
)
mocker_rot_pos_emb = mocker.patch.object(
vision_transformer,
"rot_pos_emb",
side_effect=lambda _: torch.randn(16, 64), # noqa
)
mocker_get_window_index = mocker.patch.object(
vision_transformer,
"get_window_index",
side_effect=lambda _: (torch.arange(8), [4, 8, 12, 16]), # noqa
)
mocker_cal_cos_sin = mocker.patch.object(
vision_transformer,
"cal_cos_sin",
side_effect=lambda _:
(torch.randn(16, 32), torch.randn(16, 32)), # noqa
)
mocker_merger = mocker.patch.object(
vision_transformer,
"merger",
side_effect=lambda _: torch.randn(16, 256), # noqa
)
vision_transformer.__dict__["vision_blocks"] = [
lambda *args, **kwargs: torch.randn(16, 1, 512) # noqa
]
vision_transformer.__dict__["patch_embed"] = mocker_patch_embed
vision_transformer.__dict__["rot_pos_emb"] = mocker_rot_pos_emb
vision_transformer.__dict__[
"get_window_index"] = mocker_get_window_index
vision_transformer.__dict__["cal_cos_sin"] = mocker_cal_cos_sin
vision_transformer.__dict__["merger"] = mocker_merger
vision_transformer.__dict__["fullatt_block_indexes"] = [0, 2]
vision_transformer.__dict__["spatial_merge_unit"] = 2
ret = vision_transformer.forward(x, grid_thw)
assert ret.shape == (8, 256)
mocker_patch_embed.assert_called_with(x)
mocker_rot_pos_emb.assert_called_with(grid_thw)
mocker_get_window_index.assert_called_with(grid_thw)
mocker_cal_cos_sin.assert_called_once()
mocker_merger.assert_called_once()
class TestAscendQwen2_5_VLForConditionalGeneration(PytestBase):
def test_init_vl_for_conditional_generation(self, mocker: MockerFixture):
vllm_config = mocker.MagicMock()
vllm_config.vision_config = "vision_config"
vllm_config.rms_norm_eps = 1e-5
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker_vl = mocker.patch(
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.__init__",
return_value=None,
)
mocker_vit = mocker.patch(
"vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionTransformer.__init__",
return_value=None,
)
vl_for_conditional_generation = AscendQwen2_5_VLForConditionalGeneration(
vllm_config=vllm_config)
args, kwargs = mocker_vl.call_args
assert not args
assert kwargs == {"vllm_config": vllm_config, "prefix": ""}
mocker_vit.assert_called_once()
assert isinstance(
vl_for_conditional_generation,
AscendQwen2_5_VLForConditionalGeneration,
)

View File

@@ -0,0 +1,422 @@
import pytest
import torch
import torch.nn.functional as F
from pytest_mock import MockerFixture
from vllm.model_executor.models.qwen2_5_vl import \
Qwen2_5_VLForConditionalGeneration
from tests.ut.base import PytestBase
from vllm_ascend.models.qwen2_5_vl_without_padding import (
AscendQwen2_5_VisionAttention_Without_Padding,
AscendQwen2_5_VisionBlock_Without_Padding,
AscendQwen2_5_VisionPatchEmbed_Without_Padding,
AscendQwen2_5_VisionTransformer_Without_Padding,
AscendQwen2_5_VLForConditionalGeneration_Without_Padding)
class TestAscendQwen2_5_VisionAttention_Without_Padding(PytestBase):
def init_attention(
self,
mocker,
embed_dim=1000,
num_heads=10,
projection_size=100,
quant_config=None,
prefix="",
):
mocker_attn = mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.Qwen2_5_VisionAttention.__init__"
)
attention = AscendQwen2_5_VisionAttention_Without_Padding(
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
)
args, kwargs = mocker_attn.call_args
assert args == (embed_dim, num_heads, projection_size, None, "")
assert not kwargs
attention.num_attention_heads_per_partition = num_heads
return attention
def test_vit_init_should_normal(self, mocker: MockerFixture):
embed_dim = 1000
num_heads = 10
projection_size = 100
quant_config = None
prefix = ""
vit = self.init_attention(
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
mocker=mocker,
)
assert vit.embed_dim == 1000
assert vit.hidden_size_per_attention_head == 10
def test_vit_init_should_raise_error(self, mocker: MockerFixture):
embed_dim = 1000
num_heads = 7
projection_size = 100
quant_config = None
prefix = ""
with pytest.raises(AssertionError):
# projection_size should divided by num heads
self.init_attention(
mocker=mocker,
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
)
def test_vit_forward(self, mocker: MockerFixture):
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
attention = self.init_attention(mocker=mocker)
x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim
cu_seqlens = torch.tensor([10, 50, 100])
cos = torch.rand((1, 100, 1, 128))
sin = torch.rand((1, 100, 1, 128))
qkv = lambda x: (x, 0) # noqa
split_qkv = lambda x: [ #noqa
torch.rand((100, 3, 10, 128)) for i in range(3)
] # noqa
npu_rotary_mul = lambda q, cos, sin: q # noqa
_npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa
proj = lambda x: (x, 0) # noqa
mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv)
mocker_split_qkv = mocker.patch.object(
attention,
"split_qkv",
side_effect=split_qkv,
)
mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul",
side_effect=npu_rotary_mul)
mocker_npu_flash_attention_unpad = mocker.patch(
"torch_npu._npu_flash_attention_unpad",
side_effect=_npu_flash_attention_unpad,
)
mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj)
attention.__dict__["qkv"] = mocker_qkv
attention.__dict__["split_qkv"] = mocker_split_qkv
attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul
attention.__dict__["_npu_flash_attention_unpad"] = (
mocker_npu_flash_attention_unpad)
attention.__dict__["proj"] = mocker_proj
output = attention.forward(
x=x,
cu_seqlens=cu_seqlens,
cos=cos,
sin=sin,
)
qkv_args, qkv_kwargs = mocker_qkv.call_args
assert qkv_args == (x, )
assert not qkv_kwargs
split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args
assert split_qkv_args == (x, )
assert not split_qkv_kwargs
npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args
assert npu_rotary_mul_args[1:] == (cos, sin)
assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128])
assert not npu_rotary_mul_kwargs
assert output.shape == torch.Size([100, 3, 1280])
class TestAscendQwen2_5_VisionBlock_Without_Padding(PytestBase):
def init_vision_block(
self,
mocker,
dim=100,
num_heads=10,
mlp_hidden_dim=100,
):
mocker_vit = mocker.patch(
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionBlock.__init__",
return_value=None,
)
mocker_attn = mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionAttention_Without_Padding.__init__",
return_value=None,
)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
vision_block = AscendQwen2_5_VisionBlock_Without_Padding(
dim=dim,
num_heads=num_heads,
mlp_hidden_dim=mlp_hidden_dim,
)
args, kwargs = mocker_vit.call_args
assert args == (dim, num_heads, mlp_hidden_dim, F.silu, None, None, "")
assert not kwargs
args1, kwargs1 = mocker_attn.call_args
assert not args1
assert kwargs1 == {
"embed_dim": dim,
"num_heads": num_heads,
"projection_size": dim,
"quant_config": None,
"prefix": ".attn",
}
return vision_block
def test_init_vision_block_should_normal(
self,
mocker: MockerFixture,
):
vision_block = self.init_vision_block(mocker)
assert isinstance(vision_block,
AscendQwen2_5_VisionBlock_Without_Padding)
def test_vision_block_forward(self, mocker: MockerFixture):
x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d
cu_seqlens = torch.tensor([10, 50, 100])
cos = torch.rand((1, 100, 1, 128))
sin = torch.rand((1, 100, 1, 128))
vision_block = self.init_vision_block(mocker)
mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x)
mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x)
vision_block.__dict__["attn"] = mocker_attn
vision_block.__dict__["mlp"] = mocker_mlp
output = vision_block.forward(x.clone(), cu_seqlens, cos, sin)
_, attn_kwargs = mocker_attn.call_args
assert attn_kwargs == {
"cu_seqlens": cu_seqlens,
"cos": cos,
"sin": sin,
}
assert torch.all(x * 3 == output)
class TestAscendQwen2_5_VisionPatchEmbed_Without_Padding(PytestBase):
def test_forward(self):
patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding()
ret = patch_embed(torch.rand((120, 1176)))
assert ret.shape == (120, 1152)
class TestAscendQwen2_5_VisionTransformer_Without_Padding(PytestBase):
input_data = torch.tensor([[0.1, 0.2], [0.3, 0.4]])
def init_vision_transformer(
self,
mocker,
):
norm_eps = 1e-6
vision_config = mocker.MagicMock()
vision_config.patch_size = 16
vision_config.temporal_patch_size = 2
vision_config.in_channels = 3
vision_config.hidden_act = "gelu"
vision_config.depth = 0
vision_config.hidden_size = 1280
vision_config.num_heads = 16
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker_vit = mocker.patch(
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionTransformer.__init__",
return_value=None,
)
mocker_vision_rotary_embedding = mocker.patch(
"vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionRotaryEmbedding.__init__",
return_value=None,
)
mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionBlock_Without_Padding.__init__",
return_value=None,
)
mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionPatchEmbed_Without_Padding.__init__",
return_value=None,
)
mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.parallel_state.get_tensor_model_parallel_world_size",
return_value=1,
)
mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.parallel_state.get_tensor_model_parallel_rank",
return_value=0,
)
mocker.patch("vllm.distributed.utils.divide", return_value=100)
vision_transformer = AscendQwen2_5_VisionTransformer_Without_Padding(
vision_config,
norm_eps,
)
args, kwargs = mocker_vit.call_args
assert args == (vision_config, norm_eps, None, "")
assert not kwargs
mocker_vision_rotary_embedding.assert_called_once()
return vision_transformer
def test_init_vision_transformer(self, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
assert isinstance(vision_transformer,
AscendQwen2_5_VisionTransformer_Without_Padding)
@pytest.mark.parametrize(
"interleaved, expected",
[
(
False,
torch.tensor([
input_data[0, 0].cos(),
input_data[0, 1].cos(),
input_data[0, 0].cos(),
input_data[0, 1].cos(),
input_data[1, 0].cos(),
input_data[1, 1].cos(),
input_data[1, 0].cos(),
input_data[1, 1].cos(),
]),
),
(
True,
torch.tensor([
input_data[0, 0].cos(),
input_data[0, 0].cos(),
input_data[0, 1].cos(),
input_data[0, 1].cos(),
input_data[1, 0].cos(),
input_data[1, 0].cos(),
input_data[1, 1].cos(),
input_data[1, 1].cos(),
]),
),
],
)
def test_cal_cos_sin(self, interleaved, expected, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
vision_transformer.__dict__["interleaved"] = interleaved
vision_transformer.__dict__["hidden_size_per_attention_head"] = 2
vision_transformer.hidden_size_per_attention_head = 4
cos_new, _ = vision_transformer.cal_cos_sin(self.input_data)
assert cos_new.shape == (1, 4, 1, 2)
assert torch.allclose(cos_new.view(-1), expected)
def test_forward(self, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
x = torch.randn(1, 3, 224, 224)
grid_thw = torch.tensor([[1, 4, 4]])
mocker_patch_embed = mocker.patch.object(
vision_transformer,
"patch_embed",
side_effect=lambda _: torch.randn(16, 512), # noqa
)
mocker_rot_pos_emb = mocker.patch.object(
vision_transformer,
"rot_pos_emb",
side_effect=lambda _: torch.randn(16, 64), # noqa
)
mocker_get_window_index = mocker.patch.object(
vision_transformer,
"get_window_index",
side_effect=lambda _: (torch.arange(8), [4, 8, 12, 16]), # noqa
)
mocker_cal_cos_sin = mocker.patch.object(
vision_transformer,
"cal_cos_sin",
side_effect=lambda _:
(torch.randn(16, 32), torch.randn(16, 32)), # noqa
)
mocker_merger = mocker.patch.object(
vision_transformer,
"merger",
side_effect=lambda _: torch.randn(16, 256), # noqa
)
vision_transformer.__dict__["vision_blocks"] = [
lambda *args, **kwargs: torch.randn(16, 1, 512) # noqa
]
vision_transformer.__dict__["patch_embed"] = mocker_patch_embed
vision_transformer.__dict__["rot_pos_emb"] = mocker_rot_pos_emb
vision_transformer.__dict__[
"get_window_index"] = mocker_get_window_index
vision_transformer.__dict__["cal_cos_sin"] = mocker_cal_cos_sin
vision_transformer.__dict__["merger"] = mocker_merger
vision_transformer.__dict__["fullatt_block_indexes"] = [0, 2]
vision_transformer.__dict__["spatial_merge_unit"] = 2
ret = vision_transformer.forward(x, grid_thw)
assert ret.shape == (8, 256)
mocker_patch_embed.assert_called_with(x)
mocker_rot_pos_emb.assert_called_with(grid_thw)
mocker_get_window_index.assert_called_with(grid_thw)
mocker_cal_cos_sin.assert_called_once()
mocker_merger.assert_called_once()
class TestAscendQwen2_5_VLForConditionalGeneration_Without_Padding(PytestBase):
def test_init_vl_for_conditional_generation(self, mocker: MockerFixture):
vllm_config = mocker.MagicMock()
vllm_config.vision_config = "vision_config"
vllm_config.rms_norm_eps = 1e-5
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker_vl = mocker.patch(
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.__init__",
return_value=None,
)
mocker_vit = mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionTransformer_Without_Padding.__init__",
return_value=None,
)
vl_for_conditional_generation = AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
vllm_config=vllm_config)
args, kwargs = mocker_vl.call_args
assert not args
assert kwargs == {"vllm_config": vllm_config, "prefix": ""}
mocker_vit.assert_called_once()
assert isinstance(
vl_for_conditional_generation,
AscendQwen2_5_VLForConditionalGeneration_Without_Padding,
)
def test_overridden_methods(self):
self.assert_method_overridden(
AscendQwen2_5_VLForConditionalGeneration_Without_Padding,
Qwen2_5_VLForConditionalGeneration,
"_process_image_input",
)
self.assert_method_overridden(
AscendQwen2_5_VLForConditionalGeneration_Without_Padding,
Qwen2_5_VLForConditionalGeneration,
"_process_video_input",
)
@staticmethod
def assert_method_overridden(subclass, parent, method_name: str):
"""assert subclass override parent method"""
parent_func = parent.__dict__.get(method_name)
child_func = subclass.__dict__.get(method_name)
assert child_func is not None, f"{subclass.__name__} should defined {method_name}"
assert child_func is not parent_func, f"{method_name} should override in {subclass.__name__}"

View File

@@ -0,0 +1,200 @@
import pytest
import torch
from pytest_mock import MockerFixture
from vllm.model_executor.layers.activation import QuickGELU
from tests.ut.base import PytestBase
from vllm_ascend.models.qwen2_vl import (AscendQwen2VisionAttention,
AscendQwen2VisionBlock)
class TestAscendQwen2VisionAttention(PytestBase):
def init_attention(
self,
mocker,
embed_dim=1000,
num_heads=10,
projection_size=100,
quant_config=None,
prefix="",
):
mocker_attn = mocker.patch(
"vllm_ascend.models.qwen2_vl.Qwen2VisionAttention.__init__")
attention = AscendQwen2VisionAttention(
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
)
args, kwargs = mocker_attn.call_args
assert args == (embed_dim, num_heads, projection_size, None, "")
assert not kwargs
attention.num_attention_heads_per_partition = num_heads
return attention
def test_attn_init_should_normal(self, mocker: MockerFixture):
embed_dim = 1000
num_heads = 10
projection_size = 100
quant_config = None
prefix = ""
vit = self.init_attention(
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
mocker=mocker,
)
assert vit.hidden_size_per_attention_head == 10
def test_attn_init_should_raise_error(self, mocker: MockerFixture):
embed_dim = 1000
num_heads = 7
projection_size = 100
quant_config = None
prefix = ""
with pytest.raises(AssertionError):
# projection_size should divided by num heads
self.init_attention(
mocker=mocker,
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
)
def test_attn_forward(self, mocker: MockerFixture):
attention = self.init_attention(mocker=mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim
cu_seqlens = torch.tensor([10, 50, 100])
cos = torch.rand((1, 100, 1, 128))
sin = torch.rand((1, 100, 1, 128))
qkv = lambda x: (x, 0) # noqa
split_qkv = lambda x: [ #noqa
torch.rand((100, 3, 10, 128)) for i in range(3)
] # noqa
npu_rotary_mul = lambda q, cos, sin: q # noqa
_npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa
proj = lambda x: (x, 0) # noqa
mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv)
mocker_split_qkv = mocker.patch.object(
attention,
"split_qkv",
side_effect=split_qkv,
)
mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul",
side_effect=npu_rotary_mul)
mocker_npu_flash_attention_unpad = mocker.patch(
"torch_npu._npu_flash_attention_unpad",
side_effect=_npu_flash_attention_unpad,
)
mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj)
attention.__dict__["qkv"] = mocker_qkv
attention.__dict__["split_qkv"] = mocker_split_qkv
attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul
attention.__dict__["_npu_flash_attention_unpad"] = (
mocker_npu_flash_attention_unpad)
attention.__dict__["proj"] = mocker_proj
output = attention.forward(
x=x,
cu_seqlens=cu_seqlens,
cos=cos,
sin=sin,
)
qkv_args, qkv_kwargs = mocker_qkv.call_args
assert qkv_args == (x, )
assert not qkv_kwargs
split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args
assert split_qkv_args == (x, )
assert not split_qkv_kwargs
npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args
assert npu_rotary_mul_args[1:] == (cos, sin)
assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128])
assert not npu_rotary_mul_kwargs
assert output.shape == torch.Size([100, 3, 1280])
class TestAscendQwen2VisionBlock(PytestBase):
def init_vision_block(
self,
mocker,
dim=100,
num_heads=10,
mlp_ratio=0.5,
):
mocker_vit = mocker.patch(
"vllm.model_executor.models.qwen2_vl.Qwen2VisionBlock.__init__",
return_value=None,
)
mocker_attn = mocker.patch(
"vllm_ascend.models.qwen2_vl.AscendQwen2VisionAttention.__init__",
return_value=None,
)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
vision_block = AscendQwen2VisionBlock(
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
)
args, kwargs = mocker_vit.call_args
assert args == (dim, num_heads, mlp_ratio, QuickGELU, None, None, "")
assert not kwargs
args1, kwargs1 = mocker_attn.call_args
assert not args1
assert kwargs1 == {
"embed_dim": dim,
"num_heads": num_heads,
"projection_size": dim,
"quant_config": None,
"prefix": ".attn",
}
return vision_block
def test_init_vision_block_should_normal(
self,
mocker: MockerFixture,
):
vision_block = self.init_vision_block(mocker)
assert isinstance(vision_block, AscendQwen2VisionBlock)
def test_vision_block_forward(self, mocker: MockerFixture):
x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d
cu_seqlens = torch.tensor([10, 50, 100])
cos = torch.rand((1, 100, 1, 128))
sin = torch.rand((1, 100, 1, 128))
vision_block = self.init_vision_block(mocker)
mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x)
mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x)
vision_block.__dict__["attn"] = mocker_attn
vision_block.__dict__["mlp"] = mocker_mlp
output = vision_block.forward(x.clone(), cu_seqlens, cos, sin)
_, attn_kwargs = mocker_attn.call_args
assert attn_kwargs == {
"cu_seqlens": cu_seqlens,
"cos": cos,
"sin": sin,
}
assert torch.all(x * 3 == output)

View File

@@ -0,0 +1,98 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import math
import unittest
import pytest
import torch
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM
from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention
class TestCustomQwen3MoeForCausalLM:
def test_class_inheritance(self):
assert issubclass(CustomQwen3MoeForCausalLM, Qwen3MoeForCausalLM)
@pytest.mark.parametrize("key, expected", [
("qkv_proj", ["q_proj", "k_proj", "v_proj"]),
("gate_up_proj", ["gate_proj", "up_proj"]),
("experts",
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]),
])
def test_packed_modules_mapping(self, key, expected):
assert CustomQwen3MoeForCausalLM.packed_modules_mapping[
key] == expected
def test_packed_modules_mapping_structure(self):
expected_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": [
"experts.0.gate_proj", "experts.0.up_proj",
"experts.0.down_proj"
]
}
assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping
class DummyRMSNorm:
def __init__(self, dim: int, eps: float = 1e-6):
self.dim = dim
self.eps = eps
def __call__(self, x):
mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
denom = (mean_sq + self.eps).sqrt()
return x / denom
class TestCustomQwen3MoeAttention(unittest.TestCase):
def setUp(self):
self.batch = 2
self.seq_len = 3
self.q_size = 8
self.kv_size = 8
self.head_dim = 4
self.rms_eps = 1e-6
total_dim = self.q_size + 2 * self.kv_size
self.qkv = torch.arange(self.batch * self.seq_len * total_dim,
dtype=torch.float32).reshape(
self.batch, self.seq_len, total_dim)
def test_constant_input_normalization(self):
ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size),
dtype=torch.float32)
q_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
k_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm)
norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)
expected_q = torch.full((1, 1, self.q_size), norm_val)
expected_k = torch.full((1, 1, self.kv_size), norm_val)
expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32)
self.assertTrue(torch.allclose(q, expected_q, atol=1e-6))
self.assertTrue(torch.allclose(k, expected_k, atol=1e-6))
self.assertTrue(torch.equal(v, expected_v))

View File

@@ -0,0 +1,32 @@
from tests.ut.base import TestBase
from vllm_ascend.multistream.base import (MSAttentionMetadataSplitConfig,
MSEventKey)
class Testbase(TestBase):
def test_ms_event_key(self):
self.assertEqual(MSEventKey.ATTN_COM_FINISH.value, 0)
self.assertEqual(MSEventKey.ATTN_AR_FINISH.value, 1)
self.assertEqual(MSEventKey.FFN_COM_FINISH.value, 2)
self.assertEqual(MSEventKey.FFN_AR_FINISH.value, 3)
self.assertEqual(MSEventKey.MOE_BEFORE_COMM.value, 4)
self.assertEqual(MSEventKey.MOE_AFTER_COMM.value, 5)
self.assertEqual(MSEventKey.MOE_SE_COMM_FINISH.value, 6)
self.assertEqual(MSEventKey.MOE_SE_COMP_FINISH.value, 7)
self.assertEqual(MSEventKey.MOE_GATE_FINISH.value, 8)
def test_ms_attention_metadata_split_config_default(self):
config = MSAttentionMetadataSplitConfig()
self.assertEqual(config.num_micro_batches, 2)
self.assertEqual(config.min_total_tokens_to_split, 256)
self.assertEqual(config.min_prefill_tokens_to_split, 64)
def test_ms_attention_metadata_split_config_custom(self):
config = MSAttentionMetadataSplitConfig(
num_micro_batches=4,
min_total_tokens_to_split=512,
min_prefill_tokens_to_split=128)
self.assertEqual(config.num_micro_batches, 4)
self.assertEqual(config.min_total_tokens_to_split, 512)
self.assertEqual(config.min_prefill_tokens_to_split, 128)

View File

@@ -0,0 +1,47 @@
import pytest
from pytest_mock import MockFixture
from tests.ut.base import PytestBase
from vllm_ascend.multistream.decorator import set_multistream_support
class Context:
def __init__(self, attn_metadata=None):
self.attn_metadata = attn_metadata
class TestDecorator(PytestBase):
@pytest.mark.parametrize(
'layer_context, microbatch_context, expected_metadata', [
((-1, None, None), -1, {
"original": True
}),
((-1, None, None), 0, {
"original": True
}),
((0, None, None), -1, {
"original": True
}),
((0, None, [{
"new": True
}]), 0, {
"new": True
}),
])
def test_decorator(self, mocker: MockFixture, layer_context,
microbatch_context, expected_metadata):
def context_func():
return Context(attn_metadata={"original": True})
mocker.patch(
'vllm_ascend.multistream.decorator.get_multistream_layer_context',
return_value=layer_context)
mocker.patch(
'vllm_ascend.multistream.decorator.get_multistream_microbatch_context',
return_value=microbatch_context)
context = set_multistream_support()(context_func)()
assert context.attn_metadata == expected_metadata

View File

@@ -0,0 +1,198 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import MagicMock, patch
import pytest
import torch
from tests.ut.base import PytestBase
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer,
MultiStreamPreTransformerLayer)
from vllm_ascend.multistream.metadata import MultiStreamMetadata
# === fixture: mock tensor input ===
@pytest.fixture
def input_tensors():
return [torch.randn(2, 128), torch.randn(2, 128)]
# === mock get_forward_context ===
class DummyContext:
def __init__(self, attn_metadata):
self.attn_metadata = attn_metadata
class TestMultiStreamPreTransformerLayer(PytestBase):
# === test when multistream_metadata is None ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_no_multistream_metadata(self, mock_set_ctx, mock_get_ctx,
input_tensors):
mock_get_ctx.return_value = DummyContext(attn_metadata="dummy_meta")
layer = MultiStreamPreTransformerLayer(multistream_metadata=None)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out == "dummy_meta"
assert input_out == input_tensors
mock_set_ctx.assert_called_once_with(-1, None, None)
# === test when attn_metadata is None ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_no_attn_metadata(self, mock_set_ctx, mock_get_ctx,
input_tensors):
mock_get_ctx.return_value = DummyContext(attn_metadata=None)
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
layer = MultiStreamPreTransformerLayer(
multistream_metadata=dummy_metadata)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out is None
assert input_out == input_tensors
mock_set_ctx.assert_called_once_with(-1, None, None)
# === test when do_ms=False (no split needed) ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_no_split(self, mock_set_ctx, mock_get_ctx, input_tensors):
dummy_attn = "original_attn"
mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn)
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
dummy_metadata.split_micro_batch.return_value = (False, "same_attn",
input_tensors, None)
layer = MultiStreamPreTransformerLayer(
multistream_metadata=dummy_metadata)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out == "same_attn"
assert input_out == input_tensors
mock_set_ctx.assert_called_once_with(-1, None, None)
# === test when do_ms=True (split occurred) ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_split(self, mock_set_ctx, mock_get_ctx, input_tensors):
dummy_attn = "original_attn"
mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn)
split_inputs = [[t[:1], t[1:]] for t in input_tensors]
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
dummy_metadata.start_layer = 2
dummy_metadata.split_micro_batch.return_value = (True,
["attn1", "attn2"],
split_inputs, None)
layer = MultiStreamPreTransformerLayer(
multistream_metadata=dummy_metadata)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out == ["attn1", "attn2"]
assert input_out == split_inputs
mock_set_ctx.assert_called_once_with(2, dummy_metadata,
["attn1", "attn2"])
class TestMultiStreamPostTransformerLayer(PytestBase):
def test_post_forward_metadata_none(self, input_tensors):
layer = MultiStreamPostTransformerLayer(multistream_metadata=None)
output = layer.forward(input_tensors)
assert output == input_tensors
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
dummy_metadata.ms_config = None
layer = MultiStreamPostTransformerLayer(
multistream_metadata=dummy_metadata)
output = layer.forward(input_tensors)
assert output == input_tensors
@patch("vllm_ascend.multistream.layers.get_multistream_layer_context")
@patch("vllm_ascend.multistream.layers.reset_multistream_layer_context")
def test_post_forward_normal_flow(self, mock_reset_ctx, mock_get_ctx,
input_tensors):
A_instance_of_MultiStreamMetadata = MultiStreamMetadata(
calculate_stream=MagicMock(),
communicate_stream=MagicMock(),
start_layer=0,
end_layer=1,
event_keys=[],
multistream_config=None,
)
dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata)
dummy_metadata.ms_config.num_micro_batches = 4
dummy_metadata.end_layer = 10
mock_get_ctx.return_value = (
5, # layer_index
dummy_metadata, # ms_metadata
"dummy_attn_metadata" # ms_attn_metadata
)
dummy_metadata.merge_micro_batches.return_value = "merged_result"
layer = MultiStreamPostTransformerLayer(
multistream_metadata=dummy_metadata)
output = layer.forward(input_tensors)
# check wait_event
dummy_metadata.try_wait_event.assert_called_once_with(
9, # end_layer - 1
3, # num_micro_batches - 1
MSEventKey.FFN_AR_FINISH)
mock_reset_ctx.assert_called_once()
assert output == "merged_result"
@patch("vllm_ascend.multistream.layers.get_multistream_layer_context")
@patch("vllm_ascend.multistream.layers.reset_multistream_layer_context")
def test_post_forward_with_custom_wait_layer(self, mock_reset_ctx,
mock_get_ctx, input_tensors):
A_instance_of_MultiStreamMetadata = MultiStreamMetadata(
calculate_stream=MagicMock(),
communicate_stream=MagicMock(),
start_layer=0,
end_layer=1,
event_keys=[],
multistream_config=None,
)
dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata)
dummy_metadata.ms_config.num_micro_batches = 4
dummy_metadata.end_layer = 10
mock_get_ctx.return_value = (
3, # layer_index
dummy_metadata,
"dummy_attn_metadata")
dummy_metadata.merge_micro_batches.return_value = "merged_result"
layer = MultiStreamPostTransformerLayer(
multistream_metadata=dummy_metadata)
output = layer.forward(input_tensors, wait_layer_index=7)
dummy_metadata.try_wait_event.assert_called_once_with(
7, 3, MSEventKey.FFN_AR_FINISH)
mock_reset_ctx.assert_called_once()
assert output == "merged_result"

View File

@@ -0,0 +1,246 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
MultiStreamMetadata,
MultiStreamStepMetadata,
split_micro_batches_tensors)
class TestMetaData(TestBase):
def setUp(self):
self.test_tensors_list = [torch.randn(100, 1024) for i in range(3)]
self.test_tensors = torch.randn(100, 1024)
self.test_tensors_dict = {
'query': torch.randn(100, 1024),
'key': torch.randn(100, 1024),
'value': torch.randn(100, 1024)
}
self.split_index = 50
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
self.metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
def test_split_micro_batches_tensors(self):
test_tensors_list_res = split_micro_batches_tensors(
self.test_tensors_list, self.split_index)
test_tensors_res = split_micro_batches_tensors(self.test_tensors,
self.split_index)
keys = ['query', 'key', 'value']
test_tensors_dict_res = split_micro_batches_tensors(
self.test_tensors_dict, self.split_index, keys)
for i in range(3):
self.assertEqual(len(test_tensors_list_res[i][0]),
self.split_index)
self.assertEqual(
len(test_tensors_list_res[i][0]) +
len(test_tensors_list_res[i][1]), 100)
self.assertEqual(len(test_tensors_res[0]), self.split_index)
self.assertEqual(
len(test_tensors_res[0]) + len(test_tensors_res[1]), 100)
for key in keys:
self.assertEqual(len(test_tensors_dict_res[0][key]),
self.split_index)
self.assertEqual(
len(test_tensors_dict_res[0][key]) +
len(test_tensors_dict_res[1][key]), 100)
def test_default_init_multistream_step_metadata(self):
metadata = MultiStreamStepMetadata()
self.assertIsNone(metadata.comm_stream)
self.assertIsNone(metadata.before_comm_event)
self.assertIsNone(metadata.after_comm_event)
def test_custom_init_multistream_step_metadata(self):
mockStream = MagicMock(spec=torch.npu.Stream)
mockEvent1 = MagicMock(spec=torch.npu.Event)
mockEvent2 = MagicMock(spec=torch.npu.Event)
metadata = MultiStreamStepMetadata(mockStream, mockEvent1, mockEvent2)
self.assertEqual(metadata.comm_stream, mockStream)
self.assertEqual(metadata.before_comm_event, mockEvent1)
self.assertEqual(metadata.after_comm_event, mockEvent2)
def test_default_init_multistream_config(self):
config = MultiStreamConfig()
self.assertEqual(config.min_total_tokens_to_split, 256)
self.assertEqual(config.min_prefill_tokens_to_split, 64)
self.assertEqual(config.num_micro_batches, 2)
self.assertEqual(config.imbalance_ratio, 0.1)
def test_custom_init_multistream_config(self):
config = MultiStreamConfig(512, 128, 1, 0.2)
self.assertEqual(config.min_total_tokens_to_split, 512)
self.assertEqual(config.min_prefill_tokens_to_split, 128)
self.assertEqual(config.num_micro_batches, 1)
self.assertEqual(config.imbalance_ratio, 0.2)
def test_init_multistream_metadata(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock()]
multistream_config = MagicMock(spec=MultiStreamConfig)
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
self.assertEqual(metadata.calculate_stream, mock_stream)
self.assertEqual(metadata.communicate_stream, mock_stream)
self.assertEqual(metadata.start_layer, 1)
self.assertEqual(metadata.end_layer, 3)
self.assertEqual(metadata.ms_config, multistream_config)
self.assertTrue(metadata.causal_lm)
def test_build_events(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
with patch('torch.npu.Event', return_value=mock_event):
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MultiStreamConfig(
num_micro_batches=2,
min_total_tokens_to_split=256,
min_prefill_tokens_to_split=64)
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
expected_events = {
0: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
},
1: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
},
2: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
}
}
self.assertEqual(metadata.ms_events, expected_events)
def test_build_ms_split_config(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
multistream_config.num_micro_batches = 2
multistream_config.min_total_tokens_to_split = 256
multistream_config.min_prefill_tokens_to_split = 64
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
self.assertIsNotNone(metadata.ms_split_config)
self.assertEqual(metadata.ms_split_config.num_micro_batches,
multistream_config.num_micro_batches)
self.assertEqual(metadata.ms_split_config.min_total_tokens_to_split,
multistream_config.min_total_tokens_to_split)
self.assertEqual(metadata.ms_split_config.min_prefill_tokens_to_split,
multistream_config.min_prefill_tokens_to_split)
def test_try_wait_event(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
with patch('torch.npu.Event', return_value=mock_event):
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
metadata.try_wait_event(layer_index=1,
micro_batch_index=0,
event_key=event_keys[0])
mock_event.wait.assert_called_once()
def test_try_record_event(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
with patch('torch.npu.Event', return_value=mock_event):
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
metadata.try_record_event(layer_index=1,
micro_batch_index=0,
event_key=event_keys[0])
mock_event.record.assert_called_once()
def test_merge_batches_none_input(self):
input_tensors = None
result = self.metadata.merge_micro_batches(input_tensors)
self.assertIsNone(result)
def test_merge_batches_single_tensor_input(self):
input_tensors = [torch.tensor([1, 2, 3])]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 1)
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3])))
def test_merge_batches_list_of_tensors_input(self):
input_tensors = [torch.tensor([1, 2]), torch.tensor([3, 4])]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 2)
self.assertEqual(result, input_tensors)
def test_merge_batches_nested_list_input(self):
input_tensors = [[torch.tensor([1, 2]),
torch.tensor([3, 4])],
[torch.tensor([5, 6]),
torch.tensor([7, 8])]]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3, 4])))
self.assertTrue(torch.equal(result[1], torch.tensor([5, 6, 7, 8])))

View File

@@ -0,0 +1,147 @@
from unittest.mock import MagicMock
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import (compute_split_seq_index,
model_input_split_v1_mla_attn,
split_attn_int_type,
split_attn_tensor_type)
class TestMsSplit(TestBase):
def test_decode_only(self):
result = compute_split_seq_index(
query_lens=None,
attn_state=AscendAttentionState.DecodeOnly,
num_tokens=10)
self.assertEqual(result, [5, 5])
def test_perfect_balance(self):
query_lens = [2, 3, 5]
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [5, 2])
def test_imbalance(self):
query_lens = [1, 2, 3, 4]
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [0, 0])
def test_query_lens_none(self):
with self.assertRaises(AssertionError):
compute_split_seq_index(
query_lens=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
def test_empty_query_lens(self):
query_lens: list[int] = []
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [0, 0])
def test_single_query_len(self):
query_lens = [10]
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [0, 0])
def test_split_attn_tensor_type_middle(self):
input_tensor = torch.tensor([1, 2, 3, 4, 5])
index = 3
expected_result = [torch.tensor([1, 2, 3]), torch.tensor([4, 5])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_tensor_type_start(self):
input_tensor = torch.tensor([1, 2, 3, 4, 5])
index = 0
expected_result = [torch.tensor([]), torch.tensor([1, 2, 3, 4, 5])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_tensor_type_end(self):
input_tensor = torch.tensor([1, 2, 3, 4, 5])
index = 5
expected_result = [torch.tensor([1, 2, 3, 4, 5]), torch.tensor([])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_tensor_type_empty_tensor(self):
input_tensor = torch.tensor([])
index = 0
expected_result = [torch.tensor([]), torch.tensor([])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_int_type_index_greater_than_var(self):
var = 5
index = 10
expected_result = [5, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_index_equal_to_var(self):
var = 5
index = 5
expected_result = [5, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_index_less_than_var(self):
var = 10
index = 5
expected_result = [5, 5]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_index_zero(self):
var = 10
index = 0
expected_result = [0, 10]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_var_zero(self):
var = 0
index = 5
expected_result = [0, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_both_zero(self):
var = 0
index = 0
expected_result = [0, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_v1_mla_attn_input_none(self):
attn_metadata = None
ascendMLAPrefillMetadata = MagicMock()
ms_split_config = MSAttentionMetadataSplitConfig(num_micro_batches=1)
result = model_input_split_v1_mla_attn(attn_metadata,
ascendMLAPrefillMetadata,
ms_split_config)
self.assertEqual(result, [None])

View File

@@ -0,0 +1,17 @@
{
"moe_layer_count":
1,
"layer_list": [{
"layer_id":
0,
"device_count":
2,
"device_list": [{
"device_id": 0,
"device_expert": [7, 2, 0, 3, 5]
}, {
"device_id": 1,
"device_expert": [6, 1, 4, 7, 2]
}]
}]
}

View File

@@ -0,0 +1,61 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
@pytest.fixture
def dummy_tensor():
return torch.randn(4, 8, dtype=torch.float16)
@patch("torch_npu.npu_fast_gelu", side_effect=lambda x: x + 1)
def test_QuickGELU_forward(mock_gelu, dummy_tensor):
layer = QuickGELU()
out = layer.forward(dummy_tensor)
expected_out = dummy_tensor + 1
assert torch.allclose(out, expected_out)
mock_gelu.assert_called_once()
@pytest.mark.parametrize("is_310p_return", [True, False])
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
layer = SiluAndMul()
out = layer.forward(dummy_tensor)
if is_310p_return:
expected_arg = dummy_tensor.to(torch.float32)
else:
expected_arg = dummy_tensor
# assert mock_swiglu.call_count == 1
mock_swiglu.assert_called_once()
actual_arg = mock_swiglu.call_args[0][0]
assert torch.allclose(
actual_arg,
expected_arg), "npu_swiglu called with unexpected input"
expected_out = dummy_tensor + 1
assert torch.allclose(out, expected_out)

View File

@@ -0,0 +1,69 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
class TestFusedExpertsMoGE(TestBase):
def test_fused_experts_moge(self):
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
patch('torch_npu.npu_swiglu') as mock_swiglu, \
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
mock_is_310p.return_value = False
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
torch.randn(x[0].shape[0], weight[0].shape[1])
]
mock_swiglu.side_effect = lambda x: x
hidden_states = torch.randn(4, 128)
w1 = torch.randn(4, 256, 128)
w2 = torch.randn(4, 128, 128)
topk_weights = torch.rand(4, 1)
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
top_k = 1
global_num_experts = 4
moe_parallel_config = type(
'MockConfig', (), {
'ep_size': 1,
'tp_size': 1,
'dp_size': 1,
'tp_rank': 0,
'dp_rank': 0,
'ep_rank': 0,
'use_ep': True
})()
output = fused_experts_moge(
hidden_states=hidden_states,
w1=w1,
w2=w2,
moe_parallel_config=moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
apply_router_weight_on_input=True,
)
self.assertEqual(output.shape, (4, 128))

View File

@@ -0,0 +1,141 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import json
import os
from typing import List, TypedDict
from unittest import mock
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
class Device(TypedDict):
device_id: int
device_expert: List[int]
class Layer(TypedDict):
layer_id: int
device_count: int
device_list: List[Device]
class MockData(TypedDict):
moe_layer_count: int
layer_list: List[Layer]
class TestExpertLoadBalancer(TestBase):
def setUp(self):
_TEST_DIR = os.path.dirname(__file__)
json_file = _TEST_DIR + "/expert_map.json"
with open(json_file, 'r') as f:
self.expert_map: MockData = json.load(f)
self.expert_load_balancer = ExpertLoadBalancer(json_file,
global_expert_num=8)
def test_init(self):
self.assertIsInstance(self.expert_load_balancer.expert_map_tensor,
torch.Tensor)
self.assertEqual(self.expert_load_balancer.layers_num,
self.expert_map["moe_layer_count"])
self.assertEqual(self.expert_load_balancer.ranks_num,
self.expert_map["layer_list"][0]["device_count"])
def test_generate_index_dicts(self):
tensor_2d = torch.tensor([[7, 2, 0, 3, 5], [6, 1, 4, 7, 2]])
result = self.expert_load_balancer.generate_index_dicts(tensor_2d)
expected_result = [{
7: 0,
2: 1,
0: 2,
3: 3,
5: 4
}, {
6: 5,
1: 6,
4: 7,
7: 8,
2: 9
}]
self.assertEqual(result, expected_result)
def test_generate_expert_placement_map(self):
expert_placement_map = self.expert_load_balancer.generate_expert_placement_map(
)
self.assertEqual(expert_placement_map.shape,
(self.expert_load_balancer.layers_num,
self.expert_load_balancer.ranks_num, 8))
self.assertTrue(torch.all(expert_placement_map >= -1))
def test_generate_log2phy_expert_map(self):
layer_id = 0
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
layer_id)
self.assertEqual(log2phy_map.shape,
(self.expert_load_balancer.ranks_num, 8))
self.assertTrue(torch.all(log2phy_map >= -1))
@mock.patch("torch_npu.npu._lazy_init")
@mock.patch("torch.npu.current_device", return_value="cpu")
def test_get_rank_placement_map(self, mock_current_device, mock_lazy_init):
layer_id = 0
rank_id = 0
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
layer_id, rank_id)
self.assertEqual(rank_local_expert_num, 5)
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0],
dtype=torch.int32).to(
rank_expert_map.device)
self.assertTrue(rank_expert_map.equal(expected_tensor))
rank_id = 1
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
layer_id, rank_id)
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3],
dtype=torch.int32).to(
rank_expert_map.device)
self.assertTrue(rank_expert_map.equal(expected_tensor))
def test_get_rank_log2phy_map(self):
layer_id = 0
rank_id = 0
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
layer_id, rank_id)
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0],
dtype=torch.int32).to(
log2phy_map.device)
self.assertTrue(log2phy_map.equal(expected_tensor))
rank_id = 1
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
layer_id, rank_id)
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8],
dtype=torch.int32).to(
log2phy_map.device)
self.assertTrue(log2phy_map.equal(expected_tensor))
def test_get_global_redundant_expert_num(self):
redundant_expert_num = self.expert_load_balancer.get_global_redundant_expert_num(
)
expected_redundant_expert_num = len(self.expert_map["layer_list"][0]["device_list"][0]["device_expert"]) * \
self.expert_map["layer_list"][0]["device_count"] - 8
self.assertEqual(redundant_expert_num, expected_redundant_expert_num)

View File

@@ -0,0 +1,741 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import List, TypedDict
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import (FusedMoEState,
_get_fused_moe_state)
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
from vllm_ascend.utils import AscendSocVersion, adapt_patch
adapt_patch(True)
def mock_ep_and_mc2_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.rank = 0
mock_group.world_size = 4
mock_group.device_group = "mock_group_ep"
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
return mock_group
def mock_dp_and_tp_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.world_size = 2
mock_group.device_group = "mock_group"
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
return mock_group
def mock_npu_format_cast(weight_data, format):
return weight_data
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
mock_setup_token_dispatchers = MagicMock()
mock_token_dispatcher_with_allgather = MagicMock()
mock_token_dispatcher_with_all2allv = MagicMock()
mock_token_dispatcher_with_mc2 = MagicMock()
mock_dispatch_result_allgather = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([8, 16], dtype=torch.int64),
"group_list_type": 0,
}
mock_combine_result_allgather = torch.randn(16, 2)
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
mock_dispatch_result_all2allv = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
}
mock_combine_result_all2allv = torch.randn(16, 2)
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
mock_dispatch_result_mc2 = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
"assist_info_for_combine": torch.randn(16, 2),
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
}
mock_combine_result_mc2 = torch.randn(16, 2)
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
captured_dispatchers = {}
def capture_register(dispatcher_instance):
key = dispatcher_instance.__class__.__name__
captured_dispatchers[key] = dispatcher_instance
if key == 'TokenDispatcherWithAllGather':
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
elif key == 'TokenDispatcherWithAll2AllV':
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
elif key == 'TokenDispatcherWithMC2':
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
mock_register_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher',
side_effect=capture_register)
mock_get_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher',
side_effect=lambda name: captured_dispatchers.get(name))
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
mock_forward_context_obj = MagicMock(
fused_moe_state=FusedMoEState.AllGather,
token_dispatcher=default_mock_token_dispatcher,
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
mc2_mask=torch.zeros(16, dtype=torch.bool),
padded_num_tokens=16,
with_quant=False)
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('torch.distributed.all_gather'), \
patch('torch.distributed.all_to_all_single'), \
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
expert_map_path=None
)), \
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
scheduler_config=MagicMock(max_num_seqs=4),
model_config=MagicMock(max_model_len=2048)
)), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj):
yield {
'mock_forward_context_obj': mock_forward_context_obj,
'mock_token_dispatcher_with_allgather':
mock_token_dispatcher_with_allgather,
'mock_token_dispatcher_with_all2allv':
mock_token_dispatcher_with_all2allv,
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
}
mock_register_token_dispatcher_patcher.stop()
mock_get_token_dispatcher_patcher.stop()
@pytest.fixture
def mock_moe_env(mocker: MockerFixture):
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
None
)), \
patch('torch_npu.npu_moe_init_routing', return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
)), \
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
torch.randn(8, 2)
)), \
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_moe_distribute_combine", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_grouped_matmul", return_value=(
[torch.randn(16, 2)]
)), \
patch("torch_npu.npu_swiglu", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
)), \
patch("torch_npu.npu_moe_finalize_routing", return_value=(
torch.randn(16, 2)
)):
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
torch.randn(16, 2))), \
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
torch.randn(16, 2))):
yield
else:
yield
@pytest.fixture
def default_moe_config():
return {
'num_experts': 8,
'top_k': 2,
'hidden_size': 512,
'intermediate_size': 1024
}
@pytest.fixture
def moe_method(mock_dist_env):
moe = MagicMock()
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
return AscendUnquantizedFusedMoEMethod(moe)
class Device(TypedDict):
device_id: int
device_expert: List[int]
class Layer(TypedDict):
layer_id: int
device_count: int
device_list: List[Device]
class MockData(TypedDict):
moe_layer_count: int
layer_list: List[Layer]
class MockQuantMethod(nn.Module):
def __init__(self, shared_experts, num_tokens):
super().__init__()
if shared_experts:
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
torch.randn(num_tokens, 10)))
else:
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
class MockFusedMoEMethod(FusedMoEMethodBase):
moe = MagicMock()
def __init__(self):
super().__init__(self.moe)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
pass
def apply(self, hidden_states: torch.Tensor,
expert_weights: torch.Tensor) -> torch.Tensor:
pass
class TestAscendFusedMoe:
def test_init_no_quant(self, mock_dist_env, default_moe_config):
layer = AscendFusedMoE(**default_moe_config)
layer.w13_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['intermediate_size'] * 2,
default_moe_config['hidden_size']))
layer.w2_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['hidden_size'],
default_moe_config['intermediate_size']))
assert layer.num_experts == default_moe_config['num_experts']
assert layer.top_k == default_moe_config['top_k']
assert hasattr(layer, 'w13_weight')
assert hasattr(layer, 'w2_weight')
with pytest.raises(AssertionError):
error_config = default_moe_config.copy()
error_config['use_grouped_topk'] = True
layer = AscendFusedMoE(**error_config)
with pytest.raises(ValueError):
error_config = default_moe_config.copy()
error_config['scoring_func'] = "random"
layer = AscendFusedMoE(**error_config)
def test_init_with_quant(self, mock_dist_env, default_moe_config):
mock_quant_config = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method
moe = AscendFusedMoE(**default_moe_config,
quant_config=mock_quant_config)
assert moe.quant_method is not None
assert moe.quant_method == mock_quant_method
@pytest.mark.parametrize(
"others_param",
[[None,
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
[2, None, False, 5, None], [None, None, True, 5, None],
[None, None, False, 1, None], [None, None, True, 5, 1],
[None, None, False, 5, 1]])
def test_forward(self, mock_dist_env, default_moe_config, others_param):
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
inputs = torch.randn(num_tokens, 32)
router_logits = torch.randn(num_tokens, 8)
moe = AscendFusedMoE(**default_moe_config)
if ep_size == 1:
moe.moe_parallel_config.ep_size = 1
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
dtype=torch.bool),
padded_num_tokens=num_tokens)
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
output = moe.forward(inputs,
router_logits,
is_prefill=is_prefill,
top_k=top_k,
shared_experts=shared_experts)
moe.quant_method.apply.assert_called_once()
if shared_experts:
assert output[0].shape == (num_tokens, 32)
assert output[1].shape == (num_tokens, 10)
else:
assert output.shape == (num_tokens, 32)
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
default_moe_config):
inputs = torch.randn(5, 32)
router_logits = torch.randn(5, 8)
moe = AscendFusedMoE(**default_moe_config)
moe.quant_method = MockQuantMethod(None, 5)
output = moe._forward_ms_fused_moe_comp(inputs,
router_logits,
is_prefill=False,
real_top_k=1)
moe.quant_method.apply.assert_called_once()
assert output.shape == (5, 32)
class TestAscendUnquantizedFusedMoEMethod:
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
layer = MagicMock()
layer.w13_weight.data = torch.randn(16, 32)
layer.w2_weight.data = torch.randn(16, 32)
with patch('torch_npu.npu_format_cast', mock_npu_format_cast), \
patch('vllm_ascend.utils.is_310p', return_value=False):
moe_method.process_weights_after_loading(layer)
assert isinstance(layer.w13_weight, torch.nn.Parameter)
assert isinstance(layer.w2_weight, torch.nn.Parameter)
assert not layer.w13_weight.requires_grad
assert not layer.w2_weight.requires_grad
@pytest.mark.parametrize("others_param",
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
global_num_experts, ep_size = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
local_num_experts = 2
hidden_size = 2
intermediate_size_per_partition = 4
layer.w13_weight = torch.randn(local_num_experts,
intermediate_size_per_partition * 2,
hidden_size)
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
intermediate_size_per_partition)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=global_num_experts,
is_prefill=is_prefill)
expected_shape = (16, 2)
assert result.shape == expected_shape
@pytest.mark.parametrize("others_param", [16, 1, 4])
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
ep_size = others_param
is_prefill = False
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, True),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
if ep_size == 1:
x = x.view(-1, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
local_num_experts = 2
hidden_size = 2
intermediate_size_per_partition = 4
layer.w13_weight = torch.randn(local_num_experts,
intermediate_size_per_partition * 2,
hidden_size)
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
intermediate_size_per_partition)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=128,
expert_map=expert_map,
is_prefill=is_prefill)
expected_shape = (16, 2)
assert result.shape == expected_shape
class TestExpertsSelector:
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
def test_select_experts(self, mock_dist_env, mock_moe_env,
global_num_experts):
x = torch.randn(8, 2)
router_logits = torch.randn(8, 2)
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=2,
use_grouped_topk=False,
renormalize=True,
topk_group=None,
num_expert_group=None,
custom_routing_function=None,
scoring_func="softmax",
e_score_correction_bias=None,
global_num_experts=global_num_experts)
assert topk_weights.shape == (8, 2)
assert topk_ids.shape == (8, 2)
class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_dynamic_quant')
@patch('torch_npu.npu_dequant_swiglu_quant')
def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
mock_npu_dynamic_quant,
mock_npu_grouped_matmul,
mock_is_310p,
mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.fused_moe_state = FusedMoEState.MC2
mock_get_forward_context.return_value = mock_forward_context
mock_is_310p.return_value = False
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
127, (10, 20),
dtype=torch.int8),
torch.rand(10,
1,
dtype=torch.float32))
mock_npu_grouped_matmul.side_effect = [[
torch.randint(-2147483648, 2147483647, (10, 40), dtype=torch.int32)
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
mock_npu_dequant.return_value = (torch.randn(10,
40,
dtype=torch.bfloat16),
torch.randn(10,
1,
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
w1 = torch.randint(-128, 127, (5, 20, 40), dtype=torch.int8)
w1_scale = torch.randn(5, 40, dtype=torch.float32)
w2 = torch.randint(-128, 127, (5, 40, 20), dtype=torch.int8)
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=None,
with_quant=True)
mock_get_forward_context.assert_called()
self.assertEqual(mock_forward_context.fused_moe_state,
FusedMoEState.MC2)
mock_npu_dynamic_quant.assert_called()
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_dequant.assert_called_once()
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_without_quantization(self,
mock_npu_dynamic_quant,
mock_npu_swiglu,
mock_npu_grouped_matmul,
mock_is_310p):
mock_is_310p.return_value = False
mock_npu_grouped_matmul.side_effect = [[
torch.randn(10, 40, dtype=torch.float16)
], [torch.randn(10, 20, dtype=torch.float16)]]
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
hidden_states = torch.randn(10, 20, dtype=torch.float16)
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales,
with_quant=False)
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
mock_forward_context.fused_moe_state = "NOT_MC2"
mock_get_forward_context.return_value = mock_forward_context
mock_npu_grouped_matmul.side_effect = [[
torch.randn(10, 40, dtype=torch.bfloat16)
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
mock_npu_swiglu.return_value = torch.randn(10,
40,
dtype=torch.bfloat16)
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
127, (10, 40),
dtype=torch.int8),
torch.rand(10,
1,
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=provided_dynamic_scale,
group_list_type=1,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=None,
with_quant=True)
mock_get_forward_context.assert_called()
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
mock_npu_dynamic_quant.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_without_quantization_310p(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_is_310p):
mock_is_310p.return_value = True
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
mock_gmm2_out = torch.randn(10, 20, dtype=torch.float16)
mock_npu_grouped_matmul.side_effect = [[mock_gmm1_out],
[mock_gmm2_out]]
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
hidden_states = torch.randn(10, 20, dtype=torch.float16)
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales,
with_quant=False)
mock_is_310p.assert_called_once()
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)

View File

@@ -0,0 +1,53 @@
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
@pytest.fixture
def dummy_tensor():
return torch.randn(4, 8, dtype=torch.float16)
def mock_rms_norm(x, weight, eps):
return x + 1, None
def mock_add_rms_norm(x, residual, weight, eps):
return 2 * x, None, 2 * residual
@pytest.mark.parametrize("is_310p_return", [True, False])
@pytest.mark.parametrize("residual",
[None, torch.randn(4, 8, dtype=torch.float32)])
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
residual, dummy_tensor):
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
layer = RMSNorm(hidden_size=32, eps=1e-05)
if residual is not None:
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
if is_310p_return:
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
expected_out_x = expected_arg_x + 1
expected_out_residual = expected_arg_x.to(residual.dtype)
mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
expected_out_x = 2 * dummy_tensor
expected_out_residual = 2 * residual
mock_add_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
out_x = layer.forward(dummy_tensor, residual)
expected_out_x = dummy_tensor + 1
mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)

363
tests/ut/ops/test_linear.py Normal file
View File

@@ -0,0 +1,363 @@
import os
import unittest
from unittest import mock
import torch
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
AscendMlpMergedColumnParallelLinear,
AscendMlpRowParallelLinear, LinearBase,
QuantizationConfig)
class TestAscendMlpRowParallelLinear(unittest.TestCase):
def setUp(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
self.tensor_parallel_world_size = 2
self.tensor_parallel_rank = 0
self.mlp_tensor_parallel_world_size = 2
self.mlp_tensor_parallel_rank = 1
self.get_tensor_model_parallel_world_size_patch = mock.patch(
'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size',
return_value=self.tensor_parallel_world_size)
self.get_tensor_model_parallel_rank_patch = mock.patch(
'vllm_ascend.ops.linear.get_tensor_model_parallel_rank',
return_value=self.tensor_parallel_rank)
self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch(
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size',
return_value=self.mlp_tensor_parallel_world_size)
self.get_mlp_tensor_model_parallel_rank_patch = mock.patch(
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank',
return_value=self.mlp_tensor_parallel_rank)
self.get_tensor_model_parallel_world_size_mock = \
self.get_tensor_model_parallel_world_size_patch.start()
self.get_tensor_model_parallel_rank_mock = \
self.get_tensor_model_parallel_rank_patch.start()
self.get_mlp_tensor_model_parallel_world_size_mock = \
self.get_mlp_tensor_model_parallel_world_size_patch.start()
self.get_mlp_tensor_model_parallel_rank_mock = \
self.get_mlp_tensor_model_parallel_rank_patch.start()
self.split_tensor_along_last_dim_patch = mock.patch(
'vllm_ascend.ops.linear.split_tensor_along_last_dim',
return_value=(torch.randn(10, 8), torch.randn(10, 8)))
self.tensor_model_parallel_all_reduce_patch = mock.patch(
'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce',
return_value=torch.randn(10, 8))
self.tensor_model_parallel_all_reduce_mock = \
self.tensor_model_parallel_all_reduce_patch.start()
self.split_tensor_along_last_dim_mock = \
self.split_tensor_along_last_dim_patch.start()
self.get_mlp_tp_group_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
self.get_mlp_tp_group_mock.return_value.reduce_scatter = \
mock.MagicMock()
def tearDown(self):
self.get_tensor_model_parallel_world_size_patch.stop()
self.get_tensor_model_parallel_rank_patch.stop()
self.get_mlp_tensor_model_parallel_world_size_patch.stop()
self.get_mlp_tensor_model_parallel_rank_patch.stop()
self.split_tensor_along_last_dim_patch.stop()
self.tensor_model_parallel_all_reduce_patch.stop()
self.get_mlp_tp_group_patch.stop()
def test_init_with_down_proj_prefix(self):
layer = AscendMlpRowParallelLinear(input_size=16,
output_size=8,
prefix="down_proj")
self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size)
self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank)
self.assertTrue(layer.enable_mlp_optimze)
def test_forward_with_mlp_optimize(self):
layer = AscendMlpRowParallelLinear(
input_size=16,
output_size=8,
prefix="down_proj",
input_is_parallel=False,
)
input_tensor = torch.randn(16, 8) # (batch_size, input_size)
layer(input_tensor)
self.split_tensor_along_last_dim_mock.assert_called_once_with(
input_tensor, num_partitions=layer.tp_size)
def test_forward_without_mlp_optimize(self):
layer = AscendMlpRowParallelLinear(
input_size=16,
output_size=8,
prefix="other",
input_is_parallel=False,
)
input_tensor = torch.randn(16, 8)
layer(input_tensor)
self.split_tensor_along_last_dim_mock.assert_called_once_with(
input_tensor, num_partitions=layer.tp_size)
self.tensor_model_parallel_all_reduce_mock.assert_called_once()
def test_skip_bias_add(self):
layer = AscendMlpRowParallelLinear(
input_size=16,
output_size=8,
skip_bias_add=True,
)
input_tensor = torch.randn(16, 8)
output, bias = layer(input_tensor)
self.assertIsNotNone(bias)
def test_no_reduce_results(self):
layer = AscendMlpRowParallelLinear(input_size=16,
output_size=8,
reduce_results=False,
bias=False)
input_tensor = torch.randn(16, 8)
layer(input_tensor)
self.tensor_model_parallel_all_reduce_mock.assert_not_called()
def test_input_not_parallel(self):
layer = AscendMlpRowParallelLinear(input_size=16,
output_size=8,
input_is_parallel=False)
input_tensor = torch.randn(16, 8)
layer(input_tensor)
self.split_tensor_along_last_dim_mock.assert_called_once()
def test_exception_when_reduce_false_and_bias(self):
with self.assertRaises(ValueError):
AscendMlpRowParallelLinear(input_size=16,
output_size=8,
reduce_results=False,
bias=True,
skip_bias_add=False)
class TestAscendMlpColumnParallelLinear(unittest.TestCase):
def setUp(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
# Mock distributed functions
self.mlp_tp_size_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size')
self.mlp_tp_size_mock = self.mlp_tp_size_patch.start()
self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group
self.mlp_tp_rank_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank')
self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start()
self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank
self.tp_size_patch = \
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size')
self.tp_size_mock = self.tp_size_patch.start()
self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group
self.tp_rank_patch = \
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank')
self.tp_rank_mock = self.tp_rank_patch.start()
self.tp_rank_mock.return_value = 1 # Current GPU rank
# Mock divide function (assumed to be in your module)
self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide')
self.divide_mock = self.divide_patch.start()
self.divide_mock.side_effect = lambda x, y: x // y # Simulate division
# Mock QuantizationConfig and QuantMethod
self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig)
# Mock LinearBase initialization
self.linear_base_init_patch = mock.patch.object(
LinearBase, "__init__", side_effect=self.mock_linear_base_init)
self.linear_base_init_patch.start()
self.quant_method_mock = mock.MagicMock()
def mock_linear_base_init(self, instance, *args, **kwargs):
instance.quant_method = self.quant_method_mock
instance.params_dtype = mock.MagicMock()
instance.input_size = 16
instance.output_size = 8
instance.output_size_per_partition = 4
instance.params_dtype = torch.float32
def tearDown(self):
self.mlp_tp_size_patch.stop()
self.mlp_tp_rank_patch.stop()
self.tp_size_patch.stop()
self.tp_rank_patch.stop()
self.divide_patch.stop()
self.linear_base_init_patch.stop()
def test_mlp_optimize_initialization(self):
# Test when prefix contains "gate_up_proj"
with mock.patch.object(torch.nn.Module, 'register_parameter'):
layer = AscendMlpColumnParallelLinear(
input_size=16,
output_size=8,
prefix="model.layers.0.gate_up_proj",
bias=False,
)
# Verify MLP optimization flags
self.assertTrue(layer.enable_mlp_optimze)
self.assertEqual(layer.tp_size, 2)
self.assertEqual(layer.tp_rank, 0)
self.assertEqual(layer.input_size_per_partition, 16)
self.assertEqual(layer.output_size_per_partition, 4)
# Check quant_method.create_weights was called
self.quant_method_mock.create_weights.assert_called_once()
def test_regular_parallel_initialization(self):
# Test when prefix does NOT contain "gate_up_proj"
with mock.patch.object(torch.nn.Module, 'register_parameter'):
layer = AscendMlpColumnParallelLinear(
input_size=16,
output_size=8,
prefix="model.layers.0.q_proj",
quant_config=self.quant_config_mock,
bias=False,
)
# Verify regular TP flags
self.assertFalse(layer.enable_mlp_optimze)
self.assertEqual(layer.tp_size, 4)
self.assertEqual(layer.tp_rank, 1)
self.assertEqual(layer.input_size_per_partition, 16)
self.assertEqual(layer.output_size_per_partition, 4)
# Check quant_method.create_weights was called
self.quant_method_mock.create_weights.assert_called_once()
def test_output_sizes_handling(self):
# Test when output_sizes is provided
with mock.patch.object(torch.nn.Module, 'register_parameter'):
layer = AscendMlpColumnParallelLinear(
input_size=16,
output_size=8,
output_sizes=[4, 4],
prefix="model.layers.0.qkv_proj",
quant_config=self.quant_config_mock,
bias=False,
)
# Verify output_partition_sizes
self.assertEqual(layer.output_partition_sizes, [2])
class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase):
def setUp(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
# Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size
self.mlp_world_size_patch = \
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2)
self.tensor_world_size_patch = \
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2)
self.mlp_world_size_patch.start()
self.tensor_world_size_patch.start()
# Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank
self.mlp_rank_patch = \
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0)
self.tensor_rank_patch = \
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0)
self.mlp_rank_patch.start()
self.tensor_rank_patch.start()
# Mock all_gather methods
self.get_mlp_tp_group_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock()
self.tensor_model_parallel_all_gather_patch = mock.patch(
'vllm_ascend.ops.linear.tensor_model_parallel_all_gather',
return_value=torch.randn(10, 8))
self.tensor_model_parallel_all_gather_mock = \
self.tensor_model_parallel_all_gather_patch.start()
# Mock AscendMlpColumnParallelLinear's __init__
self.linear_init_patch = mock.patch.object(
AscendMlpColumnParallelLinear,
"__init__",
side_effect=self.mock_linear_init)
self.linear_init_patch.start()
# Create mock objects
self.quant_method_mock = mock.MagicMock()
self.apply_output = torch.randn(2, 8)
self.quant_method_mock.apply.return_value = self.apply_output
def mock_linear_init(self, instance, *args, **kwargs):
torch.nn.Module.__init__(instance)
# Set quant_method and other attributes
instance.quant_method = self.quant_method_mock
instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias
instance.input_size = 16
instance.output_size = 8
instance.gather_output = False
instance.skip_bias_add = False
instance.return_bias = True
def test_forward_with_enable_mlp_optimze(self):
# Setup input
input_tensor = torch.randn(1, 16)
# Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
output_sizes=[8],
bias=True,
gather_output=False,
skip_bias_add=False,
params_dtype=torch.float32,
quant_config=None,
prefix="other_proj")
# Call forward
output, bias = layer(input_tensor)
# Validate calls
self.assertEqual(output.shape, self.apply_output.shape)
def test_forward_without_enable_mlp_optimze(self):
# Setup input
input_tensor = torch.randn(1, 16)
# Create instance with prefix not containing "gate_up_proj"
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
output_sizes=[8],
bias=True,
gather_output=False,
skip_bias_add=False,
params_dtype=torch.float32,
quant_config=None,
prefix="other_proj")
# Call forward
output, bias = layer(input_tensor)
# Validate calls
self.quant_method_mock.apply.assert_called_once_with(
layer, input_tensor, layer.bias)
self.tensor_model_parallel_all_gather_mock.assert_not_called()
self.assertEqual(output.shape, self.apply_output.shape)
def tearDown(self):
self.linear_init_patch.stop()
self.mlp_world_size_patch.stop()
self.tensor_world_size_patch.stop()
self.mlp_rank_patch.stop()
self.tensor_rank_patch.stop()
self.get_mlp_tp_group_mock.stop()
self.tensor_model_parallel_all_gather_mock.stop()

View File

@@ -0,0 +1,318 @@
import math
import unittest
from unittest.mock import MagicMock, PropertyMock, patch
import torch
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from tests.ut.base import TestBase
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 4, dtype=torch.float16)
self.key = torch.randn(3, 4, dtype=torch.float16)
self.head_size = 32
self.cos_sin_cache = torch.randn(3, 4)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = True
self.mock_self.forward_native.return_value = (self.query, self.key)
def test_custom_rotary_embedding_enabled(self):
# Test when all conditions are True
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = _custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertTrue(result)
# Test when dtype is not float16
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
query = self.query.to(torch.float32)
result = _custom_rotary_embedding_enabled(query, True,
self.head_size)
self.assertFalse(result)
# Test when neox_style is False
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = _custom_rotary_embedding_enabled(self.query, False,
self.head_size)
self.assertFalse(result)
# Test when head_size is not divisible by 32
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = _custom_rotary_embedding_enabled(self.query, True,
self.head_size + 1)
self.assertFalse(result)
# Test when custom op is disabled
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=False):
result = _custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertFalse(result)
class TestAscendRotaryEmbedding(unittest.TestCase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
self.head_size = 32
self.rotary_dim = self.head_size
self.max_position = 16
self.rope_theta = 10000
self.is_neox_style = True
self.cos_sin_cache = torch.randn(3, 1, 32)
self.layer = RotaryEmbedding(self.head_size, self.rotary_dim,
self.max_position, self.rope_theta,
self.is_neox_style, torch.float16)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = self.is_neox_style
@patch('torch.ops._C')
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=True)
@patch('torch.ops._npu_rotary_embedding')
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
mock_custom_enabled, mock_is_310p,
mock__c):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Setup mock for custom kernel path
mock__c.rotary_embedding.return_value = self.query, self.key
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
mock__c.rotary_embedding.assert_called_once()
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test contiguous path when custom is disabled
non_contig_query = self.query.transpose(0, 1)
non_contig_key = self.key.transpose(0, 1)
result_q, result_k = self.layer.forward(self.positions,
non_contig_query,
non_contig_key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, non_contig_query.shape)
self.assertEqual(result_k.shape, non_contig_key.shape)
def test_rope_forward_oot_with_offsets(self):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
self.layer.forward(self.positions, self.query, self.key, offsets)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test neox_style override
result_q, result_k = self.layer.forward(self.positions,
self.query,
self.key,
is_neox_style_override=False)
# Check that neox_style=False was passed to the NPU function
args, kwargs = mock_npu_rotary.call_args
self.assertFalse(args[-1])
class MockRopeModule:
def __init__(self, max_seq_len=2048, is_neox_style=True):
self.max_seq_len = max_seq_len
self.is_neox_style = is_neox_style
self.cos_cached = None
self.sin_cached = None
self.rotary_dim = 1
self.base = 1
class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
self.head_size = 32
self.rotary_dim = self.head_size
self.max_position = 16
self.rope_theta = 10000
self.is_neox_style = True
self.scaling_factor = 1
self.layer = None
def _create_layer(self):
self.layer = DeepseekScalingRotaryEmbedding(
self.head_size, self.rotary_dim, self.max_position,
self.rope_theta, self.is_neox_style, self.scaling_factor,
torch.float16)
return self.layer
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_base(self, mock_npuplatform):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
with patch("vllm_ascend.ops.rotary_embedding._rope_forward_oot",
return_value=(self.query,
self.key)) as mock_rope_forward_oot:
q_pe, k_pe = self.layer.forward(self.positions, self.query,
self.key)
mock_rope_forward_oot.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_cache_handling(
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
self.layer.max_seq_len = 1024
# Test cache situation is true
with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache:
mock_rope_forward_oot.return_value = (self.query, self.key)
q_pe, k_pe = self.layer.forward(self.positions,
self.query,
self.key,
max_seq_len=2048)
mock_set_cache.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_key_reshaping(
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
key = torch.randn(1, 32)
mock_rope_forward_oot.return_value = (self.query, key)
q_pe, k_pe = self.layer.forward(self.positions, self.query, key)
mock_rope_forward_oot.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == key.shape
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_non_neox_style(
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
mock_rope_forward_oot.return_value = (self.query, self.key)
q_pe, k_pe = self.layer.forward(self.positions, self.query, self.key)
mock_rope_forward_oot.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_basic_case(self, mock_npuplatform):
# Test with standard values
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
num_rotations = 100
dim = 512
base = 10000
max_position_embeddings = 2048
result = self.layer._yarn_find_correction_dim(num_rotations, dim, base,
max_position_embeddings)
# Calculate expected value manually
expected = (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 *
torch.log(torch.tensor(base)))
self.assertTrue(torch.allclose(result, expected))
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_yarn_get_mscale(self, mock_npuplatform):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
# test_scale_less_than_or_equal_1
self.assertEqual(self.layer._yarn_get_mscale(scale=0.5), 1.0)
self.assertEqual(self.layer._yarn_get_mscale(scale=1.0), 1.0)
self.assertEqual(self.layer._yarn_get_mscale(scale=0.999), 1.0)
# test_scale_greater_than_1:
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
(math.e, 1.0, 1.0 + 0.1)]
for scale, mscale, expected in test_cases:
result = self.layer._yarn_get_mscale(scale, mscale)
self.assertAlmostEqual(
result,
expected,
places=6,
msg=f"Failed for scale={scale}, mscale={mscale}")

View File

@@ -0,0 +1,606 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from unittest.mock import MagicMock, PropertyMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
AscendSocVersion, TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
class TestTokenDispatcherWithMC2(TestBase):
def setUp(self):
self.mc2_group = MagicMock()
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
self.mc2_group.rank_in_group = 0
self.mc2_group.world_size = 8
self.mc2_group_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
return_value=self.mc2_group)
self.mc2_group_patch.start()
self.rank_group_patch = patch("torch.distributed.get_rank",
return_value=0)
self.rank_group_patch.start()
# Mock get_forward_context().mc2_mask
self.forward_context = MagicMock()
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
self.forward_context_patch = patch(
"vllm.forward_context.get_forward_context",
return_value=self.forward_context)
self.forward_context_patch.start()
# Mock get_ascend_soc_version()
self.ascend_soc_version_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
return_value=AscendSocVersion.A3)
self.ascend_soc_version_patch.start()
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
self.row_idx = torch.arange(10, dtype=torch.int32)
def tearDown(self):
self.mc2_group_patch.stop()
self.forward_context_patch.stop()
self.ascend_soc_version_patch.stop()
def test_init(self):
self.assertEqual(self.dispatcher.ep_rank_id, 0)
self.assertEqual(self.dispatcher.ep_world_size, 8)
self.assertFalse(self.dispatcher.with_quant)
self.assertTrue(self.dispatcher.enable_dispatch_v2)
self.assertTrue(self.dispatcher.need_extra_args)
self.assertTrue(self.dispatcher.a3_need_extra_args)
def test_get_dispatch_mc2_kwargs_without_quant(self):
hidden_states = torch.randn(10, 128)
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
hidden_states, topk_weights, topk_ids, expert_map)
self.assertIn("x", kwargs)
self.assertIn("expert_ids", kwargs)
self.assertEqual(kwargs["moe_expert_num"], 8)
def test_token_permutation_dispatch(self):
hidden_states = torch.randn(10, 128)
topk_weights = torch.randn(10, 1)
topk_ids = torch.randint(0, 8, (10, 1))
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
output = self.dispatcher.token_dispatch(hidden_states,
topk_weights, topk_ids,
self.row_idx, expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output["group_list_type"],
1) # group_list_type == 1
def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock()
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
torch.tensor(1.0))
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
self.dispatcher.with_quant = False
self.dispatcher.shared_act = torch.randn(10, 128)
self.dispatcher.swiglu_out_scale = torch.tensor(1.0)
self.hidden_states = torch.randn(10, 128)
self.topk_weights = torch.randn(10, 1)
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5):
self.dispatcher.token_dispatch(self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
self.row_idx,
torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)
def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True
hidden_states = torch.randn(10, 128)
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.output = torch.randint(0, 8, (10, 1))
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
self.assertIn("tp_send_counts", kwargs)
def test_token_combine_with_shared_experts(self):
self.dispatcher.shared_experts = MagicMock()
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
10, 128), torch.tensor(1.0))
self.dispatcher.shared_act = torch.randn(10, 128)
self.dispatcher.with_quant = True
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
self.dispatcher.output = torch.randint(0, 8, (10, 1))
self.hidden_states = torch.randn(10, 128)
with patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)):
self.dispatcher.token_combine(self.hidden_states)
class TestTokenDispatcherWithAllGather(TestBase):
def setUp(self):
# Mock dependencies
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
"max_num_tokens": 100,
"ep_size": 2,
"num_experts": 128,
"with_quant": False,
}
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
# Mock NPU functions
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
self.mock_moe_init_routing.return_value = (
torch.randn(6, 128), # sorted_hidden_states
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
)
self.patcher_moe_compute_expert_tokens = patch(
'torch_npu.npu_moe_compute_expert_tokens')
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
)
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
[3, 3]) # expert_tokens
self.patcher_moe_finalize_routing = patch(
'torch_npu.npu_moe_finalize_routing')
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
)
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
self.row_idx = torch.arange(10, dtype=torch.int32)
def tearDown(self):
self.patcher_moe_init_routing.stop()
self.patcher_moe_compute_expert_tokens.stop()
self.patcher_moe_finalize_routing.stop()
def test_token_dispatch_without_expert_map(self):
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, self.row_idx, None)
# Verify npu_moe_init_routing is called
self.mock_moe_init_routing.assert_called_once()
args, kwargs = self.mock_moe_init_routing.call_args
self.assertEqual(results["group_list_type"], 0)
def test_token_dispatch_with_quant(self):
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
"max_num_tokens": 100,
"ep_size": 2,
"num_experts": 128,
}
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights, topk_ids,
self.row_idx, None)
self.assertEqual(results["group_list_type"], 0)
def test_token_combine_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.sorted_weights = torch.tensor(
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
self.dispatcher.original_shape = (3, 128)
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
hidden_states = torch.randn(6, 128)
final_hidden_states = self.dispatcher.token_combine(hidden_states)
# Verify index_add_ is applied correctly
self.assertEqual(final_hidden_states.shape, (3, 128))
def test_token_combine_without_expert_map(self):
self.dispatcher.with_quant = False
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.sorted_weights = torch.tensor(
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
self.dispatcher.original_shape = (3, 128)
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
hidden_states = torch.randn(6, 128)
final_hidden_states = self.dispatcher.token_combine(hidden_states)
# Verify npu_moe_finalize_routing is called
self.mock_moe_finalize_routing.assert_called_once()
args, kwargs = self.mock_moe_finalize_routing.call_args
self.assertEqual(final_hidden_states.shape, (3, 128))
def test_token_dispatch_with_router_weight(self):
self.dispatcher.apply_router_weight_on_input = True
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
topk_ids = torch.tensor([[0], [1], [2]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
self.assertEqual(results["hidden_states"].shape, (6, 128))
class TestTokenDispatcherWithAll2AllV(TestBase):
def setUp(self):
# Patch properties
patcher1 = patch.object(TokenDispatcherWithAll2AllV,
'ep_group',
new_callable=PropertyMock,
return_value=MagicMock())
patcher2 = patch.object(TokenDispatcherWithAll2AllV,
'ep_rank',
new_callable=PropertyMock,
return_value=0)
patcher3 = patch.object(TokenDispatcherWithAll2AllV,
'ep_size',
new_callable=PropertyMock,
return_value=2)
self.addCleanup(patcher1.stop)
self.addCleanup(patcher2.stop)
self.addCleanup(patcher3.stop)
self.mock_ep_group_prop = patcher1.start()
self.mock_ep_rank_prop = patcher2.start()
self.mock_ep_size_prop = patcher3.start()
# Mock torch_npu.npu_moe_token_permute
patcher4 = patch('torch_npu.npu_moe_token_permute')
self.mock_npu_moe_token_permute = patcher4.start()
self.addCleanup(patcher4.stop)
self.mock_npu_moe_token_permute.return_value = (torch.randn(16, 16),
torch.arange(16))
# Mock torch_npu.npu_moe_token_unpermute
patcher5 = patch('torch_npu.npu_moe_token_unpermute')
self.mock_npu_moe_token_unpermute = patcher5.start()
self.addCleanup(patcher5.stop)
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
# Mock async_all_to_all
patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all')
self.mock_async_all_to_all = patcher6.start()
self.addCleanup(patcher6.stop)
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
MagicMock())
# Mock gather_from_sequence_parallel_region
patcher7 = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region'
)
self.mock_gather_from_sequence_parallel_region = patcher7.start()
self.addCleanup(patcher7.stop)
self.mock_gather_from_sequence_parallel_region.return_value = torch.tensor(
[[2, 2, 2, 2], [2, 2, 2, 2]], dtype=torch.int64)
# Mock torch.histc
patcher8 = patch('torch.histc')
self.mock_histc = patcher8.start()
self.addCleanup(patcher8.stop)
self.mock_histc.return_value = torch.tensor([2, 2, 2, 2],
dtype=torch.int64)
# Mock torch.npu.current_device
patcher9 = patch('torch.npu.current_device')
self.mock_current_device = patcher9.start()
self.addCleanup(patcher9.stop)
self.mock_current_device.return_value = 'cpu'
# Mock torch_npu.npu_dynamic_quant
patcher10 = patch('torch_npu.npu_dynamic_quant')
self.mock_npu_dynamic_quant = patcher10.start()
self.addCleanup(patcher10.stop)
self.mock_npu_dynamic_quant.return_value = (torch.randn(16, 16),
torch.randn(16))
# Mock torch_npu.npu_moe_init_routing_v2
patcher11 = patch('torch_npu.npu_moe_init_routing_v2')
self.mock_npu_moe_init_routing_v2 = patcher11.start()
self.addCleanup(patcher11.stop)
self.mock_npu_moe_init_routing_v2.return_value = (torch.randn(
16, 16), torch.arange(16), None, torch.randn(16))
# Mock torch.repeat_interleave
patcher12 = patch('torch.repeat_interleave')
self.mock_repeat_interleave = patcher12.start()
self.addCleanup(patcher12.stop)
self.mock_repeat_interleave.return_value = torch.arange(16)
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2,
with_quant=False)
self.row_idx = torch.arange(10, dtype=torch.int32)
def test_token_dispatch(self):
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
def test_token_combine(self):
self.dispatcher.hidden_shape = (8, 16)
self.dispatcher.hidden_shape_before_permute = (8, 16)
self.dispatcher.reversed_local_input_permutation_mapping = torch.arange(
8)
self.dispatcher.topk_weights = torch.rand(8, 4)
self.dispatcher.input_splits = [4, 4]
self.dispatcher.output_splits = [4, 4]
self.dispatcher.reversed_global_input_permutation_mapping = torch.arange(
16)
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
self.dispatcher.num_global_tokens_per_local_expert = torch.tensor(
[[2, 2], [2, 2]], dtype=torch.int64)
expert_output = torch.randn(16, 16)
output = self.dispatcher.token_combine(expert_output)
self.assertIsNotNone(output)
self.assertEqual(output.shape, (8, 16))
def test_token_dispatch_with_quant(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2)
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
with_quant=True)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertIsNotNone(result["dynamic_scale"])
self.assertEqual(result["group_list_type"], 1)
def test_token_dispatch_with_quant_no_active_tokens(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2)
self.mock_repeat_interleave.return_value = torch.tensor(
[], dtype=torch.long)
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
with_quant=True)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertIsNotNone(result["dynamic_scale"])
self.assertEqual(result["group_list_type"], 1)
def test_token_dispatch_with_log2phy(self):
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
log2phy = torch.tensor([1, 0, 3, 2])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
log2phy=log2phy)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
class TestDispatcherRegistry(TestBase):
def setUp(self):
_Dispatchers.clear()
def tearDown(self):
_Dispatchers.clear()
def test_register_and_get_token_dispatcher(self):
mock_dispatcher = MagicMock()
mock_dispatcher.__class__.__name__ = "MockDispatcher"
_register_token_dispatcher(mock_dispatcher)
self.assertIn("MockDispatcher", _Dispatchers)
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
self.assertIs(retrieved_dispatcher, mock_dispatcher)
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
self, mock_register, mock_allgather_class):
kwargs = {"top_k": 2, "num_experts": 8}
mock_instance = MagicMock()
mock_allgather_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
setup_token_dispatchers(ep_size=1, **kwargs)
mock_allgather_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
self, mock_register, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
mock_instance = MagicMock()
mock_all2allv_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
setup_token_dispatchers(ep_size=2, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_all2allv_instance = MagicMock()
mock_mc2_instance = MagicMock()
mock_all2allv_class.return_value = mock_all2allv_instance
mock_mc2_class.return_value = mock_mc2_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_mc2_class.assert_called_once_with(**kwargs)
self.assertEqual(mock_register.call_count, 2)
mock_register.assert_any_call(mock_all2allv_instance)
mock_register.assert_any_call(mock_mc2_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_existing_all2allv = MagicMock()
mock_existing_mc2 = MagicMock()
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_not_called()
mock_mc2_class.assert_not_called()
mock_register.assert_not_called()
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
mock_existing_all2allv)
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
mock_existing_mc2)

View File

@@ -0,0 +1,232 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/lora/test_layers.py
import unittest
from unittest.mock import MagicMock, patch
import torch
from vllm_ascend.ops.vocab_parallel_embedding import (
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
class TestCustomVocabParallelEmbedding(unittest.TestCase):
def setUp(self):
self.num_embeddings = 50
self.embedding_dim = 10
self.org_num_embeddings = 40
self.padding_size = 8
def _create_layer(self):
# Patch methods and dependencies for VocabParallelEmbedding
mock_group = MagicMock()
mock_group.world_size = 2
mock_group.rank_in_group = 0
with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y):
# Create an instance of VocabParallelEmbedding
layer = AscendVocabParallelEmbedding(
num_embeddings=self.num_embeddings,
embedding_dim=self.embedding_dim,
org_num_embeddings=self.org_num_embeddings,
padding_size=self.padding_size,
quant_config=None, # Mock quantization config
prefix="")
layer.shard_indices = MagicMock()
layer.shard_indices.org_vocab_start_index = 10
layer.shard_indices.org_vocab_end_index = 20
layer.shard_indices.num_org_vocab_padding = 5
layer.shard_indices.added_vocab_start_index = 30
layer.shard_indices.added_vocab_end_index = 40
# Mock the quantization method
layer.quant_method.embedding = MagicMock(
side_effect=lambda _, x: torch.randn(x.shape[0], self.
embedding_dim))
return layer
def test_get_masked_input_and_mask(self):
"""Test the mask and offset calculation helper function."""
layer = self._create_layer()
input_ = torch.tensor([5, 15, 25, 35, 45])
masked_input, mask = layer._get_masked_input_and_mask(
input_,
org_vocab_start_index=10,
org_vocab_end_index=20,
num_org_vocab_padding=5,
added_vocab_start_index=30,
added_vocab_end_index=40)
expected_mask = torch.tensor([True, False, True, False, True])
self.assertTrue(
torch.equal(mask, expected_mask),
f"Mask mismatch. Expected {expected_mask}, got {mask}")
expected_masked = torch.tensor([0, 5, 0, 20, 0])
self.assertTrue(
torch.equal(masked_input, expected_masked),
f"Masked input mismatch. Expected {expected_masked}, got {masked_input}"
)
def test_forward_with_tp_size_1(self):
"""Test forward pass without tensor parallelism."""
# Create a fresh mock embedding with tp_size=1
layer = self._create_layer()
layer.tp_size = 1
layer.quant_method.embedding = MagicMock(
return_value=torch.randn(3, layer.embedding_dim))
input_ = torch.tensor([1, 2, 3])
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x) as mock_reduce_tp1:
output = layer.forward(input_)
# Should just pass through without masking
layer.quant_method.embedding.assert_called_once_with(
layer, input_.long())
self.assertEqual(output.shape, (3, layer.embedding_dim))
# Verify all_reduce was called once
mock_reduce_tp1.assert_called_once()
def test_forward_with_tp(self):
layer = self._create_layer()
layer.tp_size = 2
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x) as mock_reduce_tp:
# Call the forward method
output = layer.forward(input_)
# Check that masking was applied correctly
layer.quant_method.embedding.assert_called_once()
called_input = layer.quant_method.embedding.call_args[0][1]
expected_input = torch.tensor([5, 20]) # after offset calculation
self.assertTrue(torch.all(called_input == expected_input))
# Check that all reduce was called
mock_reduce_tp.assert_called_once()
self.assertEqual(output.shape, (2, self.embedding_dim))
def test_forward_with_invalid_vocab(self):
"""Test that invalid vocab indices are properly masked out."""
# Create a fresh embedding layer
layer = self._create_layer()
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases
# Create predictable mock output
mock_output = torch.randn(5, self.embedding_dim)
layer.quant_method.embedding = MagicMock(
return_value=mock_output.clone())
# Patch tensor_model_parallel_all_reduce to mock its behavior
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x):
# Call the forward method
output = layer.forward(input_)
# Check that invalid positions (0, 2, 4) were zeroed out
self.assertTrue(torch.all(output[0] == 0))
self.assertTrue(torch.all(output[2] == 0))
self.assertTrue(torch.all(output[4] == 0))
self.assertTrue(torch.all(output[1] == mock_output[1]))
self.assertTrue(torch.all(output[3] == mock_output[3]))
self.assertEqual(output.shape, (5, self.embedding_dim))
def test_output_shape(self):
"""Test that output shape is correct."""
# Create a fresh embedding layer
layer = self._create_layer()
test_cases = [
(torch.tensor([15]), (1, self.embedding_dim)),
(torch.tensor([15, 35]), (2, self.embedding_dim)),
(torch.tensor([15, 35, 16, 36]), (4, self.embedding_dim)),
]
for input_, expected_shape in test_cases:
with self.subTest(input=input_):
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x):
# Call the forward method
output = layer.forward(input_)
self.assertEqual(output.shape, expected_shape)
class TestAscendLogitsProcessor(unittest.TestCase):
def setUp(self):
self.vocab_size = 50
self.num_embeddings = 50
self.embedding_dim = 10
self.org_num_embeddings = 40
self.padding_size = 8
self.mock_group = MagicMock()
self.mock_group.world_size = 2
self.mock_group.rank_in_group = 0
self.mock_ascend_config = MagicMock()
self.mock_quant_method = MagicMock()
self.mock_quant_method.apply = MagicMock(
return_value=torch.randn(1, self.vocab_size))
self.patches = [
patch("vllm_ascend.ascend_config.get_ascend_config",
return_value=self.mock_ascend_config),
patch(
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.ops.vocab_parallel_embedding.lmhead_tp_enable",
return_value=True),
patch(
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
return_value=torch.randn(1, self.vocab_size))
]
for p in self.patches:
p.start()
def tearDown(self):
for p in self.patches:
p.stop()
def test_create_processor(self):
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
self.assertEqual(processor.vocab_size, self.vocab_size)
def test_get_logits(self):
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
lmhead = AscendParallelLMHead(num_embeddings=self.num_embeddings,
embedding_dim=self.embedding_dim,
prefix="lm_head")
lmhead.quant_method = self.mock_quant_method
lmhead.quant_method.apply = self.mock_quant_method.apply
hidden_state = torch.randn(1, self.org_num_embeddings)
processor._get_logits(hidden_state, lmhead)
self.mock_quant_method.apply.assert_called_once()

Some files were not shown because too many files have changed in this diff Show More