# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional import torch from vllm.logger import init_logger logger = init_logger(__name__) try: import intel_extension_for_pytorch as ipex except ImportError as e: logger.warning("Import error msg: %s", e.msg) class ipex_ops: @staticmethod def _reshape_activation_tensor( x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: num = x.size(0) d = x.size(1) // 2 x = x.reshape(num, 2, d) x1, x2 = torch.chunk(x, chunks=2, dim=1) x1 = x1.reshape(num, d) x2 = x2.reshape(num, d) return x1, x2 @staticmethod def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ipex.llm.functional.silu_and_mul(x, out) @staticmethod def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ipex.llm.functional.gelu_and_mul(x, out) @staticmethod def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ipex.llm.functional.gelu_and_mul(x, out) @staticmethod def gelu_fast(x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.gelu(x) @staticmethod def gelu_new(x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.gelu(x) @staticmethod def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: ipex.llm.functional.gelu_quick(x, out) @staticmethod def paged_attention_v1( out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, num_kv_heads: int, scale: float, block_tables: torch.Tensor, context_lens: torch.Tensor, block_size: int, max_context_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: float, v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: assert kv_cache_dtype == "auto" num_heads = out.size(1) num_queries_per_tokens = num_heads // num_kv_heads ipex.llm.modules.PagedAttention.single_query_kv_attention( out, query.contiguous(), key_cache.view_as(value_cache), value_cache, num_queries_per_tokens, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes, ) @staticmethod def paged_attention_v2( out: torch.Tensor, exp_sum: torch.Tensor, max_logits: torch.Tensor, tmp_out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, num_kv_heads: int, scale: float, block_tables: torch.Tensor, context_lens: torch.Tensor, block_size: int, max_context_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: float, v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: assert kv_cache_dtype == "auto" num_heads = out.size(1) num_queries_per_tokens = num_heads // num_kv_heads ipex.llm.modules.PagedAttention.single_query_kv_attention( out, query.contiguous(), key_cache.view_as(value_cache), value_cache, num_queries_per_tokens, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes, ) @staticmethod def rotary_embedding( positions: torch.Tensor, # [batch_size, seq_len] query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size] key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size] head_size: int, cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim] is_neox: bool, ) -> None: rot_dim = cos_sin_cache.size(1) ipex.llm.functional.rotary_embedding_batched(positions, query, key, head_size, cos_sin_cache, is_neox, rot_dim) @staticmethod def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: ipex.llm.functional.rotary_embedding_batched(positions, query, key, head_size, cos_sin_cache, is_neox, rot_dim, cos_sin_cache_offsets) @staticmethod def rms_norm(input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: return ipex.llm.functional.rms_norm(input, weight, epsilon) @staticmethod def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None, epsilon, True) input.copy_(tmp) @staticmethod def varlen_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, out: torch.Tensor, seqlen_q: torch.Tensor, seqlen_k: torch.Tensor, alibi_slopes: Optional[torch.Tensor], max_seqlen_q: int, max_seqlen_k: int, pdropout: float, softmax_scale: float, zero_tensors: bool, is_causal: bool, return_softmax: bool, gen_: torch.Generator, window_size_left: float, window_size_right: float, logits_soft_cap: float, ) -> None: if ipex.__version__.endswith("cpu"): if logits_soft_cap != 0.0: raise ValueError("IPEX CPU does not support logits_soft_cap") assert alibi_slopes is None assert window_size_left < 0 and window_size_right < 0 ipex.llm.functional.varlen_attention(query.contiguous(), key.contiguous(), value.contiguous(), out, seqlen_q.int(), seqlen_k.int(), max_seqlen_q, max_seqlen_k, pdropout, softmax_scale, zero_tensors, is_causal, return_softmax, gen_) else: # XPU build ipex.llm.functional.varlen_attention( query.contiguous(), key.contiguous(), value.contiguous(), out, seqlen_q.int(), seqlen_k.int(), alibi_slopes, max_seqlen_q, max_seqlen_k, pdropout, softmax_scale, zero_tensors, is_causal, return_softmax, gen_, window_size_left, window_size_right, logits_soft_cap) @staticmethod def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, k_scale: float, v_scale: float, ) -> None: assert kv_cache_dtype == "auto" ipex.llm.modules.PagedAttention.reshape_and_cache( key, value, key_cache, value_cache, slot_mapping) @staticmethod def reshape_and_cache_flash( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, k_scale: Optional[torch.Tensor] = None, v_scale: Optional[torch.Tensor] = None, k_scale_float: float = 1.0, v_scale_float: float = 1.0, ) -> None: assert kv_cache_dtype == "auto" # TODO: support FP8 kv cache. ipex.llm.modules.PagedAttention.reshape_and_cache_flash( key, value, key_cache, value_cache, slot_mapping) @staticmethod def flash_attn_varlen_func( out: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, seqused_k: torch.Tensor, # we don't support this in ipex kernel max_seqlen_q: int, max_seqlen_k: int, softmax_scale: float, causal: bool, block_table: torch.Tensor, alibi_slopes: Optional[torch.Tensor], window_size: Optional[list[int]] = None, softcap: Optional[float] = 0.0, cu_seqlens_k: Optional[torch.Tensor] = None, # The following parameters are not used in ipex kernel currently, # we keep API compatible to CUDA's. scheduler_metadata=None, fa_version: int = 2, q_descale=None, k_descale=None, v_descale=None, num_splits=0, ): if cu_seqlens_k is None: # cu_seqlens_k is not used in ipex kernel. cu_seqlens_k = torch.cumsum(seqused_k, dim=0) cu_seqlens_k = torch.cat([ torch.tensor([0], device=seqused_k.device, dtype=torch.int32), cu_seqlens_k ]).to(torch.int32) real_window_size: tuple[int, int] if window_size is None: real_window_size = (-1, -1) else: assert len(window_size) == 2 real_window_size = (window_size[0], window_size[1]) return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, q.contiguous(), k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, block_table, alibi_slopes, softcap=softcap, window_size_left=real_window_size[0], window_size_right=real_window_size[1], k_scale=1.0, v_scale=1.0, ) @staticmethod def get_scheduler_metadata( batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, cache_seqlens: torch.Tensor, qkv_dtype=torch.bfloat16, headdim_v=None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k_new: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, page_size: Optional[int] = None, max_seqlen_k_new=0, causal=False, window_size=(-1, -1), # -1 means infinite context window has_softcap=False, num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication ) -> None: logger.warning_once( "get_scheduler_metadata is not implemented for ipex_ops, " "returning None.") return None @staticmethod def copy_blocks(key_caches: list[torch.Tensor], value_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None: torch.xpu.copy_blocks( # type: ignore key_caches, value_caches, block_mapping, ) @staticmethod def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore