# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from math import prod from typing import Optional import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) try: from lmslim.layers.gemm.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) except Exception: print("INFO: Please install lmslim if you want to use int utils.\n") from vllm.utils import cdiv def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: """ Shrink the given tensor and apply the given view to it. This is used to resize the intermediate fused_moe caches. """ assert prod(v) <= x.numel( ), f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})" # CUDAGRAPH unfriendly? return x.flatten()[:prod(v)].view(*v) def _fp8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], per_act_token: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Perform fp8 quantization on the inputs. If a block_shape is provided, the output will be blocked. """ if block_shape is None: A, A_scale = ops.scaled_fp8_quant( A, A_scale, use_per_token_if_dynamic=per_act_token) else: assert not per_act_token assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] A, A_scale = per_token_group_quant_fp8(A, block_k) assert cdiv(A.size(-1), block_k) == A_scale.size(-1) return A, A_scale def _int8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], per_act_token: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Perform int8 quantization on the inputs. If a block_shape is provided, the output will be blocked. """ # If weights are per-channel (per_channel_quant=True), then # activations apply per-token quantization. Otherwise, assume # activation tensor-wise fp8/int8 quantization, dynamic or static if block_shape is None: assert per_act_token, \ "int8 quantization only supports block or channel-wise" A, A_scale = per_token_quant_int8(A) else: assert not per_act_token assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] A, A_scale = per_token_group_quant_int8(A, block_k) assert cdiv(A.size(-1), block_k) == A_scale.size(-1) return A, A_scale def moe_kernel_quantize_input( A: torch.Tensor, A_scale: Optional[torch.Tensor], quant_dtype: Optional[torch.dtype], per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if quant_dtype == torch.float8_e4m3fn: return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) else: return A, A_scale def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """ A permutation routine that works on fp8 types. """ if torch.is_floating_point(m) and m.dtype.itemsize == 1: return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) else: return m[idx, ...] def normalize_scales_shape( scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: if scales is not None: if scales.numel() == 1: scales = scales.view(1, 1) else: scales = scales.view(-1, scales.size(-1)) return scales def normalize_batched_scales_shape( scales: Optional[torch.Tensor], num_experts: int, ) -> Optional[torch.Tensor]: if scales is not None and scales.ndim < 3: if scales.numel() == 1: scales = scales.view(1) scales = torch.repeat_interleave(scales, num_experts, dim=0).view(num_experts, 1, 1) else: scales = scales.view(num_experts, -1, scales.size(-1)) return scales def _validate_scale_shape( a: torch.Tensor, a_scale: Optional[torch.Tensor], per_act_token_quant: bool, block_shape: Optional[list[int]], ) -> None: if a_scale is None: return if not per_act_token_quant and block_shape is None: assert a_scale.numel() == 1, f"{a_scale.shape}" elif per_act_token_quant: assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, ( f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1") else: assert block_shape is not None expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"