Files

96 lines
3.7 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
import contextlib
import importlib
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch
import torch.library
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.scalar_type import ScalarType
def cutlass_scaled_mm_vacc(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
`cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
[3, 4]]
then we would expand a to:
a = [[1, 1, 2, 2],
[3, 3, 4, 4]]
currently we only support the case:
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype
m = a.shape[0]
n = b.shape[1]
if current_platform.is_rocm():
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
# print('a',a.shape,a.dtype) # torch.Size([8192, 3584]) torch.float8_e4m3fn
# print('scale_a',scale_a.shape) #torch.Size([8192, 56])
# print('b',b.shape,b.dtype) # torch.Size([3584, 1536]) torch.float8_e4m3fn
# print('scale_b',scale_b.shape) #torch.Size([56, 12])
use_a32_w32 = True #反量化到fp32 计算 matmul
if use_a32_w32 or (b.shape[1]//scale_b.shape[1] != 128 or
a.shape[1]//scale_a.shape[1] != 128 or
b.shape[0]//scale_b.shape[0] != 128):
# cutlass_scaled_mm 不支持非128的 quant block
a1 = a.to(torch.float32).reshape(a.shape[0], scale_a.shape[1], -1)
scale_a = scale_a.reshape(scale_a.shape[0], scale_a.shape[1], 1).to(torch.float32)
a = (a1*scale_a).reshape(a.shape).contiguous()
b1 = b.to(torch.float32).reshape(scale_b.shape[0], b.shape[0]//scale_b.shape[0], scale_b.shape[1], b.shape[1]//scale_b.shape[1])
scale_b = scale_b.reshape(scale_b.shape[0], 1, scale_b.shape[1], 1).to(torch.float32)
b = (b1*scale_b).reshape(b.shape).contiguous()
out = a@b
if bias is not None:
out = out + bias
return out.to(out_dtype)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
return out
def concat_and_cache_mla(
kv_c: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
scale: torch.Tensor,
) -> None:
torch.vacc.concat_and_cache_attention(
kv_c, k_pe, kv_cache, slot_mapping)