From c404af3a41e7839e3e05fdc90a22cfdb0c2a2ef6 Mon Sep 17 00:00:00 2001 From: WANG HAO <30464411+15050188022@users.noreply.github.com> Date: Tue, 20 Jan 2026 21:27:02 +0800 Subject: [PATCH] [Feature] totaly support multi-lora support,latest xspeedgate needed (#133) Co-authored-by: wanghao --- vllm_kunlun/lora/__init__.py | 0 vllm_kunlun/lora/ops/__init__.py | 0 vllm_kunlun/lora/ops/kunlun_ops/__init__.py | 16 + vllm_kunlun/lora/ops/kunlun_ops/lora_ops.py | 133 +++++ vllm_kunlun/lora/punica_wrapper/__init__.py | 0 .../lora/punica_wrapper/punica_kunlun.py | 548 ++++++++++++++++++ vllm_kunlun/platforms/kunlun.py | 5 +- 7 files changed, 701 insertions(+), 1 deletion(-) create mode 100644 vllm_kunlun/lora/__init__.py create mode 100644 vllm_kunlun/lora/ops/__init__.py create mode 100644 vllm_kunlun/lora/ops/kunlun_ops/__init__.py create mode 100644 vllm_kunlun/lora/ops/kunlun_ops/lora_ops.py create mode 100644 vllm_kunlun/lora/punica_wrapper/__init__.py create mode 100644 vllm_kunlun/lora/punica_wrapper/punica_kunlun.py diff --git a/vllm_kunlun/lora/__init__.py b/vllm_kunlun/lora/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_kunlun/lora/ops/__init__.py b/vllm_kunlun/lora/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_kunlun/lora/ops/kunlun_ops/__init__.py b/vllm_kunlun/lora/ops/kunlun_ops/__init__.py new file mode 100644 index 0000000..ccf44c7 --- /dev/null +++ b/vllm_kunlun/lora/ops/kunlun_ops/__init__.py @@ -0,0 +1,16 @@ +"""# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project""" + +from vllm_kunlun.lora.ops.kunlun_ops.lora_ops import (bgmv_expand,bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, + sgmv_shrink) + + +__all__ = [ + "bgmv_expand", + "bgmv_expand_slice", + "bgmv_shrink", + "sgmv_expand", + "sgmv_expand_slice", + "sgmv_shrink" +] \ No newline at end of file diff --git a/vllm_kunlun/lora/ops/kunlun_ops/lora_ops.py b/vllm_kunlun/lora/ops/kunlun_ops/lora_ops.py new file mode 100644 index 0000000..e8a9fe0 --- /dev/null +++ b/vllm_kunlun/lora/ops/kunlun_ops/lora_ops.py @@ -0,0 +1,133 @@ +"""kunlun_ops for lora""" + +import torch +import xspeedgate_ops +import time +from torch._C import dtype +import os +from torch._dynamo import disable + + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + expert_m: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + """ + sgmv_shrink + """ + + + return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling) + + + +def sgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False): + """ + sgmv_expand + """ + + return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0) + + + +def sgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + normed_scale: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False): + + """ + sgmv_expand_slice + """ + + + return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, slice_offset) + + + + + + +def bgmv_shrink( + inputs: torch.Tensor, # [m, hidden_dim] + lora_a_weights: torch.Tensor, # [n, 1, r, hidden_dim] + output_tensor: torch.Tensor, # [m, r] + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + expert_m: torch.Tensor, + lora_indices_tensor: torch.Tensor, # [m] + scaling: float = 1.0 +) -> torch.Tensor: + """ + bgmv_shrink + """ + return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + """" + bgmv_expand + """ + return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0) +# @my_wrapper + +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + normed_scale: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True +): + """ + bgmv_expand_slice + """ + return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset) \ No newline at end of file diff --git a/vllm_kunlun/lora/punica_wrapper/__init__.py b/vllm_kunlun/lora/punica_wrapper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_kunlun/lora/punica_wrapper/punica_kunlun.py b/vllm_kunlun/lora/punica_wrapper/punica_kunlun.py new file mode 100644 index 0000000..87aaf19 --- /dev/null +++ b/vllm_kunlun/lora/punica_wrapper/punica_kunlun.py @@ -0,0 +1,548 @@ +# +# Copyright (c) 2025 Baidu, Inc. All Rights Reserved. +# Author: Wang Hao +# Email: wanghao129@baidu.com +# This file is a part of the vllm-kunlun project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import TYPE_CHECKING, Optional, Union, final + +import torch +# Disable torchdynamo for all functions in this file +torch._dynamo.config.disable = True + + +# SPDX-License-Identifier: Apache-2.0 +from typing import Callable, Optional, Tuple, Union + + +from vllm_kunlun.lora.ops.kunlun_ops import ( + bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink, +) + +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase + + +# The platforms that are compatible with the PyTorch-native implementation can +# inherit this class +class PunicaWrapperKunlun(PunicaWrapperBase): + """ + PunicaWrapperKunlun with moe_fc + """ + + def __init__( + self, + max_num_batched_tokens: int, + max_batches: int, + device: Union[torch.device, str], + **kwargs, + ): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + scale: float, + ): + + expert_m = torch.zeros(9, dtype=torch.int32, device=x.device) + + sgmv_shrink( + x, + w_t_all, + y, + block_statistic, + sorted_tokens_num_lod, + moe_index, + expert_m, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + scale: float, + ): + + expert_m = torch.zeros(9, dtype=torch.int32, device=x.device) + bgmv_shrink( + x, + w_t_all, + y, + block_statistic, + sorted_tokens_num_lod, + moe_index, + expert_m, + self.token_lora_indices, + scale, + ) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + add_inputs: bool, + ): + + sgmv_expand( + x, + w_t_all, + y, + block_statistic, + sorted_tokens_num_lod, + moe_index, + *self.prefill_metadata, + add_inputs, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + add_inputs: bool, + ): + bgmv_expand( + x, + w_t_all, + y, + block_statistic, + sorted_tokens_num_lod, + moe_index, + self.token_lora_indices, + add_inputs, + ) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + block_statistic, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + + normed_scale = torch.ones([y.size(0), 1], dtype=torch.float32, device=x.device) + + sgmv_expand_slice( + x, + w_t_all, + y, + block_statistic, + sorted_tokens_num_lod, + moe_index, + normed_scale, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_inputs, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + + normed_scale = torch.ones([y.size(0), 1], dtype=torch.float32, device=x.device) + + bgmv_expand_slice( + x, + w_t_all, + y, + block_statistic, + sorted_tokens_num_lod, + moe_index, + normed_scale, + self.token_lora_indices, + y_offset, + y_slice_size, + add_inputs, + ) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + block_statistic, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = ( + self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode + ) + expand_slice_fun( + y, + x, + w_t_all, + block_statistic, + sorted_tokens_num_lod, + moe_index, + y_offset, + y_slice_size, + add_inputs, + ) + + def _apply_shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + scale: float, + ): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + + shrink_fun: Callable = ( + self._shrink_prefill if self.is_prefill else self._shrink_decode + ) + + shrink_fun( + y, x, w_t_all, block_statistic, sorted_tokens_num_lod, moe_index, scale + ) + + y = y.view_as(y_org) + + def add_shrink( + self, + y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + scale: float, + **kwargs, + ): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + + for slice_idx in range(len(lora_a_stacked)): # Each slice represents a layer + + self._apply_shrink( + y[slice_idx], + x, + lora_a_stacked[slice_idx], + block_statistic, + sorted_tokens_num_lod, + moe_index, + scale, + ) + + def add_expand( + self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + + if lora_bias_stacked is not None: + self._apply_bias( + self.token_lora_indices, y, output_slices, lora_bias_stacked + ) + + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + block_statistic, + sorted_tokens_num_lod, + moe_index, + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + + y = y.view_as(y_org) + + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + expand_fun: Callable = ( + self._expand_prefill if self.is_prefill else self._expand_decode + ) + expand_fun(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs, + ) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + if self.no_lora: + return + + expert_num = 9 + block_statistic = torch.zeros( + [12, expert_num], dtype=torch.int32, device=x.device + ) + sorted_tokens_num_lod = torch.zeros( + expert_num + 1, dtype=torch.int32, device=x.device + ) + token_nums = x.size(0) + moe_index = torch.zeros(token_nums, dtype=torch.int32, device=x.device) + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias( + self.token_lora_indices, y, output_slices, lora_bias_stacked + ) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + buffer = tuple( + torch.zeros((x.size(0), r), dtype=torch.float16, device=x.device) + for _ in range(len(output_slices)) + ) + # [tensor.squeeze_(1) for tensor in lora_a_stacked] + new_lora_a_stacked = tuple(lora_a.squeeze(1) for lora_a in lora_a_stacked) + self.add_shrink( + buffer, + x, + new_lora_a_stacked, + block_statistic, + sorted_tokens_num_lod, + moe_index, + scale, + **kwargs, + ) + # [tensor.unsqueeze_(1) for tensor in lora_a_stacked] + + # [tensor.squeeze_(1) for tensor in lora_b_stacked] + new_lora_b_stacked = tuple(lora_b.squeeze(1) for lora_b in lora_b_stacked) + self.add_expand( + y, + buffer, + new_lora_b_stacked, + block_statistic, + sorted_tokens_num_lod, + moe_index, + None, + output_slices, + add_inputs=True, + **kwargs, + ) + # [tensor.unsqueeze_(1) for tensor in lora_b_stacked] + + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + + if lora_a_stacked.dim() == 2: + lora_a_stacked = lora_a_stacked.unsqueeze(0) + if lora_b_stacked.dim() == 2: + lora_b_stacked = lora_b_stacked.unsqueeze(0) + + r = lora_a_stacked.size(-1) + + if buffer is None: + buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) + + indices = self.sampler_indices + if indices.max() >= lora_a_stacked.size(0): + indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1) + + lora_a_reshaped = lora_a_stacked.transpose(1, 2) + lora_b_reshaped = lora_b_stacked.transpose(1, 2) + + bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale) + bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True) + + y = y.view_as(y_org) \ No newline at end of file diff --git a/vllm_kunlun/platforms/kunlun.py b/vllm_kunlun/platforms/kunlun.py index b79cd62..4153712 100644 --- a/vllm_kunlun/platforms/kunlun.py +++ b/vllm_kunlun/platforms/kunlun.py @@ -311,7 +311,10 @@ class KunlunPlatform(Platform): @classmethod def get_punica_wrapper(cls): - return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU" + ''' + kunlun wrapper + ''' + return "vllm_kunlun.lora.punica_wrapper.punica_kunlun.PunicaWrapperKunlun" @classmethod def check_if_supports_dtype(cls, torch_dtype: torch.dtype):