148 lines
4.8 KiB
Python
148 lines
4.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from math import prod
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|
per_token_group_quant_fp8)
|
|
try:
|
|
from lmslim.layers.gemm.int8_utils import (
|
|
per_token_group_quant_int8, per_token_quant_int8)
|
|
except Exception:
|
|
print("INFO: Please install lmslim if you want to use int utils.\n")
|
|
from vllm.utils import cdiv
|
|
|
|
|
|
def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
|
"""
|
|
Shrink the given tensor and apply the given view to it. This is
|
|
used to resize the intermediate fused_moe caches.
|
|
"""
|
|
assert prod(v) <= x.numel(
|
|
), f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})" # CUDAGRAPH unfriendly?
|
|
return x.flatten()[:prod(v)].view(*v)
|
|
|
|
|
|
def _fp8_quantize(
|
|
A: torch.Tensor,
|
|
A_scale: Optional[torch.Tensor],
|
|
per_act_token: bool,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Perform fp8 quantization on the inputs. If a block_shape
|
|
is provided, the output will be blocked.
|
|
"""
|
|
if block_shape is None:
|
|
A, A_scale = ops.scaled_fp8_quant(
|
|
A, A_scale, use_per_token_if_dynamic=per_act_token)
|
|
else:
|
|
assert not per_act_token
|
|
assert len(block_shape) == 2
|
|
_, block_k = block_shape[0], block_shape[1]
|
|
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
|
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
|
|
|
|
return A, A_scale
|
|
|
|
|
|
def _int8_quantize(
|
|
A: torch.Tensor,
|
|
A_scale: Optional[torch.Tensor],
|
|
per_act_token: bool,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Perform int8 quantization on the inputs. If a block_shape
|
|
is provided, the output will be blocked.
|
|
"""
|
|
|
|
# If weights are per-channel (per_channel_quant=True), then
|
|
# activations apply per-token quantization. Otherwise, assume
|
|
# activation tensor-wise fp8/int8 quantization, dynamic or static
|
|
if block_shape is None:
|
|
assert per_act_token, \
|
|
"int8 quantization only supports block or channel-wise"
|
|
A, A_scale = per_token_quant_int8(A)
|
|
else:
|
|
assert not per_act_token
|
|
assert len(block_shape) == 2
|
|
_, block_k = block_shape[0], block_shape[1]
|
|
A, A_scale = per_token_group_quant_int8(A, block_k)
|
|
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
|
|
|
|
return A, A_scale
|
|
|
|
|
|
def moe_kernel_quantize_input(
|
|
A: torch.Tensor,
|
|
A_scale: Optional[torch.Tensor],
|
|
quant_dtype: Optional[torch.dtype],
|
|
per_act_token_quant: bool,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
if quant_dtype == torch.float8_e4m3fn:
|
|
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
|
elif quant_dtype == torch.int8:
|
|
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
|
else:
|
|
return A, A_scale
|
|
|
|
|
|
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
A permutation routine that works on fp8 types.
|
|
"""
|
|
if torch.is_floating_point(m) and m.dtype.itemsize == 1:
|
|
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
|
else:
|
|
return m[idx, ...]
|
|
|
|
|
|
def normalize_scales_shape(
|
|
scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
if scales is not None:
|
|
if scales.numel() == 1:
|
|
scales = scales.view(1, 1)
|
|
else:
|
|
scales = scales.view(-1, scales.size(-1))
|
|
return scales
|
|
|
|
|
|
def normalize_batched_scales_shape(
|
|
scales: Optional[torch.Tensor],
|
|
num_experts: int,
|
|
) -> Optional[torch.Tensor]:
|
|
if scales is not None and scales.ndim < 3:
|
|
if scales.numel() == 1:
|
|
scales = scales.view(1)
|
|
scales = torch.repeat_interleave(scales, num_experts,
|
|
dim=0).view(num_experts, 1, 1)
|
|
else:
|
|
scales = scales.view(num_experts, -1, scales.size(-1))
|
|
|
|
return scales
|
|
|
|
|
|
def _validate_scale_shape(
|
|
a: torch.Tensor,
|
|
a_scale: Optional[torch.Tensor],
|
|
per_act_token_quant: bool,
|
|
block_shape: Optional[list[int]],
|
|
) -> None:
|
|
if a_scale is None:
|
|
return
|
|
|
|
if not per_act_token_quant and block_shape is None:
|
|
assert a_scale.numel() == 1, f"{a_scale.shape}"
|
|
elif per_act_token_quant:
|
|
assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, (
|
|
f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1")
|
|
else:
|
|
assert block_shape is not None
|
|
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
|
|
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
|