# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project import torch import triton import triton.language as tl from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size from vllm.utils.torch_utils import direct_register_custom_op @triton.jit def _sgmv_shrink_kernel_mlu( input_ptr, lora_ptr, out_ptr, N, K, b_seq_start_loc, seq_lens, lora_indices, scaling, xm_stride, # hidden_size xk_stride, # 1 l0_stride, # hidden_size*max_rank lora_k_stride, lora_n_stride, cm_stride, cn_stride, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, ): """ The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, introducing SPLIT-K can improve performance """ pid = tl.program_id(axis=0) pid_sk = tl.program_id(axis=1) cur_batch = tl.program_id(axis=2) cta_n_num = tl.cdiv(N, BLOCK_N) pid_m = pid // cta_n_num pid_n = pid % cta_n_num M = tl.load(seq_lens + cur_batch) if pid_m * BLOCK_M > M: return lora_index = tl.load(lora_indices + cur_batch) if lora_index == -1: return cur_seq_start = tl.load(b_seq_start_loc + cur_batch) offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) ''' ============================= Modify by vllm_mlu ============================= @brief: adjust kernel impl to fit mlu. ''' a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \ offset_k[None, :] * xk_stride b_ptr = lora_ptr + l0_stride * lora_index + offset_n[None, :] * lora_k_stride + \ offset_k[:, None] * lora_n_stride ''' ================== End of MLU Hijack ================== ''' accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): ''' ============================= Modify by vllm_mlu ============================= @brief: adjust kernel impl to fit mlu. ''' if EVEN_K: tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M) tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) tiled_a = tl.load(a_ptr, mask=((offset_k[None, :] < k_remaining) & (offset_m[:, None] < M)), other=0.0) tiled_b = tl.load(b_ptr, mask=((offset_k[:, None] < k_remaining) & (offset_n[None, :] < N)), other=0.0) ''' ================== End of MLU Hijack ================== ''' accumulator += tl.dot(tiled_a, tiled_b) a_ptr += BLOCK_K * SPLIT_K * xk_stride b_ptr += BLOCK_K * SPLIT_K * lora_n_stride offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + offset_cn[None, :] * cn_stride) c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < N) accumulator *= scaling # handles write-back with reduction-splitting if SPLIT_K == 1: tl.store(c_ptr, accumulator, mask=c_mask) else: tl.atomic_add(c_ptr, accumulator, mask=c_mask) @torch.inference_mode() def sgmv_shrink_mlu( inputs: torch.Tensor, lora_a_weights: torch.Tensor, output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, token_nums: int, scaling: float, ) -> None: """ Args: inputs (torch.Tensor): input tensor lora_a_weights (torch.Tensor): lora'a weight output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index into sequence. E.g., if the sequence length is [4, 6], it is [0, 4]. seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size max_seq_length (int): The max sequence lengths of the sequences in the batch. token_nums (int): The token numbers in the batch. Used to verify if the token numbers in the inputs matches the one in the metadata. scaling (float): Scaling factor. """ assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype in [torch.float16, torch.bfloat16] assert lora_a_weights.dtype in [ torch.float16, torch.bfloat16, ] assert inputs.size(0) == token_nums assert inputs.size(1) == lora_a_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches assert inputs.is_contiguous() if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) assert lora_a_weights.size(1) == 1 lora_a_weights = lora_a_weights.squeeze(dim=1) else: assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) assert lora_a_weights.is_contiguous() assert output_tensor.is_contiguous() # TODO tuning this config N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank ''' ============================= Modify by vllm_mlu ============================= @brief: Workaround: adjust block size to meet mlu restrictions. The grid of mlu triton kernel must less than 65536, it will be out of bound when the input seq is very long, and causes runtime error. So we need to adjust the block size to avoid this. ''' BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 16) ''' ================== End of MLU Hijack ================== ''' BLOCK_K = 32 SPLIT_K = 8 EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 grid = ( triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), SPLIT_K, batches, ) ''' ============================= Modify by vllm_mlu ============================= @brief: call _sgmv_shrink_kernel_mlu ''' _sgmv_shrink_kernel_mlu[grid]( inputs, lora_a_weights, output_tensor, N, K, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, scaling, inputs.stride(0), inputs.stride(1), lora_a_weights.stride(0), lora_a_weights.stride(1), lora_a_weights.stride(2), output_tensor.stride(0), output_tensor.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, ) ''' ================== End of MLU Hijack ================== ''' return