init src 0.9.2
This commit is contained in:
83
vllm/model_executor/layers/quantization/deepgemm.py
Normal file
83
vllm/model_executor/layers/quantization/deepgemm.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import direct_register_custom_op, has_deep_gemm
|
||||
|
||||
if has_deep_gemm():
|
||||
import deep_gemm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prepare_block_fp8_matmul_inputs(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> tuple[int, int, int, torch.Tensor]:
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1]
|
||||
assert A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
assert B.ndim == 2
|
||||
assert B.is_contiguous()
|
||||
assert Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N, )
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
return M, N, K, C
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul_deepgemm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
|
||||
output_dtype)
|
||||
# Deepgemm only supports output tensor type as bfloat16
|
||||
assert C.dtype == torch.bfloat16
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
return C
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul_deepgemm_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
|
||||
output_dtype)
|
||||
return C
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="w8a8_block_fp8_matmul_deepgemm",
|
||||
op_func=w8a8_block_fp8_matmul_deepgemm,
|
||||
mutates_args=[],
|
||||
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
Reference in New Issue
Block a user