Sync from v0.13
This commit is contained in:
78
tests/v1/tpu/test_tpu_int8.py
Normal file
78
tests/v1/tpu/test_tpu_int8.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests whether TPU Int8 computation is enabled correctly.
|
||||
|
||||
Run `pytest tests/quantization/test_tpu_int8.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.tpu_int8 import TPUInt8LinearMethod
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...models.registry import HF_EXAMPLE_MODELS
|
||||
|
||||
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_tpu(), reason="TPU Int8 is only enabled for TPUs."
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
@pytest.mark.parametrize(
|
||||
"hf_overrides",
|
||||
[
|
||||
# w8a8 dynamic activation
|
||||
{
|
||||
"quantization_config": {
|
||||
"quant_method": "tpu_int8",
|
||||
"activation_scheme": "dynamic",
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_model_tpu_int8(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
hf_overrides: dict,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
activation_scheme = hf_overrides.get("quantization_config", {}).get(
|
||||
"activation_scheme"
|
||||
)
|
||||
quantize_activation = activation_scheme == "dynamic"
|
||||
|
||||
# Allows using apply_model
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
# Prevent error from re-initializing cache
|
||||
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")
|
||||
|
||||
prompts = [
|
||||
"A robot may not injure a human being",
|
||||
]
|
||||
answers = [
|
||||
"or kill a human being",
|
||||
]
|
||||
|
||||
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
|
||||
|
||||
def check_model(model):
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, LinearBase):
|
||||
continue
|
||||
quant_method = module.quant_method
|
||||
assert isinstance(quant_method, TPUInt8LinearMethod)
|
||||
assert quant_method.quantize_activation == quantize_activation
|
||||
|
||||
vllm.apply_model(check_model)
|
||||
outputs = vllm.generate_greedy(prompts, max_tokens)
|
||||
for (_, output), answer in zip(outputs, answers):
|
||||
assert answer in output
|
||||
Reference in New Issue
Block a user