# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility methods for model layers.""" from typing import Callable, Optional import torch from vllm import _custom_ops as ops from vllm import envs from vllm.platforms import CpuArchEnum, current_platform from vllm.utils import direct_register_custom_op def shuffle_weight(w: torch.Tensor) -> torch.Tensor: # Shuffle weight along the last dimension so that # we folded the weights to adjance location # Example: # input: # [[1, 2, 3, 4, 5, 6], # [7, 8, 9, 10, 11, 12]] # output: # [[1, 4, 2, 5, 3, 6], # [7, 10, 8, 11, 9, 12]] # This will be used together with triton swiglu kernel shape = w.shape N = shape[-1] first = w[..., :N // 2] second = w[..., N // 2:] stacked = torch.stack((first, second), dim=-1) w_shuffled = stacked.reshape(shape) return w_shuffled def get_token_bin_counts_and_mask( tokens: torch.Tensor, vocab_size: int, num_seqs: int, ) -> tuple[torch.Tensor, torch.Tensor]: # Compute the bin counts for the tokens. # vocab_size + 1 for padding. bin_counts = torch.zeros((num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device) bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) bin_counts = bin_counts[:, :vocab_size] mask = bin_counts > 0 return bin_counts, mask def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, frequency_penalties: torch.Tensor, repetition_penalties: torch.Tensor) -> torch.Tensor: """ Applies penalties in place to the logits tensor logits : The input logits tensor of shape [num_seqs, vocab_size] prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts are padded to the maximum prompt length within the batch using `vocab_size` as the padding value. The value `vocab_size` is used for padding because it does not correspond to any valid token ID in the vocabulary. output_tokens_tensor: The output tokens tensor. presence_penalties: The presence penalties of shape (num_seqs, ) frequency_penalties: The frequency penalties of shape (num_seqs, ) repetition_penalties: The repetition penalties of shape (num_seqs, ) """ num_seqs, vocab_size = logits.shape _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, num_seqs) output_bin_counts, output_mask = get_token_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) # Apply repetition penalties as a custom op from vllm._custom_ops import apply_repetition_penalties apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties) # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits def default_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): return torch.nn.functional.linear(x, weight, bias) def rocm_unquantized_gemm_impl( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: from vllm.platforms.rocm import on_gfx9 k = weight.shape[1] use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ x.dtype in [torch.float16, torch.bfloat16] \ and k % 8 == 0) if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) x_view = x.view(-1, x.size(-1)) n = x_view.shape[0] m = weight.shape[0] cu_count = current_platform.get_cu_count() if m > 8 and 0 < n <= 4: out = ops.wvSplitK(weight, x_view, cu_count, bias) return out.view(*x.shape[:-1], weight.shape[0]) elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None: out = ops.LLMM1(weight, x_view, 4) return out.view(*x.shape[:-1], weight.shape[0]) return torch.nn.functional.linear(x, weight, bias) def rocm_unquantized_gemm_impl_fake( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: return x.new_empty((*x.shape[:-1], weight.shape[0])) def rocm_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias) direct_register_custom_op( op_name="rocm_unquantized_gemm_impl", op_func=rocm_unquantized_gemm_impl, fake_impl=rocm_unquantized_gemm_impl_fake, ) def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool: return (torch._C._cpu._is_amx_tile_supported() and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0 and n % 16 == 0) def dispatch_cpu_unquantized_gemm( layer: torch.nn.Module, remove_weight: bool, ) -> None: N, K = layer.weight.size() dtype = layer.weight.dtype if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype): packed_weight = torch.ops._C.convert_weight_packed(layer.weight) if getattr(layer, "bias", None) is not None: bias_f32 = layer.bias.to(torch.float32) else: bias_f32 = None layer.cpu_linear = ( lambda x, weight, bias: torch.ops._C.weight_packed_linear( x, packed_weight, bias_f32 if bias is not None else None, True)) if remove_weight: layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) elif (ops._supports_onednn and current_platform.get_cpu_architecture() == CpuArchEnum.X86): origin_weight = layer.weight if remove_weight: layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) handler = ops.create_onednn_mm(origin_weight.t(), 32) layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm( handler, x, bias) else: layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( x, weight, bias) def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): return layer.cpu_linear(x, weight, bias) def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: if current_platform.is_rocm(): return rocm_unquantized_gemm elif current_platform.is_cpu(): return cpu_unquantized_gemm else: return default_unquantized_gemm