diff --git a/vllm_kunlun/lora/ops/kunlun_ops/lora_ops.py b/vllm_kunlun/lora/ops/kunlun_ops/lora_ops.py index e8a9fe0..8620628 100644 --- a/vllm_kunlun/lora/ops/kunlun_ops/lora_ops.py +++ b/vllm_kunlun/lora/ops/kunlun_ops/lora_ops.py @@ -1,86 +1,94 @@ """kunlun_ops for lora""" - + import torch -import xspeedgate_ops -import time -from torch._C import dtype -import os -from torch._dynamo import disable def sgmv_shrink( - inputs: torch.Tensor, - lora_a_weights: torch.Tensor, - output_tensor: torch.Tensor, - block_statistic: torch.Tensor, + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, sorted_tokens_num_lod: torch.Tensor, moe_index: torch.Tensor, expert_m: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - scaling: float, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, ): """ sgmv_shrink """ - - - return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling) - + return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn( + inputs, + lora_a_weights, + seq_len_tensor.to(torch.int32), + lora_indices_tensor.to(torch.int32), + output_tensor, + scaling, + ) -def sgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - block_statistic: torch.Tensor, - sorted_tokens_num_lod: torch.Tensor, - moe_index: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - add_inputs: bool = False): +def sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False, +): """ sgmv_expand """ - - return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0) - + return torch.ops.xspeedgate_ops.sgmv_expand_sdnn( + inputs, + lora_b_weights, + seq_len_tensor.to(torch.int32), + lora_indices_tensor.to(torch.int32), + output_tensor, + 0, + ) -def sgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - block_statistic: torch.Tensor, - sorted_tokens_num_lod: torch.Tensor, - moe_index: torch.Tensor, - normed_scale: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - slice_offset: int, - slice_size: int, - add_inputs: bool = False): - +def sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + normed_scale: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +): """ sgmv_expand_slice """ - - - return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, slice_offset) - - - - + return torch.ops.xspeedgate_ops.sgmv_expand_sdnn( + inputs, + lora_b_weights, + seq_len_tensor.to(torch.int32), + lora_indices_tensor.to(torch.int32), + output_tensor, + slice_offset, + ) def bgmv_shrink( @@ -92,27 +100,33 @@ def bgmv_shrink( moe_index: torch.Tensor, expert_m: torch.Tensor, lora_indices_tensor: torch.Tensor, # [m] - scaling: float = 1.0 + scaling: float = 1.0, ) -> torch.Tensor: """ bgmv_shrink """ - return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling) + return torch.ops.xspeedgate_ops.bgmv_shrink_cluster( + inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling + ) -def bgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - block_statistic: torch.Tensor, - sorted_tokens_num_lod: torch.Tensor, - moe_index: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True): - """" - bgmv_expand +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +): + """ " + bgmv_expand """ - return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0) -# @my_wrapper + return torch.ops.xspeedgate_ops.bgmv_expand_cluster( + inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0 + ) + def bgmv_expand_slice( inputs: torch.Tensor, @@ -125,9 +139,11 @@ def bgmv_expand_slice( lora_indices_tensor: torch.Tensor, slice_offset: int, slice_size: int, - add_inputs: bool = True + add_inputs: bool = True, ): """ - bgmv_expand_slice + bgmv_expand_slice """ - return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset) \ No newline at end of file + return torch.ops.xspeedgate_ops.bgmv_expand_cluster( + inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset + ) diff --git a/vllm_kunlun/lora/punica_wrapper/punica_kunlun.py b/vllm_kunlun/lora/punica_wrapper/punica_kunlun.py index 87aaf19..196aeb7 100644 --- a/vllm_kunlun/lora/punica_wrapper/punica_kunlun.py +++ b/vllm_kunlun/lora/punica_wrapper/punica_kunlun.py @@ -22,16 +22,11 @@ Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ -from typing import TYPE_CHECKING, Optional, Union, final - -import torch -# Disable torchdynamo for all functions in this file -torch._dynamo.config.disable = True - - # SPDX-License-Identifier: Apache-2.0 from typing import Callable, Optional, Tuple, Union +import torch +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase from vllm_kunlun.lora.ops.kunlun_ops import ( bgmv_expand, @@ -42,7 +37,7 @@ from vllm_kunlun.lora.ops.kunlun_ops import ( sgmv_shrink, ) -from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase +# Disable torchdynamo for all functions in this file # The platforms that are compatible with the PyTorch-native implementation can @@ -545,4 +540,4 @@ class PunicaWrapperKunlun(PunicaWrapperBase): bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale) bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True) - y = y.view_as(y_org) \ No newline at end of file + y = y.view_as(y_org) diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index b4d15a4..ebe9713 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -1,16 +1,16 @@ """vllm_utils_wrapper.py""" -import vllm.distributed.parallel_state as parallel_state -import vllm.utils as _orig -from typing import Any, Callable, Optional, Union, get_origin, get_args, List, Tuple -from types import SimpleNamespace -import torch -from torch.library import Library import inspect +import socket import typing -from torch.library import register_fake -import vllm_kunlun._kunlun +from types import SimpleNamespace +from typing import Any, Callable, List, Optional, Tuple, Union, get_args, get_origin + +import torch +import vllm.distributed.parallel_state as parallel_state import vllm.envs as envs +import vllm.utils as _orig +from torch.library import Library, register_fake def patch_annotations_for_schema(func): @@ -87,7 +87,7 @@ def direct_register_custom_op( import torch.library if hasattr(torch.library, "infer_schema"): - patched_func = patch_annotations_for_schema(op_func) + patch_annotations_for_schema(op_func) schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) else: # for pytorch 2.4 @@ -153,7 +153,7 @@ _wrapped.weak_ref_tensor = vllm_kunlun_weak_ref_tensor _wrapped.weak_ref_tensors = vllm_kunlun_weak_ref_tensors _wrapped._get_open_port = _get_open_port -import sys +import sys # noqa: E402 sys.modules["vllm.utils"] = _wrapped @@ -204,11 +204,10 @@ parallel_state.GroupCoordinator.all_reduce = vllm_kunlun_all_reduce parallel_state.GroupCoordinator.all_gather = vllm_kunlun_all_gather -from torch.library import custom_op, impl -import torch -from vllm import _custom_ops as ops -from typing import Optional, List -import os +from typing import Optional # noqa: E402 + +import torch # noqa: E402 +from torch.library import custom_op, impl # noqa: E402 @custom_op("_C::rms_norm", mutates_args=()) @@ -379,9 +378,9 @@ def silu_and_mul_quant_xpu( pass -import torch -import xtorch_ops -from torch.library import custom_op, impl +import torch # noqa: E402 +import xtorch_ops # noqa: E402 +from torch.library import custom_op, impl # noqa: E402 @custom_op("_C::add_rmsnorm", mutates_args=()) @@ -472,7 +471,7 @@ def rmsnorm_cuda( ) -import torch +import torch # noqa: E402 def _fake_rmsnorm( @@ -618,7 +617,6 @@ split_norm_rope_neox.register_fake(_fake_split_norm_rope_neox) # register fake op impl here # for torch.dynamo -from torch.library import register_fake if hasattr(torch.ops.custom_ops, "fc_fusion"): @@ -1396,7 +1394,7 @@ def awq_dequantize_cuda( device=qweight.device, ) group_m = int(qweight.shape[0] / scales.shape[0]) - out = xtorch_ops.awq_dequantize( + xtorch_ops.awq_dequantize( qweight=qweight, scales=scales, zeros=zeros, @@ -1915,7 +1913,7 @@ def apply_repetition_penalties_( @impl("_C::apply_repetition_penalties_", "CUDA") -def apply_repetition_penalties_( +def apply_repetition_penalties_cuda( logits: torch.Tensor, prompt_mask: torch.Tensor, output_mask: torch.Tensor, @@ -2341,34 +2339,499 @@ dequant_int4.register_fake(_fake_dequant_int4) ################################################## @custom_op("_C::fast_topkv2", mutates_args=()) def fast_topkv2( - score: torch.Tensor, - lengths: torch.Tensor, - topk: Optional[int] = 2048) -> torch.Tensor: + score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048 +) -> torch.Tensor: assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now" - topk_indices = xtorch_ops.fast_topkv2( - score=score, - lengths=lengths, - topk=topk) + topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk) return topk_indices + @impl("_C::fast_topkv2", "CUDA") def fast_topkv2_cuda( - score: torch.Tensor, - lengths: torch.Tensor, - topk: Optional[int] = 2048) -> torch.Tensor: + score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048 +) -> torch.Tensor: assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now" - topk_indices = xtorch_ops.fast_topkv2( - score=score, - lengths=lengths, - topk=topk) + topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk) return topk_indices + def _fake_fast_topkv2( - score: torch.Tensor, - lengths: torch.Tensor, - topk: Optional[int] = 2048) -> torch.Tensor: + score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048 +) -> torch.Tensor: assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now" topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32) return topk_indices -fast_topkv2.register_fake(_fake_fast_topkv2) \ No newline at end of file + +fast_topkv2.register_fake(_fake_fast_topkv2) + +################################################## +# ----------------- LoRA ops -------------------- +################################################## + + +################################################## +# -------------- sgmv_shrink_lora ---------------- +################################################## +@custom_op("_C::sgmv_shrink_lora", mutates_args=()) +def sgmv_shrink_lora( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + expert_m: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +) -> torch.Tensor: + # return torch.ops.xspeedgate_ops.sgmv_shrink_cluster( + # inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling + # ) + return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn( + inputs, + lora_a_weights, + seq_len_tensor, + lora_indices_tensor, + output_tensor, + scaling, + ) + + +@impl("_C::sgmv_shrink_lora", "CUDA") +def sgmv_shrink_lora_cuda( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + expert_m: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +) -> torch.Tensor: + # return torch.ops.xspeedgate_ops.sgmv_shrink_cluster( + # inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling + # ) + return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn( + inputs, + lora_a_weights, + seq_len_tensor, + lora_indices_tensor, + output_tensor, + scaling, + ) + + +def _fake_sgmv_shrink_lora( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + expert_m: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +) -> torch.Tensor: + return output_tensor + + +sgmv_shrink_lora.register_fake(_fake_sgmv_shrink_lora) + + +################################################## +# -------------- sgmv_expand_lora ---------------- +################################################## +@custom_op("_C::sgmv_expand_lora", mutates_args=()) +def sgmv_expand_lora( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False, +) -> torch.Tensor: + # return torch.ops.xspeedgate_ops.sgmv_expand_cluster( + # inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0 + # ) + return torch.ops.xspeedgate_ops.sgmv_expand_sdnn( + inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0 + ) + + +@impl("_C::sgmv_expand_lora", "CUDA") +def sgmv_expand_lora_cuda( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False, +) -> torch.Tensor: + # return torch.ops.xspeedgate_ops.sgmv_expand_cluster( + # inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0 + # ) + return torch.ops.xspeedgate_ops.sgmv_expand_sdnn( + inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0 + ) + + +def _fake_sgmv_expand_lora( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False, +) -> torch.Tensor: + return output_tensor + + +sgmv_expand_lora.register_fake(_fake_sgmv_expand_lora) + + +################################################## +# ----------- sgmv_expand_slice_lora ------------- +################################################## +@custom_op("_C::sgmv_expand_slice_lora", mutates_args=()) +def sgmv_expand_slice_lora( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + normed_scale: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +) -> torch.Tensor: + return torch.ops.xspeedgate_ops.sgmv_expand_cluster( + inputs, + lora_b_weights, + seq_len_tensor, + lora_indices_tensor, + output_tensor, + slice_offset, + ) + + +@impl("_C::sgmv_expand_slice_lora", "CUDA") +def sgmv_expand_slice_lora_cuda( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + normed_scale: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +) -> torch.Tensor: + return torch.ops.xspeedgate_ops.sgmv_expand_cluster( + inputs, + lora_b_weights, + seq_len_tensor, + lora_indices_tensor, + output_tensor, + slice_offset, + ) + + +def _fake_sgmv_expand_slice_lora( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + normed_scale: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +) -> torch.Tensor: + return output_tensor + + +sgmv_expand_slice_lora.register_fake(_fake_sgmv_expand_slice_lora) + + +################################################## +# -------------- bgmv_shrink_lora ---------------- +################################################## +@custom_op("_C::bgmv_shrink_lora", mutates_args=()) +def bgmv_shrink_lora( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + expert_m: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +) -> torch.Tensor: + return torch.ops.xspeedgate_ops.bgmv_shrink_cluster( + inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling + ) + + +@impl("_C::bgmv_shrink_lora", "CUDA") +def bgmv_shrink_lora_cuda( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + expert_m: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +) -> torch.Tensor: + return torch.ops.xspeedgate_ops.bgmv_shrink_cluster( + inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling + ) + + +def _fake_bgmv_shrink_lora( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + expert_m: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +) -> torch.Tensor: + return output_tensor + + +bgmv_shrink_lora.register_fake(_fake_bgmv_shrink_lora) + + +################################################## +# -------------- bgmv_expand_lora ---------------- +################################################## +@custom_op("_C::bgmv_expand_lora", mutates_args=()) +def bgmv_expand_lora( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +) -> torch.Tensor: + return torch.ops.xspeedgate_ops.bgmv_expand_cluster( + inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0 + ) + + +@impl("_C::bgmv_expand_lora", "CUDA") +def bgmv_expand_lora_cuda( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +) -> torch.Tensor: + return torch.ops.xspeedgate_ops.bgmv_expand_cluster( + inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0 + ) + + +def _fake_bgmv_expand_lora( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +) -> torch.Tensor: + return output_tensor + + +bgmv_expand_lora.register_fake(_fake_bgmv_expand_lora) + + +################################################## +# ----------- bgmv_expand_slice_lora ------------- +################################################## +@custom_op("_C::bgmv_expand_slice_lora", mutates_args=()) +def bgmv_expand_slice_lora( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + normed_scale: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +) -> torch.Tensor: + return torch.ops.xspeedgate_ops.bgmv_expand_cluster( + inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset + ) + + +@impl("_C::bgmv_expand_slice_lora", "CUDA") +def bgmv_expand_slice_lora_cuda( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + normed_scale: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +) -> torch.Tensor: + return torch.ops.xspeedgate_ops.bgmv_expand_cluster( + inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset + ) + + +def _fake_bgmv_expand_slice_lora( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + block_statistic: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + moe_index: torch.Tensor, + normed_scale: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +) -> torch.Tensor: + return output_tensor + + +bgmv_expand_slice_lora.register_fake(_fake_bgmv_expand_slice_lora) + + +################################################## +# ----------- lora_matmul_inplace ---------------- +################################################## +@custom_op("_C::lora_matmul_inplace", mutates_args=()) +def lora_matmul_inplace( + x: torch.Tensor, + w: torch.Tensor, + output_tensor: torch.Tensor, + x_trans: bool = False, + w_trans: bool = True, + alpha: float = 1.0, + beta: float = 1.0, +) -> None: + xtorch_ops.matmul( + x=x.contiguous(), + w=w.contiguous(), + out=output_tensor, + x_trans=x_trans, + w_trans=w_trans, + alpha=alpha, + beta=beta, + ) + + +@impl("_C::lora_matmul_inplace", "CUDA") +def lora_matmul_inplace_cuda( + x: torch.Tensor, + w: torch.Tensor, + output_tensor: torch.Tensor, + x_trans: bool = False, + w_trans: bool = True, + alpha: float = 1.0, + beta: float = 1.0, +) -> None: + xtorch_ops.matmul( + x=x.contiguous(), + w=w.contiguous(), + out=output_tensor, + x_trans=x_trans, + w_trans=w_trans, + alpha=alpha, + beta=beta, + ) + + +def _fake_lora_matmul_inplace( + x: torch.Tensor, + w: torch.Tensor, + output_tensor: torch.Tensor, + x_trans: bool = False, + w_trans: bool = True, + alpha: float = 1.0, + beta: float = 1.0, +) -> None: + return None + + +lora_matmul_inplace.register_fake(_fake_lora_matmul_inplace)