Sync from v0.13
This commit is contained in:
109
tests/quantization/test_gptq_v2.py
Normal file
109
tests/quantization/test_gptq_v2.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests whether vllm correctly load and run gptq_v2 format checkpoints.
|
||||
|
||||
Run `pytest tests/quantization/test_gptq_v2.py --forked`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
|
||||
# A dummy small model quantized by GPTQModel, stored in GPTQ v2 format
|
||||
MODELS = ["XXXXyu/Qwen3-1.7B-w2g64-gptq_v2"]
|
||||
|
||||
# Generate multiple sequences for testing, because an 1.7B 2-bit model
|
||||
# cannot always generate normal texts.
|
||||
N_SEQ = 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", MODELS)
|
||||
def test_model_load(vllm_runner, model_id, monkeypatch):
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
# Only check the default GPTQ linear method (used for 2/3-bit models).
|
||||
# 4/8-bit linear methods like Marlin already support gptq_v2.
|
||||
linear_method_cls = GPTQLinearMethod
|
||||
|
||||
with vllm_runner(model_id, dtype=torch.float16, max_model_len=512) as llm:
|
||||
|
||||
def check_model(model_id):
|
||||
for name, submodule in model_id.named_modules():
|
||||
# Could check more modules if necessary
|
||||
if name == "model_id.layers.0.self_attn.qkv_proj":
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
|
||||
config = submodule.quant_method.quant_config
|
||||
assert config.checkpoint_format == "gptq_v2"
|
||||
assert submodule.quant_method.use_v2_format
|
||||
|
||||
# Just break since currently we only check 1 module
|
||||
break
|
||||
|
||||
# Check if gptq_v2 format is correctly loaded
|
||||
llm.apply_model(check_model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", MODELS)
|
||||
def test_model_inference(vllm_runner, model_id):
|
||||
# Prepare prompt to test the model's generation result.
|
||||
prompt = "What is the meaning of life?"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False, # If thinking model, set it to false
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
n=N_SEQ,
|
||||
max_tokens=128,
|
||||
temperature=0.7,
|
||||
top_p=0.8,
|
||||
top_k=20,
|
||||
min_p=0,
|
||||
presence_penalty=2,
|
||||
)
|
||||
|
||||
with vllm_runner(model_id, dtype=torch.float16, max_model_len=512) as llm:
|
||||
# Generate a response to verify inference correctness
|
||||
output = llm.generate(text, sampling_params)
|
||||
|
||||
# Make sure the output exists
|
||||
assert output
|
||||
assert output[0][1]
|
||||
assert len(output[0][1]) == N_SEQ
|
||||
|
||||
def has_normal_char_distribution(texts, min_len):
|
||||
for text in texts:
|
||||
# Response too short
|
||||
if len(text) < min_len:
|
||||
return False
|
||||
|
||||
# Basic ratio checks
|
||||
letters = sum(c.isalpha() for c in text)
|
||||
spaces = sum(c.isspace() for c in text)
|
||||
total = len(text)
|
||||
|
||||
letter_ratio = letters / total
|
||||
space_ratio = spaces / total
|
||||
|
||||
# At least 1 normal text should exist within output sequences
|
||||
# Normal text should be mostly letters with reasonable spacing
|
||||
# Some magic numbers, could be adjusted
|
||||
if 0.5 <= letter_ratio <= 0.9 and 0.01 <= space_ratio <= 0.3:
|
||||
return True
|
||||
# No sequence contains normal text, output might be broken
|
||||
return False
|
||||
|
||||
# Apply some simple checks for giberish output
|
||||
# Print the output sequences if failed
|
||||
assert has_normal_char_distribution(output[0][1], 5), output[0][1]
|
||||
Reference in New Issue
Block a user