Sync from v0.13
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024
|
||||
ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128]
|
||||
ALLSPARK_AMPERE_N_ALIGN = 16
|
||||
ALLSPARK_AMPERE_K_ALIGN = 16
|
||||
|
||||
|
||||
def check_allspark_supported_dtype_shape(
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
group_size: int,
|
||||
weight_dtype: ScalarType,
|
||||
act_dtype: torch.dtype,
|
||||
):
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
|
||||
# For Ampere GPU
|
||||
if device_capability >= 80 and device_capability < 90:
|
||||
if group_size != -1:
|
||||
return (
|
||||
False,
|
||||
"For Ampere GPU, AllSpark does not support group_size "
|
||||
f"= {group_size}. Only group_size = -1 are supported.",
|
||||
)
|
||||
|
||||
if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES:
|
||||
return (
|
||||
False,
|
||||
"For Ampere GPU, AllSpark does not support "
|
||||
f"quant type ({weight_dtype}). Only quant type "
|
||||
f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported.",
|
||||
)
|
||||
|
||||
if (
|
||||
input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0
|
||||
or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"AllSpark needs input_size_per_partition % "
|
||||
f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "
|
||||
f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "
|
||||
"for Ampere GPU optimized kernels.",
|
||||
)
|
||||
|
||||
if act_dtype != torch.float16 and act_dtype != torch.bfloat16:
|
||||
return (
|
||||
False,
|
||||
"AllSpark only supports act_dtype = float16 or bfloat16,"
|
||||
f"for Ampere GPU, but got act_dtype = {act_dtype}.",
|
||||
)
|
||||
else:
|
||||
return (
|
||||
False,
|
||||
"AllSpark currently does not support "
|
||||
f"device_capability = {device_capability}.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
Reference in New Issue
Block a user