96 lines
3.7 KiB
Python
96 lines
3.7 KiB
Python
|
|
|
||
|
|
import contextlib
|
||
|
|
import importlib
|
||
|
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||
|
|
import torch
|
||
|
|
import torch
|
||
|
|
import torch.library
|
||
|
|
|
||
|
|
import vllm.envs as envs
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
from vllm.scalar_type import ScalarType
|
||
|
|
|
||
|
|
def cutlass_scaled_mm_vacc(a: torch.Tensor,
|
||
|
|
b: torch.Tensor,
|
||
|
|
scale_a: torch.Tensor,
|
||
|
|
scale_b: torch.Tensor,
|
||
|
|
out_dtype: torch.dtype,
|
||
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||
|
|
"""
|
||
|
|
`cutlass_scaled_mm` implements a fused version of
|
||
|
|
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||
|
|
where scale_a * a and scale_b * b are implemented using numpy-style
|
||
|
|
broadcasting.
|
||
|
|
|
||
|
|
In order to support blockwise scaling like found in DeepSeek V3 we also
|
||
|
|
support extended "group" broadcast rules. We extend the numpy-style
|
||
|
|
broadcasting rules with the following rule:
|
||
|
|
"if the extent of a dimension in the source shape is between 1 and
|
||
|
|
corresponding extent in the target shape we repeat each element along
|
||
|
|
that dimension src_shape[dim] // target_shape[dim] times consecutively"
|
||
|
|
example if we have:
|
||
|
|
a = [[1, 2], and target_shape = (2, 4)
|
||
|
|
[3, 4]]
|
||
|
|
then we would expand a to:
|
||
|
|
a = [[1, 1, 2, 2],
|
||
|
|
[3, 3, 4, 4]]
|
||
|
|
currently we only support the case:
|
||
|
|
scale_a.shape * [1, 128] == a.shape
|
||
|
|
scale_b.shape * [128, 128] == b.shape
|
||
|
|
"""
|
||
|
|
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||
|
|
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||
|
|
assert bias is None or bias.shape[0] == b.shape[
|
||
|
|
1] and bias.dtype == out_dtype
|
||
|
|
|
||
|
|
m = a.shape[0]
|
||
|
|
n = b.shape[1]
|
||
|
|
|
||
|
|
if current_platform.is_rocm():
|
||
|
|
triton_scaled_mm_module = importlib.import_module(
|
||
|
|
"vllm.model_executor.layers.quantization.compressed_tensors."
|
||
|
|
"triton_scaled_mm")
|
||
|
|
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
||
|
|
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||
|
|
|
||
|
|
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||
|
|
# print('a',a.shape,a.dtype) # torch.Size([8192, 3584]) torch.float8_e4m3fn
|
||
|
|
# print('scale_a',scale_a.shape) #torch.Size([8192, 56])
|
||
|
|
# print('b',b.shape,b.dtype) # torch.Size([3584, 1536]) torch.float8_e4m3fn
|
||
|
|
# print('scale_b',scale_b.shape) #torch.Size([56, 12])
|
||
|
|
|
||
|
|
use_a32_w32 = True #反量化到fp32 计算 matmul
|
||
|
|
|
||
|
|
if use_a32_w32 or (b.shape[1]//scale_b.shape[1] != 128 or
|
||
|
|
a.shape[1]//scale_a.shape[1] != 128 or
|
||
|
|
b.shape[0]//scale_b.shape[0] != 128):
|
||
|
|
# cutlass_scaled_mm 不支持非128的 quant block
|
||
|
|
a1 = a.to(torch.float32).reshape(a.shape[0], scale_a.shape[1], -1)
|
||
|
|
scale_a = scale_a.reshape(scale_a.shape[0], scale_a.shape[1], 1).to(torch.float32)
|
||
|
|
a = (a1*scale_a).reshape(a.shape).contiguous()
|
||
|
|
|
||
|
|
b1 = b.to(torch.float32).reshape(scale_b.shape[0], b.shape[0]//scale_b.shape[0], scale_b.shape[1], b.shape[1]//scale_b.shape[1])
|
||
|
|
scale_b = scale_b.reshape(scale_b.shape[0], 1, scale_b.shape[1], 1).to(torch.float32)
|
||
|
|
b = (b1*scale_b).reshape(b.shape).contiguous()
|
||
|
|
|
||
|
|
out = a@b
|
||
|
|
if bias is not None:
|
||
|
|
out = out + bias
|
||
|
|
return out.to(out_dtype)
|
||
|
|
|
||
|
|
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
||
|
|
|
||
|
|
return out
|
||
|
|
|
||
|
|
|
||
|
|
def concat_and_cache_mla(
|
||
|
|
kv_c: torch.Tensor,
|
||
|
|
k_pe: torch.Tensor,
|
||
|
|
kv_cache: torch.Tensor,
|
||
|
|
slot_mapping: torch.Tensor,
|
||
|
|
kv_cache_dtype: str,
|
||
|
|
scale: torch.Tensor,
|
||
|
|
) -> None:
|
||
|
|
torch.vacc.concat_and_cache_attention(
|
||
|
|
kv_c, k_pe, kv_cache, slot_mapping)
|