update
This commit is contained in:
159
vllm/_xpu_ops.py
Normal file
159
vllm/_xpu_ops.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def register_fake(fn):
|
||||
return lambda name: fn
|
||||
else:
|
||||
try:
|
||||
from torch.library import register_fake
|
||||
except ImportError:
|
||||
from torch.library import impl_abstract as register_fake
|
||||
|
||||
if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"):
|
||||
|
||||
@register_fake("_xpu_C::fp8_gemm_w8a16")
|
||||
def _fp8_gemm_w8a16_fake(
|
||||
input: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
M = input_2d.size(0)
|
||||
N = q_weight.size(1)
|
||||
return torch.empty((M, N), dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
|
||||
|
||||
@register_fake("_xpu_C::int4_gemm_w4a16")
|
||||
def _int4_gemm_w4a16_fake(
|
||||
input: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
weight_scale: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
group_size: int,
|
||||
group_idx: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
M = input_2d.size(0)
|
||||
N = q_weight.size(1)
|
||||
return torch.empty((M, N), dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
class xpu_ops:
|
||||
@staticmethod
|
||||
def flash_attn_varlen_func(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float | None = None,
|
||||
causal: bool = False,
|
||||
out: torch.Tensor | None = None,
|
||||
block_table: torch.Tensor | None = None,
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
window_size: list[int] | None = None,
|
||||
softcap: float | None = 0.0,
|
||||
seqused_k: torch.Tensor | None = None,
|
||||
cu_seqlens_k: torch.Tensor | None = None,
|
||||
# passed in qwen vl
|
||||
dropout_p: float = 0.0,
|
||||
# The following parameters are not used in xpu kernel currently,
|
||||
# we keep API compatible to CUDA's.
|
||||
scheduler_metadata=None,
|
||||
fa_version: int = 2,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
num_splits=0,
|
||||
return_softmax_lse: bool | None = False,
|
||||
s_aux: torch.Tensor | None = None,
|
||||
):
|
||||
assert cu_seqlens_k is not None or seqused_k is not None, (
|
||||
"cu_seqlens_k or seqused_k must be provided"
|
||||
)
|
||||
assert cu_seqlens_k is None or seqused_k is None, (
|
||||
"cu_seqlens_k and seqused_k cannot be provided at the same time"
|
||||
)
|
||||
assert block_table is None or seqused_k is not None, (
|
||||
"when enable block_table, seqused_k is needed"
|
||||
)
|
||||
assert block_table is not None or cu_seqlens_k is not None, (
|
||||
"when block_table is disabled, cu_seqlens_k is needed"
|
||||
)
|
||||
if out is None:
|
||||
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||
real_window_size: tuple[int, int]
|
||||
if window_size is None:
|
||||
real_window_size = (-1, -1)
|
||||
else:
|
||||
assert len(window_size) == 2
|
||||
real_window_size = (window_size[0], window_size[1]) # noqa: F841
|
||||
|
||||
# In encode attention, k and v maybe not contiguous and current
|
||||
# kernel can't handle it
|
||||
if block_table is None:
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
return flash_attn_varlen_func(
|
||||
out=out,
|
||||
q=q.contiguous(),
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
block_table=block_table,
|
||||
s_aux=s_aux,
|
||||
window_size=real_window_size,
|
||||
# alibi_slopes = alibi_slopes,
|
||||
# softcap=softcap,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scheduler_metadata(
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads_q,
|
||||
num_heads_kv,
|
||||
headdim,
|
||||
cache_seqlens: torch.Tensor,
|
||||
qkv_dtype=torch.bfloat16,
|
||||
headdim_v=None,
|
||||
cu_seqlens_q: torch.Tensor | None = None,
|
||||
cu_seqlens_k_new: torch.Tensor | None = None,
|
||||
cache_leftpad: torch.Tensor | None = None,
|
||||
page_size: int | None = None,
|
||||
max_seqlen_k_new=0,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
has_softcap=False,
|
||||
num_splits=0, # Can be tuned for speed
|
||||
pack_gqa=None, # Can be tuned for speed
|
||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
) -> None:
|
||||
logger.warning_once(
|
||||
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
|
||||
)
|
||||
return None
|
||||
Reference in New Issue
Block a user