import contextlib import importlib from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch import torch import torch.library import vllm.envs as envs from vllm.logger import init_logger from vllm.scalar_type import ScalarType def cutlass_scaled_mm_vacc(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """ `cutlass_scaled_mm` implements a fused version of `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` where scale_a * a and scale_b * b are implemented using numpy-style broadcasting. In order to support blockwise scaling like found in DeepSeek V3 we also support extended "group" broadcast rules. We extend the numpy-style broadcasting rules with the following rule: "if the extent of a dimension in the source shape is between 1 and corresponding extent in the target shape we repeat each element along that dimension src_shape[dim] // target_shape[dim] times consecutively" example if we have: a = [[1, 2], and target_shape = (2, 4) [3, 4]] then we would expand a to: a = [[1, 1, 2, 2], [3, 3, 4, 4]] currently we only support the case: scale_a.shape * [1, 128] == a.shape scale_b.shape * [128, 128] == b.shape """ assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert bias is None or bias.shape[0] == b.shape[ 1] and bias.dtype == out_dtype m = a.shape[0] n = b.shape[1] if current_platform.is_rocm(): triton_scaled_mm_module = importlib.import_module( "vllm.model_executor.layers.quantization.compressed_tensors." "triton_scaled_mm") triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) # print('a',a.shape,a.dtype) # torch.Size([8192, 3584]) torch.float8_e4m3fn # print('scale_a',scale_a.shape) #torch.Size([8192, 56]) # print('b',b.shape,b.dtype) # torch.Size([3584, 1536]) torch.float8_e4m3fn # print('scale_b',scale_b.shape) #torch.Size([56, 12]) use_a32_w32 = True #反量化到fp32 计算 matmul if use_a32_w32 or (b.shape[1]//scale_b.shape[1] != 128 or a.shape[1]//scale_a.shape[1] != 128 or b.shape[0]//scale_b.shape[0] != 128): # cutlass_scaled_mm 不支持非128的 quant block a1 = a.to(torch.float32).reshape(a.shape[0], scale_a.shape[1], -1) scale_a = scale_a.reshape(scale_a.shape[0], scale_a.shape[1], 1).to(torch.float32) a = (a1*scale_a).reshape(a.shape).contiguous() b1 = b.to(torch.float32).reshape(scale_b.shape[0], b.shape[0]//scale_b.shape[0], scale_b.shape[1], b.shape[1]//scale_b.shape[1]) scale_b = scale_b.reshape(scale_b.shape[0], 1, scale_b.shape[1], 1).to(torch.float32) b = (b1*scale_b).reshape(b.shape).contiguous() out = a@b if bias is not None: out = out + bias return out.to(out_dtype) torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) return out def concat_and_cache_mla( kv_c: torch.Tensor, k_pe: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, scale: torch.Tensor, ) -> None: torch.vacc.concat_and_cache_attention( kv_c, k_pe, kv_cache, slot_mapping)