[Model] Support DeepSeek-V4

This commit is contained in:
chenxb002
2026-04-24 09:50:34 +08:00
commit b9925203b8
172 changed files with 44780 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
QUANTIZATION_CHOICES = ['int8', 'int4', 'e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
INTERGER_DTYPES = [torch.uint8, torch.uint16, torch.uint32, torch.uint64, torch.int8, torch.int16, torch.short,
torch.int32, torch.int, torch.int64, torch.long]
FLOAT_DTYPES = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.bfloat16,
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.half]
FP8_DTYPE = [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
FP8_STR_DTYPE = ['e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
GEMM_GROUP_SIZE = [64, 128, 256, 512]
_STR_TO_TORCH_DTYPE_DICT = dict(
bfloat16=torch.bfloat16,
float16=torch.float16,
float32=torch.float32,
int64=torch.int64,
int32=torch.int32,
int8=torch.int8,
bool=torch.bool,
e4m3fn=torch.float8_e4m3fn,
e4m3fnuz=torch.float8_e4m3fnuz,
e5m2=torch.float8_e5m2,
e5m2fnuz=torch.float8_e5m2fnuz,
)
TORCH_DTYPE_TO_STR_DICT = {
torch.bfloat16: "bfloat16",
torch.float16: "float16",
torch.float32: "float32",
torch.int64: "int64",
torch.int32: "int32",
torch.int8: "int8",
torch.bool: "bool",
torch.float8_e4m3fn: "e4m3fn",
torch.float8_e4m3fnuz: "e4m3fnuz",
torch.float8_e5m2: "e5m2",
torch.float8_e5m2fnuz: "e5m2fnuz",
}
STR_DTYPE_TO_BITS_DICT = {
"bfloat16": 16,
"float16": 16,
"float32": 32,
"int64": 64,
"int32": 32,
"int8": 8,
'int4': 4,
"bool": 1,
"e4m3fn": 8,
"e4m3fnuz": 8,
"e5m2": 8,
"e5m2fnuz": 8,
}
def str_dtype_to_torch(str_dtype: str):
'''
convert torch dytpe to str dtype
'''
ret = _STR_TO_TORCH_DTYPE_DICT.get(str_dtype)
dtype = ret if ret is not None else torch.float16
return dtype
def torch_dtype_to_str(dtype: torch.dtype):
'''
convert torch dytpe to str dtype
'''
ret = TORCH_DTYPE_TO_STR_DICT.get(dtype)
str_dtype = ret if ret is not None else "float16"
return str_dtype
def str_dtype_to_bits(str_dtype):
'''
convert torch dtype to bits size
'''
ret = STR_DTYPE_TO_BITS_DICT.get(str_dtype)
bits = ret if ret is not None else 8
return bits
def is_integer_dtype(dtype: torch.dtype):
'''
check whether is integer or not
'''
return dtype in INTERGER_DTYPES
def is_float_dtype(dtype: torch.dtype):
'''
check whether is float or not
'''
return dtype in FLOAT_DTYPES
def is_fp8_dtype(dtype: torch.dtype):
'''
judge fp8 torch dtype
'''
return dtype in FP8_DTYPE
def is_fp8_str_dtype(str_dtype: str):
'''
judge fp8 str dtype
'''
return str_dtype in FP8_STR_DTYPE

View File

@@ -0,0 +1,424 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/pull/2575
import functools
import json
import os
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_per_token_group_quant_fp8_colmajor)
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
'''
=============================
Modify by vllm_mlu
=============================
@brief: get total core for split triton kernel
'''
import triton.backends.mlu.driver as driver
_devprob = driver.BangUtils().get_device_properties(torch.mlu.current_device())
TOTAL_CLUSTER_NUM = _devprob.get("cluster_num")
TOTAL_CORE_NUM = TOTAL_CLUSTER_NUM * _devprob.get("core_num_per_cluster")
'''
==================
End of MLU Hijack
==================
'''
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
and weight.shape[1] % 128 == 0)
if current_platform.is_rocm():
# TODO this is never used, as cutlass_block_fp8_supported is False
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
input_2d.shape[:-1])[::-1]
scale_b_shape = (weight_scale.view(-1, 1)
if weight_scale.dim() <= 1 else weight_scale.T).shape
ar, ac = scale_a_shape
br, bc = scale_b_shape
if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0])
or br not in (1, weight.shape[0])):
shape_supported_by_cutlass = False
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=True)
output = ops.cutlass_scaled_mm(q_input,
weight.T,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.T)
else:
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=False)
output = w8a8_block_fp8_matmul(q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=input.dtype)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
'''
=============================
Modify by vllm_mlu
=============================
@brief: split for limit the memory usage(65536)
'''
group_per_block = 1
while M >= 65536:
group_per_block *= 2
M = x.numel() // (group_size * group_per_block)
'''
==================
End of MLU Hijack
==================
'''
if column_major_scales:
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device,
dtype=torch.float32).permute(-1, -2)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
'''
=============================
Modify by vllm_mlu
=============================
@brief: set num_warps to 1 for triton-mlu
'''
num_warps = 1
num_stages = 1
'''
==================
End of MLU Hijack
==================
'''
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M, )](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: replaced the 'scaled_quantize' kernel from the 'tmo' library with
'_per_token_group_quant_fp8' kernel
'''
# Check if x is contiguous, if not, create a new tensor for contiguous x
if not x.is_contiguous():
x = x.contiguous()
x_origin_shape = x.shape
x = x.reshape(*x.shape[:-1], -1, group_size)
x_q, x_s = mlu_ops.scaled_quantize(x,
None,
quant_type=dtype,
quant_mode='dynamic_per_token')
x_q = x_q.reshape(x_origin_shape)
'''
==================
End of MLU Hijack
==================
'''
return x_q, x_s
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
'''
=============================
Modify by vllm_mlu
=============================
@brief: split for limit the memory usage(65536)
'''
num_block_size_all = num_pid_m * num_pid_n
num_block_size_per = num_block_size_all // tl.num_programs(axis=0)
num_block_size_rem = num_block_size_all % tl.num_programs(axis=0)
core_deal_num_block_size = num_block_size_per + (pid < num_block_size_rem)
core_deal_num_block_start = num_block_size_per * pid + min(num_block_size_rem, pid)
for pid_i in range(0, core_deal_num_block_size):
pid_in_core_deal_block = core_deal_num_block_start + pid_i
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid_in_core_deal_block // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid_in_core_deal_block % group_size_m)
pid_n = (pid_in_core_deal_block % num_pid_in_group) // group_size_m
'''
==================
End of MLU Hijack
==================
'''
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
'''
=============================
Modify by vllm_mlu
=============================
@brief: replaced the 'scaled_matmul' kernel from the 'tmo' library with
'_w8a8_block_fp8_matmul' kernel
'''
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert B.ndim == 2 and Bs.ndim == 2
if (B.shape[0] % 128 == 0) and (B.shape[1] % 128 == 0):
C = mlu_ops.scaled_matmul(A, B, As, Bs, output_dtype, bias=None, c=None, act_mode="none",
quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None)
else:
# NOTE(wulingchao): scaled_matmul 底层算子只支持n和k是128的倍数
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
# Default config
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
# BLOCK_SIZE_K must be divisible by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 1,
"num_stages": 1,
}
def grid(META):
return (TOTAL_CORE_NUM, )
_w8a8_block_fp8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
'''
==================
End of MLU Hijack
==================
'''
return C

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional, Callable
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, USE_ROWWISE_TORCH_SCALED_MM, cutlass_w8a8_scaled_mm,
flashinfer_w8a8_scaled_mm, rocm_per_tensor_w8a8_scaled_mm,
torch_per_tensor_w8a8_scaled_mm, torch_per_token_w8a8_scaled_mm,
torch_channelwise_w8a8_scaled_mm)
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def mlu_w8a8_scaled_mm(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
output_shape: list, **kwargs
) -> torch.Tensor:
output = mlu_ops.scaled_matmul(
qinput, # a
weight, # b
scale_a, # a_scale
scale_b, # b_scale
out_dtype, # output_dtype
bias, # bias
c=None, act_mode="none",quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None
)
return output.view(*output_shape)
def dispatch_w8a8_scaled_mm(
preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool,
weight_per_channel: bool, activation_per_token: bool
) -> Callable[..., torch.Tensor]:
if per_tensor_weights and per_tensor_activations:
if preferred_backend == "rocm":
return rocm_per_tensor_w8a8_scaled_mm
if preferred_backend == "flashinfer":
return flashinfer_w8a8_scaled_mm
if preferred_backend == "cutlass":
return cutlass_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
return cutlass_w8a8_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if (
not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM
):
return torch_per_token_w8a8_scaled_mm
# Normally, torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
'''
=============================
Modify by vllm_mlu
=============================
@brief: dispatch to mlu_w8a8_scaled_mm
'''
if weight_per_channel and activation_per_token:
return mlu_w8a8_scaled_mm
'''
==================
End of MLU Hijack
==================
'''
return torch_channelwise_w8a8_scaled_mm
def vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype | None = None,
input_scale: torch.Tensor | None = None,
input_scale_ub: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
weight_per_channel: bool = True,
activation_per_token: bool = True,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
'''
=============================
Modify by vllm_mlu
=============================
@brief: add mlu_fp8_supported
'''
self.mlu_fp8_supported = False
if weight_per_channel and activation_per_token:
self.mlu_fp8_supported = True
'''
==================
End of MLU Hijack
==================
'''
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]]
if out_dtype is None:
out_dtype = input.dtype
if self.mlu_fp8_supported:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add support for activation-per-token weight-per-channel quantization.
'''
qinput, x_scale = mlu_ops.scaled_quantize(
input_2d,# x
None, # scale
None, # zero
None, # scale_ub
quant_type=torch.float8_e4m3fn,
quant_mode='dynamic_per_token'
)
output_shape = [*input.shape[:-1], weight.shape[0]]
'''
==================
End of MLU Hijack
==================
'''
else:
# If input not quantized
# TODO(luka) remove this path if not used anymore
if input.dtype != current_platform.fp8_dtype():
qinput, x_scale = self.quant_fp8(
input_2d,
input_scale,
input_scale_ub,
)
else:
qinput, x_scale = input_2d, input_scale
# Must have dim() conditions
# In per-token quant scenario, when the number of token is 1,
# the scale will only have 1 elements.
# Without checking the dim(),
# we cannot distingushes between per-tensor and per-token quant.
# Example:
# When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or ().
per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.preferred_backend, per_tensor_weights, per_tensor_activations,
weight_per_channel, activation_per_token)
return w8a8_scaled_mm_func(
qinput=qinput,
weight=weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
output_shape=output_shape,
)
MluHijackObject.apply_hijack(
Fp8LinearOp,
Fp8LinearOp.apply,
vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply
)