[Model] Support DeepSeek-V4
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
QUANTIZATION_CHOICES = ['int8', 'int4', 'e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
|
||||
INTERGER_DTYPES = [torch.uint8, torch.uint16, torch.uint32, torch.uint64, torch.int8, torch.int16, torch.short,
|
||||
torch.int32, torch.int, torch.int64, torch.long]
|
||||
FLOAT_DTYPES = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.bfloat16,
|
||||
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.half]
|
||||
FP8_DTYPE = [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
|
||||
FP8_STR_DTYPE = ['e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
|
||||
GEMM_GROUP_SIZE = [64, 128, 256, 512]
|
||||
|
||||
_STR_TO_TORCH_DTYPE_DICT = dict(
|
||||
bfloat16=torch.bfloat16,
|
||||
float16=torch.float16,
|
||||
float32=torch.float32,
|
||||
int64=torch.int64,
|
||||
int32=torch.int32,
|
||||
int8=torch.int8,
|
||||
bool=torch.bool,
|
||||
e4m3fn=torch.float8_e4m3fn,
|
||||
e4m3fnuz=torch.float8_e4m3fnuz,
|
||||
e5m2=torch.float8_e5m2,
|
||||
e5m2fnuz=torch.float8_e5m2fnuz,
|
||||
)
|
||||
|
||||
TORCH_DTYPE_TO_STR_DICT = {
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.float16: "float16",
|
||||
torch.float32: "float32",
|
||||
torch.int64: "int64",
|
||||
torch.int32: "int32",
|
||||
torch.int8: "int8",
|
||||
torch.bool: "bool",
|
||||
torch.float8_e4m3fn: "e4m3fn",
|
||||
torch.float8_e4m3fnuz: "e4m3fnuz",
|
||||
torch.float8_e5m2: "e5m2",
|
||||
torch.float8_e5m2fnuz: "e5m2fnuz",
|
||||
}
|
||||
|
||||
STR_DTYPE_TO_BITS_DICT = {
|
||||
"bfloat16": 16,
|
||||
"float16": 16,
|
||||
"float32": 32,
|
||||
"int64": 64,
|
||||
"int32": 32,
|
||||
"int8": 8,
|
||||
'int4': 4,
|
||||
"bool": 1,
|
||||
"e4m3fn": 8,
|
||||
"e4m3fnuz": 8,
|
||||
"e5m2": 8,
|
||||
"e5m2fnuz": 8,
|
||||
}
|
||||
|
||||
|
||||
def str_dtype_to_torch(str_dtype: str):
|
||||
'''
|
||||
convert torch dytpe to str dtype
|
||||
'''
|
||||
ret = _STR_TO_TORCH_DTYPE_DICT.get(str_dtype)
|
||||
dtype = ret if ret is not None else torch.float16
|
||||
return dtype
|
||||
|
||||
|
||||
def torch_dtype_to_str(dtype: torch.dtype):
|
||||
'''
|
||||
convert torch dytpe to str dtype
|
||||
'''
|
||||
ret = TORCH_DTYPE_TO_STR_DICT.get(dtype)
|
||||
str_dtype = ret if ret is not None else "float16"
|
||||
return str_dtype
|
||||
|
||||
|
||||
def str_dtype_to_bits(str_dtype):
|
||||
'''
|
||||
convert torch dtype to bits size
|
||||
'''
|
||||
ret = STR_DTYPE_TO_BITS_DICT.get(str_dtype)
|
||||
bits = ret if ret is not None else 8
|
||||
return bits
|
||||
|
||||
|
||||
def is_integer_dtype(dtype: torch.dtype):
|
||||
'''
|
||||
check whether is integer or not
|
||||
'''
|
||||
return dtype in INTERGER_DTYPES
|
||||
|
||||
|
||||
def is_float_dtype(dtype: torch.dtype):
|
||||
'''
|
||||
check whether is float or not
|
||||
'''
|
||||
return dtype in FLOAT_DTYPES
|
||||
|
||||
|
||||
def is_fp8_dtype(dtype: torch.dtype):
|
||||
'''
|
||||
judge fp8 torch dtype
|
||||
'''
|
||||
return dtype in FP8_DTYPE
|
||||
|
||||
|
||||
def is_fp8_str_dtype(str_dtype: str):
|
||||
'''
|
||||
judge fp8 str dtype
|
||||
'''
|
||||
return str_dtype in FP8_STR_DTYPE
|
||||
424
vllm_mlu/model_executor/layers/quantization/utils/fp8_utils.py
Normal file
424
vllm_mlu/model_executor/layers/quantization/utils/fp8_utils.py
Normal file
@@ -0,0 +1,424 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_per_token_group_quant_fp8_colmajor)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: get total core for split triton kernel
|
||||
'''
|
||||
|
||||
import triton.backends.mlu.driver as driver
|
||||
|
||||
_devprob = driver.BangUtils().get_device_properties(torch.mlu.current_device())
|
||||
TOTAL_CLUSTER_NUM = _devprob.get("cluster_num")
|
||||
TOTAL_CORE_NUM = TOTAL_CLUSTER_NUM * _devprob.get("core_num_per_cluster")
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
|
||||
and weight.shape[1] % 128 == 0)
|
||||
if current_platform.is_rocm():
|
||||
# TODO this is never used, as cutlass_block_fp8_supported is False
|
||||
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
|
||||
input_2d.shape[:-1])[::-1]
|
||||
scale_b_shape = (weight_scale.view(-1, 1)
|
||||
if weight_scale.dim() <= 1 else weight_scale.T).shape
|
||||
ar, ac = scale_a_shape
|
||||
br, bc = scale_b_shape
|
||||
if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0])
|
||||
or br not in (1, weight.shape[0])):
|
||||
shape_supported_by_cutlass = False
|
||||
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True)
|
||||
output = ops.cutlass_scaled_mm(q_input,
|
||||
weight.T,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.T)
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=False)
|
||||
output = w8a8_block_fp8_matmul(q_input,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=input.dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
def per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
column_major_scales: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
It converts the tensor values into signed float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
Args:
|
||||
x: The input tensor with ndim >= 2.
|
||||
group_size: The group size used for quantization.
|
||||
eps: The minimum to avoid dividing zero.
|
||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
is supported for now.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
"""
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
assert (x.shape[-1] % group_size == 0), (
|
||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||
f"by `group_size` {group_size}")
|
||||
assert x.stride(-1) == 1, "`x` groups must be contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: split for limit the memory usage(65536)
|
||||
'''
|
||||
group_per_block = 1
|
||||
while M >= 65536:
|
||||
group_per_block *= 2
|
||||
M = x.numel() // (group_size * group_per_block)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if column_major_scales:
|
||||
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
|
||||
x_s = torch.empty(shape, device=x.device,
|
||||
dtype=torch.float32).permute(-1, -2)
|
||||
else:
|
||||
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: set num_warps to 1 for triton-mlu
|
||||
'''
|
||||
num_warps = 1
|
||||
num_stages = 1
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if column_major_scales:
|
||||
_per_token_group_quant_fp8_colmajor[(M, )](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
x_s.stride(1),
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replaced the 'scaled_quantize' kernel from the 'tmo' library with
|
||||
'_per_token_group_quant_fp8' kernel
|
||||
'''
|
||||
# Check if x is contiguous, if not, create a new tensor for contiguous x
|
||||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
x_origin_shape = x.shape
|
||||
x = x.reshape(*x.shape[:-1], -1, group_size)
|
||||
x_q, x_s = mlu_ops.scaled_quantize(x,
|
||||
None,
|
||||
quant_type=dtype,
|
||||
quant_mode='dynamic_per_token')
|
||||
x_q = x_q.reshape(x_origin_shape)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
@triton.jit
|
||||
def _w8a8_block_fp8_matmul(
|
||||
# Pointers to inputs and output
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
# Shape for matmul
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# Block size for block-wise quantization
|
||||
group_n,
|
||||
group_k,
|
||||
# Stride for inputs and output
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_As_m,
|
||||
stride_As_k,
|
||||
stride_Bs_k,
|
||||
stride_Bs_n,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Triton-accelerated function used to perform linear operations (dot
|
||||
product) on input tensors `A` and `B` with block-wise quantization, and
|
||||
store the result in output tensor `C`.
|
||||
"""
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: split for limit the memory usage(65536)
|
||||
'''
|
||||
num_block_size_all = num_pid_m * num_pid_n
|
||||
num_block_size_per = num_block_size_all // tl.num_programs(axis=0)
|
||||
num_block_size_rem = num_block_size_all % tl.num_programs(axis=0)
|
||||
|
||||
core_deal_num_block_size = num_block_size_per + (pid < num_block_size_rem)
|
||||
core_deal_num_block_start = num_block_size_per * pid + min(num_block_size_rem, pid)
|
||||
|
||||
for pid_i in range(0, core_deal_num_block_size):
|
||||
pid_in_core_deal_block = core_deal_num_block_start + pid_i
|
||||
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid_in_core_deal_block // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid_in_core_deal_block % group_size_m)
|
||||
pid_n = (pid_in_core_deal_block % num_pid_in_group) // group_size_m
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
As_ptrs = As + offs_am * stride_As_m
|
||||
offs_bsn = offs_bn // group_n
|
||||
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs,
|
||||
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
||||
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if C.dtype.element_ty == tl.bfloat16:
|
||||
c = accumulator.to(tl.bfloat16)
|
||||
elif C.dtype.element_ty == tl.float16:
|
||||
c = accumulator.to(tl.float16)
|
||||
else:
|
||||
c = accumulator.to(tl.float32)
|
||||
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
quantization.
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
Args:
|
||||
A: The input tensor, e.g., activation.
|
||||
B: The input tensor, e.g., weight.
|
||||
As: The per-token-group quantization scale for `A`.
|
||||
Bs: The per-block quantization scale for `B`.
|
||||
block_size: The block size for per-block quantization. It should
|
||||
be 2-dim, e.g., [128, 128].
|
||||
output_dytpe: The dtype of the returned tensor.
|
||||
Returns:
|
||||
torch.Tensor: The result of matmul.
|
||||
"""
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replaced the 'scaled_matmul' kernel from the 'tmo' library with
|
||||
'_w8a8_block_fp8_matmul' kernel
|
||||
'''
|
||||
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert B.ndim == 2 and Bs.ndim == 2
|
||||
|
||||
if (B.shape[0] % 128 == 0) and (B.shape[1] % 128 == 0):
|
||||
C = mlu_ops.scaled_matmul(A, B, As, Bs, output_dtype, bias=None, c=None, act_mode="none",
|
||||
quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
|
||||
a_quant_bit_size=8, a_calib=None, b_calib=None)
|
||||
else:
|
||||
# NOTE(wulingchao): scaled_matmul 底层算子只支持n和k是128的倍数
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
assert B.ndim == 2 and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N, )
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
# Default config
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
|
||||
# BLOCK_SIZE_K must be divisible by block_size[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_size[0],
|
||||
"BLOCK_SIZE_K": block_size[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 1,
|
||||
"num_stages": 1,
|
||||
}
|
||||
|
||||
def grid(META):
|
||||
return (TOTAL_CORE_NUM, )
|
||||
|
||||
_w8a8_block_fp8_matmul[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return C
|
||||
178
vllm_mlu/model_executor/layers/quantization/utils/w8a8_utils.py
Normal file
178
vllm_mlu/model_executor/layers/quantization/utils/w8a8_utils.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional, Callable
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, USE_ROWWISE_TORCH_SCALED_MM, cutlass_w8a8_scaled_mm,
|
||||
flashinfer_w8a8_scaled_mm, rocm_per_tensor_w8a8_scaled_mm,
|
||||
torch_per_tensor_w8a8_scaled_mm, torch_per_token_w8a8_scaled_mm,
|
||||
torch_channelwise_w8a8_scaled_mm)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def mlu_w8a8_scaled_mm(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
output_shape: list, **kwargs
|
||||
) -> torch.Tensor:
|
||||
output = mlu_ops.scaled_matmul(
|
||||
qinput, # a
|
||||
weight, # b
|
||||
scale_a, # a_scale
|
||||
scale_b, # b_scale
|
||||
out_dtype, # output_dtype
|
||||
bias, # bias
|
||||
c=None, act_mode="none",quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
|
||||
a_quant_bit_size=8, a_calib=None, b_calib=None
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
def dispatch_w8a8_scaled_mm(
|
||||
preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool,
|
||||
weight_per_channel: bool, activation_per_token: bool
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
if preferred_backend == "rocm":
|
||||
return rocm_per_tensor_w8a8_scaled_mm
|
||||
if preferred_backend == "flashinfer":
|
||||
return flashinfer_w8a8_scaled_mm
|
||||
if preferred_backend == "cutlass":
|
||||
return cutlass_w8a8_scaled_mm
|
||||
return torch_per_tensor_w8a8_scaled_mm
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
|
||||
return cutlass_w8a8_scaled_mm
|
||||
|
||||
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
|
||||
if (
|
||||
not per_tensor_weights
|
||||
and not per_tensor_activations
|
||||
and USE_ROWWISE_TORCH_SCALED_MM
|
||||
):
|
||||
return torch_per_token_w8a8_scaled_mm
|
||||
# Normally, torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: dispatch to mlu_w8a8_scaled_mm
|
||||
'''
|
||||
if weight_per_channel and activation_per_token:
|
||||
return mlu_w8a8_scaled_mm
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return torch_channelwise_w8a8_scaled_mm
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
input_scale_ub: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
weight_per_channel: bool = True,
|
||||
activation_per_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add mlu_fp8_supported
|
||||
'''
|
||||
self.mlu_fp8_supported = False
|
||||
if weight_per_channel and activation_per_token:
|
||||
self.mlu_fp8_supported = True
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[1]]
|
||||
|
||||
if out_dtype is None:
|
||||
out_dtype = input.dtype
|
||||
|
||||
if self.mlu_fp8_supported:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add support for activation-per-token weight-per-channel quantization.
|
||||
'''
|
||||
qinput, x_scale = mlu_ops.scaled_quantize(
|
||||
input_2d,# x
|
||||
None, # scale
|
||||
None, # zero
|
||||
None, # scale_ub
|
||||
quant_type=torch.float8_e4m3fn,
|
||||
quant_mode='dynamic_per_token'
|
||||
)
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
if input.dtype != current_platform.fp8_dtype():
|
||||
qinput, x_scale = self.quant_fp8(
|
||||
input_2d,
|
||||
input_scale,
|
||||
input_scale_ub,
|
||||
)
|
||||
else:
|
||||
qinput, x_scale = input_2d, input_scale
|
||||
|
||||
# Must have dim() conditions
|
||||
# In per-token quant scenario, when the number of token is 1,
|
||||
# the scale will only have 1 elements.
|
||||
# Without checking the dim(),
|
||||
# we cannot distingushes between per-tensor and per-token quant.
|
||||
# Example:
|
||||
# When the number of token is 1, per-token scale is [[1]]
|
||||
# When per-tensor scale is [1] or ().
|
||||
per_tensor_weights = weight_scale.numel() == 1
|
||||
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
|
||||
|
||||
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
|
||||
self.preferred_backend, per_tensor_weights, per_tensor_activations,
|
||||
weight_per_channel, activation_per_token)
|
||||
return w8a8_scaled_mm_func(
|
||||
qinput=qinput,
|
||||
weight=weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearOp,
|
||||
Fp8LinearOp.apply,
|
||||
vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply
|
||||
)
|
||||
Reference in New Issue
Block a user