160 lines
5.2 KiB
Python
160 lines
5.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
def register_fake(fn):
|
|
return lambda name: fn
|
|
else:
|
|
try:
|
|
from torch.library import register_fake
|
|
except ImportError:
|
|
from torch.library import impl_abstract as register_fake
|
|
|
|
if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"):
|
|
|
|
@register_fake("_xpu_C::fp8_gemm_w8a16")
|
|
def _fp8_gemm_w8a16_fake(
|
|
input: torch.Tensor,
|
|
q_weight: torch.Tensor,
|
|
weight_scale: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
input_2d = input.view(-1, input.shape[-1])
|
|
M = input_2d.size(0)
|
|
N = q_weight.size(1)
|
|
return torch.empty((M, N), dtype=input.dtype, device=input.device)
|
|
|
|
|
|
if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
|
|
|
|
@register_fake("_xpu_C::int4_gemm_w4a16")
|
|
def _int4_gemm_w4a16_fake(
|
|
input: torch.Tensor,
|
|
q_weight: torch.Tensor,
|
|
bias: torch.Tensor | None,
|
|
weight_scale: torch.Tensor,
|
|
qzeros: torch.Tensor,
|
|
group_size: int,
|
|
group_idx: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
input_2d = input.view(-1, input.shape[-1])
|
|
M = input_2d.size(0)
|
|
N = q_weight.size(1)
|
|
return torch.empty((M, N), dtype=input.dtype, device=input.device)
|
|
|
|
|
|
class xpu_ops:
|
|
@staticmethod
|
|
def flash_attn_varlen_func(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
cu_seqlens_q: torch.Tensor,
|
|
max_seqlen_q: int,
|
|
max_seqlen_k: int,
|
|
softmax_scale: float | None = None,
|
|
causal: bool = False,
|
|
out: torch.Tensor | None = None,
|
|
block_table: torch.Tensor | None = None,
|
|
alibi_slopes: torch.Tensor | None = None,
|
|
window_size: list[int] | None = None,
|
|
softcap: float | None = 0.0,
|
|
seqused_k: torch.Tensor | None = None,
|
|
cu_seqlens_k: torch.Tensor | None = None,
|
|
# passed in qwen vl
|
|
dropout_p: float = 0.0,
|
|
# The following parameters are not used in xpu kernel currently,
|
|
# we keep API compatible to CUDA's.
|
|
scheduler_metadata=None,
|
|
fa_version: int = 2,
|
|
q_descale=None,
|
|
k_descale=None,
|
|
v_descale=None,
|
|
num_splits=0,
|
|
return_softmax_lse: bool | None = False,
|
|
s_aux: torch.Tensor | None = None,
|
|
):
|
|
assert cu_seqlens_k is not None or seqused_k is not None, (
|
|
"cu_seqlens_k or seqused_k must be provided"
|
|
)
|
|
assert cu_seqlens_k is None or seqused_k is None, (
|
|
"cu_seqlens_k and seqused_k cannot be provided at the same time"
|
|
)
|
|
assert block_table is None or seqused_k is not None, (
|
|
"when enable block_table, seqused_k is needed"
|
|
)
|
|
assert block_table is not None or cu_seqlens_k is not None, (
|
|
"when block_table is disabled, cu_seqlens_k is needed"
|
|
)
|
|
if out is None:
|
|
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
|
real_window_size: tuple[int, int]
|
|
if window_size is None:
|
|
real_window_size = (-1, -1)
|
|
else:
|
|
assert len(window_size) == 2
|
|
real_window_size = (window_size[0], window_size[1]) # noqa: F841
|
|
|
|
# In encode attention, k and v maybe not contiguous and current
|
|
# kernel can't handle it
|
|
if block_table is None:
|
|
k = k.contiguous()
|
|
v = v.contiguous()
|
|
return flash_attn_varlen_func(
|
|
out=out,
|
|
q=q.contiguous(),
|
|
k=k,
|
|
v=v,
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
seqused_k=seqused_k,
|
|
max_seqlen_q=max_seqlen_q,
|
|
max_seqlen_k=max_seqlen_k,
|
|
softmax_scale=softmax_scale,
|
|
causal=causal,
|
|
block_table=block_table,
|
|
s_aux=s_aux,
|
|
window_size=real_window_size,
|
|
# alibi_slopes = alibi_slopes,
|
|
# softcap=softcap,
|
|
return_softmax_lse=return_softmax_lse,
|
|
)
|
|
|
|
@staticmethod
|
|
def get_scheduler_metadata(
|
|
batch_size,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
num_heads_q,
|
|
num_heads_kv,
|
|
headdim,
|
|
cache_seqlens: torch.Tensor,
|
|
qkv_dtype=torch.bfloat16,
|
|
headdim_v=None,
|
|
cu_seqlens_q: torch.Tensor | None = None,
|
|
cu_seqlens_k_new: torch.Tensor | None = None,
|
|
cache_leftpad: torch.Tensor | None = None,
|
|
page_size: int | None = None,
|
|
max_seqlen_k_new=0,
|
|
causal=False,
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
has_softcap=False,
|
|
num_splits=0, # Can be tuned for speed
|
|
pack_gqa=None, # Can be tuned for speed
|
|
sm_margin=0, # Can be tuned if some SMs are used for communication
|
|
) -> None:
|
|
logger.warning_once(
|
|
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
|
|
)
|
|
return None
|