Files
2026-04-24 09:58:03 +08:00

32 lines
904 B
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
from math import prod
from typing import List, Optional, Tuple
import torch
from vllm.utils.math_utils import cdiv
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
block_shape: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
assert block_shape is not None
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
return A, A_scale