# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional import numpy import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types from .quant_utils import pack_cols, unpack_cols logger = init_logger(__name__) GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_MIN_THREAD_N = 64 GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MAX_PARALLEL = 16 MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] # In case there is a performance issue with Marlin, the variable below can be # changed to False, which allows Marlin to perform global reductions in fp16 # precision (instead of fp32), and therefore, save on some memory movements. USE_FP32_REDUCE_DEFAULT = True # For binary size and compile time, we don't support the same types for with and # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( has_zp: Optional[bool] = None, include_fp_type: bool = True, device_capability: Optional[int] = None, ): if device_capability is None: capability_tuple = current_platform.get_device_capability() device_capability = (-1 if capability_tuple is None else capability_tuple.to_int()) if device_capability < 80: return [] # - has_zp is True: return quant_types that has zero points # - has_zp is False: return quant_types that has not zero points # - has_zp is None: both if has_zp is None: types0 = query_marlin_supported_quant_types(False, include_fp_type, device_capability) types1 = query_marlin_supported_quant_types(True, include_fp_type, device_capability) return types0 + types1 if has_zp: # AWQ style, unsigned + runtime zero-point return [scalar_types.uint4] else: # GPTQ style, unsigned + symmetric bias res = [scalar_types.uint4b8, scalar_types.uint8b128] if include_fp_type: res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] return res def _check_marlin_supported( quant_type: ScalarType, group_size: Optional[int], has_zp: bool, device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = current_platform.get_device_capability() device_capability = (-1 if capability_tuple is None else capability_tuple.to_int()) supported_types = query_marlin_supported_quant_types( has_zp, True, device_capability) if quant_type not in supported_types: return (False, f"Marlin does not support weight_bits = {quant_type}. " f"Only types = {supported_types} " f"are supported (for group_size = {group_size}, " f"device_capability = {device_capability}, zp = {has_zp}).") if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): return (False, f"Marlin does not support group_size = {group_size}. " f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " "are supported.") return True, None def check_marlin_supported(quant_type: ScalarType, group_size: int, has_zp: bool = False, device_capability: Optional[int] = None) -> bool: cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) return cond def verify_marlin_supported(quant_type: ScalarType, group_size: int, has_zp: bool = False) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) def verify_marlin_supports_shape(output_size_per_partition: int, input_size_per_partition: int, input_size: int, group_size: int) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: raise ValueError(f"Weight output_size_per_partition = " f"{output_size_per_partition} is not divisible by " f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " "Consider reducing tensor_parallel_size or running " "with --quantization gptq.") # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: raise ValueError(f"Weight input_size_per_partition = " f"{input_size_per_partition} is not divisible " f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " "Consider reducing tensor_parallel_size or running " "with --quantization gptq.") if (group_size < input_size and input_size_per_partition % group_size != 0): raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " "with --quantization gptq.") def check_marlin_supports_shape(output_size_per_partition: int, input_size_per_partition: int, input_size: int, group_size: int) \ -> tuple[bool, Optional[str]]: try: verify_marlin_supports_shape(output_size_per_partition, input_size_per_partition, input_size, group_size) except ValueError as e: return False, e.__str__() return True, None #暂不支持marlinlinear def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ -> bool: output_size_per_partition = getattr(layer, "output_size_per_partition", None) or layer.output_size input_size_per_partition = getattr(layer, "input_size_per_partition", None) or layer.input_size # return check_marlin_supports_shape( # output_size_per_partition=output_size_per_partition, # input_size_per_partition=input_size_per_partition, # input_size=layer.input_size, # group_size=group_size)[0] return False def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ -> bool: hidden_size = layer.hidden_size intermediate_size_per_partition = layer.intermediate_size_per_partition # apply_router_weight_on_input is not supported for moe marlin supports_router_weight = not layer.apply_router_weight_on_input # moe marlin requires the activation to be silu supports_activation = layer.activation == "silu" #暂时只支持bw device_name = torch.cuda.get_device_properties(torch.cuda.current_device()).name supports_device = "BW" in device_name # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) # down: (n, k) = (hidden_size, intermediate_size_per_partition) # moe marlin requires n % 128 == 0 and k % 64 == 0 supports_shape = hidden_size % 128 == 0 and \ intermediate_size_per_partition % max(64, group_size) == 0 #暂时只支持64 supports_group_size = group_size in [64] return supports_shape and supports_group_size and \ supports_router_weight and supports_activation and supports_device def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL return torch.zeros(max_workspace_size, dtype=torch.int, device=device, requires_grad=False) def marlin_make_workspace_new(device: torch.device, max_blocks_per_sm: int = 1) -> torch.Tensor: # In the new marlin kernel, we use the num of threadblocks as workspace # size. The num of threadblocks is is sms_count * max_blocks_per_sm. sms = torch.cuda.get_device_properties(device).multi_processor_count return torch.zeros(sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, is_row_parallel: bool) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 return act_order or (is_channelwise and is_row_parallel) def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), requires_grad=False) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), requires_grad=False) def marlin_sort_g_idx( g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices # def get_scale_perms(): # scale_perm: list[int] = [] # for i in range(8): # scale_perm.extend([i + 8 * j for j in range(8)]) # scale_perm_single: list[int] = [] # for i in range(4): # scale_perm_single.extend( # [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) # return scale_perm, scale_perm_single # def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, # group_size: int) -> torch.Tensor: # scale_perm, scale_perm_single = get_scale_perms() # if group_size < size_k and group_size != -1: # s = s.reshape((-1, len(scale_perm)))[:, scale_perm] # else: # s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] # s = s.reshape((-1, size_n)).contiguous() # return s def get_scale_perms(): scale_perm: List[int] = [] for i in range(16): # 遍历列方向不同scale的 8个线程 scale_perm.extend([i + 16 * j for j in range(8)]) # 插入 8 个数据块中 对应位置的索引 return scale_perm def marlin_permute_scales(s: torch.Tensor, # [56, 512] # torch.float16 size_k: int, # 7168 size_n: int, # 512 group_size: int # 128 ) -> torch.Tensor: # 将[128, 128](fp16) B矩阵中 每个[16, 16]计算块中的对应位置的 zero值 放到一起 scale_perm = get_scale_perms() s = s.reshape((-1, len(scale_perm)))[:, scale_perm] s = s.reshape((-1, size_n)).contiguous() return s def marlin_moe_permute_scales( s: torch.Tensor, size_k: int, size_n: int, group_size: int, ): num_experts = s.shape[0] output = torch.empty( (num_experts, s.shape[1], s.shape[2]), device=s.device, dtype=s.dtype, ) for e in range(num_experts): output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) return output def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: # 和 scale 使用一致的重排逻辑,将[128, 128](fp16) B矩阵中 每个[16, 16]计算块中的对应位置的 zero值 放到一起 scale_perm = get_scale_perms() zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] # uint4 混排 if num_bits == 4: interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) elif num_bits == 8: interleave = numpy.array([0, 2, 1, 3]) else: raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) # uint4打包成 int32 zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() zp = zp.reshape((-1, size_n)).contiguous() zp = pack_cols(zp, num_bits, size_k, size_n) return zp def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation # and pack it back. q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) # Undo interleaving (use argsort(..) to get inverse perm) if num_bits == 4: undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) elif num_bits == 8: undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) else: raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() q_zp = q_zp.reshape((-1, size_n)).contiguous() marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) return marlin_zp def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), device=q_zp_packed.device, dtype=q_zp_packed.dtype, ) for e in range(num_experts): output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) return output def maybe_warn_marlin_atomic_add(device, dtype): if torch.compiler.is_dynamo_compiling(): return device_capability = torch.cuda.get_device_capability(device) if device_capability[0] < 9 and dtype == torch.bfloat16: logger.info_once( "You are running Marlin kernel with bf16 on GPUs before SM90. " "You can consider change to fp16 to achieve better performance " "if possible.") def maybe_warn_marlin_atomic_add_env(): if torch.compiler.is_dynamo_compiling(): return if envs.VLLM_MARLIN_USE_ATOMIC_ADD: return logger.info_once( "Marlin kernel can achieve better performance for small size_n " "with experimental use_atomic_add feature. " "You can consider set environment variable " "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, dtype: torch.dtype) -> bool: # the performance of atomicAdd is better than global reduce # only when m*n is small and k is large if n >= 2048 or k < 2048 or device.type != "cuda": return False # disable atomicAdd reduce by default, # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 if not envs.VLLM_MARLIN_USE_ATOMIC_ADD: maybe_warn_marlin_atomic_add_env() return False # sm8x doesn't support atomicAdd + bfloat16 natively device_capability = torch.cuda.get_device_capability(device) if device_capability[0] < 9 and dtype == torch.bfloat16: maybe_warn_marlin_atomic_add(device, dtype) return False return True def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zp: torch.Tensor, g_idx: torch.Tensor, g_idx_sort_indices: torch.Tensor, workspace: torch.Tensor, wtype: ScalarType, output_size_per_partition: int, input_size_per_partition: int, is_k_full: bool, bias: Optional[torch.Tensor] = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition, ) use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), n=output_size_per_partition, k=reshaped_x.size(1), device=input.device, dtype=input.dtype) output = ops.gptq_marlin_gemm(reshaped_x, None, weight, weight_scale, None, weight_zp, g_idx, g_idx_sort_indices, workspace, wtype, size_m=reshaped_x.shape[0], size_n=output_size_per_partition, size_k=input_size_per_partition, is_k_full=is_k_full, use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce, is_zp_float=False) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) def apply_awq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zp: torch.Tensor, g_idx: torch.Tensor, g_idx_sort_indices: torch.Tensor, workspace: torch.Tensor, quant_type: ScalarType, output_size_per_partition: int, input_size_per_partition: int, bias: Optional[torch.Tensor] = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition, ) use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), n=output_size_per_partition, k=reshaped_x.size(1), device=input.device, dtype=input.dtype) output = ops.gptq_marlin_gemm(reshaped_x, None, weight, weight_scale, None, weight_zp, g_idx, g_idx_sort_indices, workspace, quant_type, size_m=reshaped_x.shape[0], size_n=output_size_per_partition, size_k=input_size_per_partition, use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce, is_zp_float=False) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) def merge_scales_zeros(marlin_s: torch.Tensor, marlin_zp: torch.Tensor, data_num_0: int, data_num_1: int) -> torch.Tensor: """ 合并两个 Tensor, 每行交替取 data_num_0 个 float16 和 data_num_1 个 int32。 要求: - marlin_s 每行长度能被 data_num_0 整除 - marlin_zp 每行长度能被 data_num_1 整除 - 合并后的总字节数必为 4 的倍数 返回: [N, M] 的 int32 Tensor(行数一致,列数已对齐) """ assert marlin_s.shape[0] == marlin_zp.shape[0], "Batch size mismatch" assert marlin_s.dtype == torch.float16 assert marlin_zp.dtype == torch.int32 N, D0 = marlin_s.shape _, D1 = marlin_zp.shape assert D0 % data_num_0 == 0, "marlin_s 每行必须能被 data_num_0 整除" assert D1 % data_num_1 == 0, "marlin_zp 每行必须能被 data_num_1 整除" s_block_count = D0 // data_num_0 zp_block_count = D1 // data_num_1 assert s_block_count == zp_block_count total_blocks = s_block_count # 转为字节视图 s_bytes = marlin_s.view(torch.uint8).reshape(N, -1) zp_bytes = marlin_zp.view(torch.uint8).reshape(N, -1) # 每行的合并结果 merged_rows = [] for i in range(N): s_row = s_bytes[i] zp_row = zp_bytes[i] s_ptr = 0 zp_ptr = 0 merged = [] for _ in range(total_blocks): # 如果 s 还有剩余 block,就取 if s_ptr < s_row.numel(): chunk_s = s_row[s_ptr: s_ptr + data_num_0 * 2] # float16 = 2 字节 merged.append(chunk_s) s_ptr += data_num_0 * 2 # 如果 zp 还有剩余 block,就取 if zp_ptr < zp_row.numel(): chunk_zp = zp_row[zp_ptr: zp_ptr + data_num_1 * 4] # int32 = 4 字节 merged.append(chunk_zp) zp_ptr += data_num_1 * 4 # 合并所有字节,并直接转换为 int32 merged_bytes = torch.cat(merged) # assert merged_bytes.numel() % 4 == 0, "最终字节长度必须是4的倍数" merged_int32 = merged_bytes.view(torch.int32) merged_rows.append(merged_int32) # 所有合并行长度一致,可以直接堆叠 result = torch.stack(merged_rows) return result def awq_marlin_moe_permute_sz( s : torch.Tensor, z : torch.Tensor, size_k: int, size_n: int, ) -> torch.Tensor: num_experts = s.shape[0] # output = torch.empty((num_experts, size_k // 16, size_n//2 + size_n//8), # device=z.device, # dtype=z.dtype) outputs = [] for e in range(num_experts): out_sz = merge_scales_zeros(s[e], z[e], 128, 16) outputs.append(out_sz) return torch.stack(outputs, dim=0)