Sync from v0.13
This commit is contained in:
72
tests/kernels/quantization/test_per_token_group_quant.py
Normal file
72
tests/kernels/quantization/test_per_token_group_quant.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
|
||||
@pytest.mark.parametrize("column_major", [False, True])
|
||||
@pytest.mark.parametrize("scale_ue8m0", [False, True])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_per_token_group_quant_fp8(
|
||||
shape, column_major: bool, scale_ue8m0: bool, group_size: int
|
||||
):
|
||||
device = "cuda"
|
||||
|
||||
torch.manual_seed(42)
|
||||
num_tokens, hidden_dim = shape
|
||||
|
||||
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
|
||||
|
||||
# cuda path
|
||||
out_q, scale = fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
column_major_scales=column_major,
|
||||
use_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
# triton ref
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
column_major_scales=column_major,
|
||||
use_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
|
||||
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_per_token_group_quant_int8(shape, group_size: int):
|
||||
device = "cuda"
|
||||
|
||||
torch.manual_seed(42)
|
||||
num_tokens, hidden_dim = shape
|
||||
|
||||
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
|
||||
|
||||
# cuda path
|
||||
out_q, scale = int8_utils.per_token_group_quant_int8(
|
||||
x,
|
||||
group_size,
|
||||
)
|
||||
|
||||
# triton ref
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
ref_q, ref_s = int8_utils.per_token_group_quant_int8(
|
||||
x,
|
||||
group_size,
|
||||
)
|
||||
|
||||
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
|
||||
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
|
||||
Reference in New Issue
Block a user