# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024 ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128] ALLSPARK_AMPERE_N_ALIGN = 16 ALLSPARK_AMPERE_K_ALIGN = 16 def check_allspark_supported_dtype_shape(input_size_per_partition: int, output_size_per_partition: int, group_size: int, weight_dtype: ScalarType, act_dtype: torch.dtype): capability_tuple = current_platform.get_device_capability() device_capability = (-1 if capability_tuple is None else capability_tuple.to_int()) # For Ampere GPU if device_capability >= 80 and device_capability < 90: if group_size != -1: return False, \ "For Ampere GPU, AllSpark does not support group_size "\ f"= {group_size}. Only group_size = -1 are supported." if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES: return False, "For Ampere GPU, AllSpark does not support "\ f"quant type ({weight_dtype}). Only quant type "\ f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported." if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \ or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0: return False, \ "AllSpark needs input_size_per_partition % "\ f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\ f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\ "for Ampere GPU optimized kernels." if act_dtype != torch.float16 and act_dtype != torch.bfloat16: return False, \ "AllSpark only supports act_dtype = float16 or bfloat16,"\ f"for Ampere GPU, but got act_dtype = {act_dtype}." else: return False, "AllSpark currently does not support "\ f"device_capability = {device_capability}." return True, None