# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project import torch import triton import triton.language as tl from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size from vllm.utils.torch_utils import direct_register_custom_op @triton.jit def _sgmv_expand_kernel_mlu( input_ptr, lora_ptr, out_ptr, N, K, b_seq_start_loc, seq_lens, lora_indices, xm_stride, xk_stride, # 1 l0_stride, # hidden_size*max_rank lora_k_stride, lora_n_stride, cm_stride, cn_stride, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, ADD_INPUTS: tl.constexpr, CAST_TYPE: tl.constexpr, ): """ The sgmv's expand triton kernel is based on GroupGEMM. """ pid = tl.program_id(axis=0) cur_batch = tl.program_id(axis=1) cta_n_num = tl.cdiv(N, BLOCK_N) pid_m = pid // cta_n_num pid_n = pid % cta_n_num M = tl.load(seq_lens + cur_batch) if pid_m * BLOCK_M > M: return lora_index = tl.load(lora_indices + cur_batch) if lora_index == -1: return cur_seq_start = tl.load(b_seq_start_loc + cur_batch) offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N offset_k = tl.arange(0, BLOCK_K) ''' ============================= Modify by vllm_mlu ============================= @brief: adjust kernel impl to fit mlu. ''' a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \ offset_k[None, :] * xk_stride b_ptr = lora_ptr + l0_stride * lora_index + \ offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride ''' ================== End of MLU Hijack ================== ''' accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(tl.cdiv(K, BLOCK_K)): ''' ============================= Modify by vllm_mlu ============================= @brief: adjust kernel impl to fit mlu. ''' if EVEN_K: tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M) tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N) else: tiled_a = tl.load(a_ptr, mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)), other=0) tiled_b = tl.load(b_ptr, mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)), other=0) ''' ================== End of MLU Hijack ================== ''' if CAST_TYPE: tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) accumulator += tl.dot( tiled_a, tiled_b, ) a_ptr += BLOCK_K * xk_stride b_ptr += BLOCK_K * lora_n_stride tiled_c = accumulator.to(lora_ptr.dtype.element_ty) offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + offset_cn[None, :] * cn_stride) M = tl.load(seq_lens + cur_batch) c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < N) if ADD_INPUTS: tiled_out = tl.load(c_ptr, mask=c_mask) tiled_c += tiled_out tl.store(c_ptr, tiled_c, mask=c_mask) @torch.inference_mode() def sgmv_expand_mlu( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, token_nums: int, add_inputs: bool = False, ) -> None: """ Args: inputs (torch.Tensor): input tensor lora_b_weights (torch.Tensor): lora'a weight output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index into sequence. E.g., if the sequence length is [4, 6], it is [0, 4, 10]. seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size max_seq_length (int): The max sequence lengths of the sequences in the batch. token_nums (int): The token numbers in the batch. Used to verify if the token numbers in the inputs matches the one in the metadata. add_inputs (bool, optional): Defaults to False, adds the final lora results to the output. """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert lora_b_weights.dtype in [ torch.float16, torch.bfloat16, ] assert inputs.size(0) == token_nums assert inputs.size(1) == lora_b_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches assert inputs.is_contiguous() assert output_tensor.is_contiguous() if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) assert lora_b_weights.size(1) == 1 lora_b_weights = lora_b_weights.squeeze(dim=1) else: assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) assert lora_b_weights.is_contiguous() # TODO tuning this config N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size ''' ============================= Modify by vllm_mlu ============================= @brief: Workaround: Adjust block size to meet mlu restrictions. The grid of mlu triton kernel must less than 65536, it will be out of bound when the input seq is very long, and causes runtime error. So we need to adjust the block size to avoid this. ''' BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32) ''' ================== End of MLU Hijack ================== ''' BLOCK_K = 16 EVEN_K = K % BLOCK_K == 0 ADD_INPUTS = add_inputs CAST_TYPE = False if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ torch.float16, torch.bfloat16, ]: CAST_TYPE = True grid = ( triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), batches, ) ''' ============================= Modify by vllm_mlu ============================= @brief: call _sgmv_expand_kernel_mlu ''' _sgmv_expand_kernel_mlu[grid]( inputs, lora_b_weights, output_tensor, N, K, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, inputs.stride(0), inputs.stride(1), lora_b_weights.stride(0), lora_b_weights.stride(1), lora_b_weights.stride(2), output_tensor.stride(0), output_tensor.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, ADD_INPUTS, CAST_TYPE, ) ''' ================== End of MLU Hijack ================== ''' return