diff --git a/vllm_kunlun/models/qwen3_next.py b/vllm_kunlun/models/qwen3_next.py index a665b18..d34589a 100644 --- a/vllm_kunlun/models/qwen3_next.py +++ b/vllm_kunlun/models/qwen3_next.py @@ -70,7 +70,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, from vllm_kunlun.ops.activation import SiluAndMul from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask -import xtorch_ops +import kunlun_ops @torch.compile(dynamic=True, backend="aot_eager") @@ -640,7 +640,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype).view( last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1]) cast_ssm_state = ssm_state.view(ssm_state.shape[0], 1, -1, ssm_state.shape[-1]) - xtorch_ops.reshape_and_cache_flash( + kunlun_ops.reshape_and_cache_flash( last_recurrent_state, last_recurrent_state, cast_ssm_state, diff --git a/vllm_kunlun/models/qwen3_vl.py b/vllm_kunlun/models/qwen3_vl.py index d695dcb..3cf8c5e 100644 --- a/vllm_kunlun/models/qwen3_vl.py +++ b/vllm_kunlun/models/qwen3_vl.py @@ -85,7 +85,7 @@ from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM, Qwen3Model from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, maybe_prefix, merge_multimodal_embeddings) from vllm.model_executor.models.vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model -import xtorch_ops +import kunlun_ops from einops import repeat logger = init_logger(__name__) diff --git a/vllm_kunlun/ops/_kunlun_ops.py b/vllm_kunlun/ops/_kunlun_ops.py index 1d86690..25495db 100644 --- a/vllm_kunlun/ops/_kunlun_ops.py +++ b/vllm_kunlun/ops/_kunlun_ops.py @@ -28,7 +28,7 @@ from vllm.logger import init_logger logger = init_logger(__name__) try: - import xtorch_ops + import kunlun_ops logger.info(f"Load custom ops library success!") except ImportError as e: logger.warning("Import error msg: %s", e.msg) @@ -71,7 +71,7 @@ class KunlunOps: ): """ PagedAttentionV1 """ # block_size = value_cache.shape[2] - xtorch_ops.paged_attention( + kunlun_ops.paged_attention( x=query, k_cache=key_cache, v_cache=value_cache, @@ -114,7 +114,7 @@ class KunlunOps: ): """ PagedAttentionV2 """ # block_size = value_cache.shape[2] - xtorch_ops.paged_attention( + kunlun_ops.paged_attention( x=query, k_cache=key_cache, v_cache=value_cache, @@ -133,7 +133,7 @@ class KunlunOps: def silu_and_mul(out: torch.Tensor, x: torch.Tensor): """ silu and mul """ - xtorch_ops.silu_and_mul( + kunlun_ops.silu_and_mul( x, axis=-1, turn=True, @@ -145,7 +145,7 @@ class KunlunOps: def quick_gelu(out: torch.Tensor, x: torch.Tensor): """ quick gelu """ - xtorch_ops.quick_gelu( + kunlun_ops.quick_gelu( x, out=out, ) @@ -159,7 +159,7 @@ class KunlunOps: epsilon, ): """rms_norm""" - xtorch_ops.rmsnorm( + kunlun_ops.rmsnorm( x, weight.to(torch.float32), epsilon, out=out ) @@ -172,7 +172,7 @@ class KunlunOps: ): """fused_add_rms_norm""" output = torch.empty_like(x) - xtorch_ops.add_rmsnorm( + kunlun_ops.add_rmsnorm( x, residual, weight.to(torch.float32), epsilon, out=output ) fused_input = x + residual @@ -222,7 +222,7 @@ class KunlunOps: key_x = key.contiguous() query_x_dim = query_x.dim() assert is_neox_style - xtorch_ops.mrotary_embedding_neox( + kunlun_ops.mrotary_embedding_neox( positions, query_x, key_x, @@ -240,7 +240,7 @@ class KunlunOps: dst, block_mapping): """ swap_blocks """ - xtorch_ops.swap_blocks( + kunlun_ops.swap_blocks( src, dst, block_mapping @@ -255,7 +255,7 @@ class KunlunOps: for i in range(len(key_caches)): key_caches[i] = key_caches[i].contiguous() value_caches[i] = value_caches[i].contiguous() - xtorch_ops.copy_blocks( + kunlun_ops.copy_blocks( key_caches, value_caches, block_mapping, @@ -272,7 +272,7 @@ class KunlunOps: ): """ reshape_and_cache """ # slot_mapping_cast = slot_mapping.to(torch.int32) - xtorch_ops.reshape_and_cache( + kunlun_ops.reshape_and_cache( key, value, key_cache, @@ -308,7 +308,7 @@ class KunlunOps: repeat = Qh // KVh key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd] value = value.repeat_interleave(repeat, dim=2) - xtorch_ops.attention( + kunlun_ops.attention( q=query, k_cache=key, v_cache=value, @@ -337,7 +337,7 @@ class KunlunOps: else: out_scale = torch.empty(12, device=x.device, dtype=torch.float) - xtorch_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps, + kunlun_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps, out=out, out_scale=out_scale , residual_tensor=residual) if residual is None: @@ -360,7 +360,7 @@ class KunlunOps: else: out_scale = torch.empty(12, device=x.device, dtype=torch.float) - xtorch_ops.quant_rmsnorm(x, weight, bias, eps, + kunlun_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale) return out, out_scale @@ -388,7 +388,7 @@ class KunlunOps: dtype=torch.float16, device=weight.device) output_bs_shape = [-1] - xtorch_ops.smooth_quant_matmul_column_row_kernels(input_tensor, + kunlun_ops.smooth_quant_matmul_column_row_kernels(input_tensor, weight, smoother, input_scale, weight_scale, @@ -642,7 +642,7 @@ class KunlunOps: """mla pa block""" output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device) - xtorch_ops.xft_multi_head_latent_page_attention_block( + kunlun_ops.xft_multi_head_latent_page_attention_block( hidden_states, q_lora_rank, kv_lora_rank, @@ -688,7 +688,7 @@ class KunlunOps: threshold: float = 20.0, ) -> torch.Tensor: """fused_gdn_gating""" - output = xtorch_ops.fused_gdn_gating( + output = kunlun_ops.fused_gdn_gating( A_log, a, dt_bias, @@ -713,7 +713,7 @@ class KunlunOps: 2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。 ''' - o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwd( + o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd( q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel, cu_seqlens) return (o, final_state) \ No newline at end of file diff --git a/vllm_kunlun/ops/activation.py b/vllm_kunlun/ops/activation.py index 6719808..992a2d2 100644 --- a/vllm_kunlun/ops/activation.py +++ b/vllm_kunlun/ops/activation.py @@ -93,7 +93,7 @@ class SiluAndMul(CustomOp): def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: """forward_cuda""" - import xtorch_ops + import kunlun_ops d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) @@ -103,7 +103,7 @@ class SiluAndMul(CustomOp): def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor: """forward_kunlun""" - import xtorch_ops + import kunlun_ops d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) @@ -251,14 +251,14 @@ class GeluAndMul(CustomOp): 无。 """ # from vllm import _custom_ops as ops - import xtorch_ops + import kunlun_ops # d = x.shape[-1] // 2 # output_shape = (x.shape[:-1] + (d, )) out = torch.empty(x, dtype=x.dtype, device=x.device) if self.approximate == "none": # ops.gelu_and_mul(out, x) print(x,x.shape) - xtorch_ops.gelu(x, out) + kunlun_ops.gelu(x, out) elif self.approximate == "tanh": ops.gelu_tanh_and_mul(out, x) return out diff --git a/vllm_kunlun/ops/attention/flashmla.py b/vllm_kunlun/ops/attention/flashmla.py index 0eeb4a5..b643e65 100644 --- a/vllm_kunlun/ops/attention/flashmla.py +++ b/vllm_kunlun/ops/attention/flashmla.py @@ -7,7 +7,7 @@ import torch from vllm.logger import init_logger from vllm.platforms import current_platform -import xtorch_ops +import kunlun_ops logger = init_logger(__name__) @@ -104,7 +104,7 @@ def flash_mla_with_kvcache( is_context = False vo_head_dim = -1 - xtorch_ops.paged_attention(out, + kunlun_ops.paged_attention(out, q, k_cache, None, block_table, @@ -149,7 +149,7 @@ def kunlun_flash_mla_with_kvcache( p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32. """ assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache." - assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now." + assert q.shape[1] <= 2, "kunlun_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now." if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if indices is not None: diff --git a/vllm_kunlun/ops/attention/merge_attn_states.py b/vllm_kunlun/ops/attention/merge_attn_states.py index aaab8ad..bd5fff4 100644 --- a/vllm_kunlun/ops/attention/merge_attn_states.py +++ b/vllm_kunlun/ops/attention/merge_attn_states.py @@ -3,7 +3,7 @@ from typing import Optional import torch -import xtorch_ops +import kunlun_ops from vllm.platforms import current_platform @@ -16,7 +16,7 @@ def merge_attn_states( output_lse: Optional[torch.Tensor] = None, ) -> None: - return xtorch_ops.attention_merge_stage( + return kunlun_ops.attention_merge_stage( prefix_output, prefix_lse, suffix_output, diff --git a/vllm_kunlun/ops/fla/fused_recurrent.py b/vllm_kunlun/ops/fla/fused_recurrent.py index 3902bee..d3c81b1 100644 --- a/vllm_kunlun/ops/fla/fused_recurrent.py +++ b/vllm_kunlun/ops/fla/fused_recurrent.py @@ -11,7 +11,7 @@ from typing import Optional import torch -import xtorch_ops +import kunlun_ops class FusedRecurrentFunction(torch.autograd.Function): @@ -31,7 +31,7 @@ class FusedRecurrentFunction(torch.autograd.Function): num_accepted_tokens: Optional[torch.Tensor] = None, use_qk_l2norm_in_kernel: bool = False): - o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwdv2( + o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2( q.contiguous(), k.contiguous(), v.contiguous(), diff --git a/vllm_kunlun/ops/fla/l2norm.py b/vllm_kunlun/ops/fla/l2norm.py index 8f0ed8b..55cc061 100644 --- a/vllm_kunlun/ops/fla/l2norm.py +++ b/vllm_kunlun/ops/fla/l2norm.py @@ -13,7 +13,7 @@ from typing import Optional import torch from vllm.triton_utils import tl, triton -import xtorch_ops +import kunlun_ops BT_LIST = [8, 16, 32, 64, 128] @@ -149,5 +149,5 @@ def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None): out = torch.empty_like(x) - xtorch_ops.l2norm(x, out, eps) + kunlun_ops.l2norm(x, out, eps) return out diff --git a/vllm_kunlun/ops/layernorm.py b/vllm_kunlun/ops/layernorm.py index 85b8124..b65dc76 100644 --- a/vllm_kunlun/ops/layernorm.py +++ b/vllm_kunlun/ops/layernorm.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm from vllm.model_executor.layers import layernorm from typing import Optional, Union -import xtorch_ops +import kunlun_ops def vllm_kunlun_forward_cuda( self, diff --git a/vllm_kunlun/ops/mamba/causal_conv1d.py b/vllm_kunlun/ops/mamba/causal_conv1d.py index 9080507..4bc6416 100644 --- a/vllm_kunlun/ops/mamba/causal_conv1d.py +++ b/vllm_kunlun/ops/mamba/causal_conv1d.py @@ -12,7 +12,7 @@ import torch.nn.functional as F from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.triton_utils import tl, triton -import xtorch_ops +import kunlun_ops @triton.jit() @@ -1212,7 +1212,7 @@ def torch_causal_conv1d_update( tmp_hidden_states = hidden_states_new[:, :, -state_len:] ori_shape = tmp_hidden_states.shape tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(ori_shape) - xtorch_ops.reshape_and_cache_flash( + kunlun_ops.reshape_and_cache_flash( tmp_hidden_states, tmp_hidden_states, cast_conv_state, diff --git a/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py index d9d810d..122dab9 100644 --- a/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py @@ -113,7 +113,7 @@ class KunlunCompressedTensorsMoEMethod(FusedMoEMethodBase): class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # NOTE: xtorch_ops use max as scale + # NOTE: kunlun_ops use max as scale with torch.no_grad(): layer.w13_weight_scale.mul_(127.0) layer.w2_weight_scale.mul_(127.0) diff --git a/vllm_kunlun/v1/attention/backends/kunlun_attn.py b/vllm_kunlun/v1/attention/backends/kunlun_attn.py index edf3935..93a4022 100644 --- a/vllm_kunlun/v1/attention/backends/kunlun_attn.py +++ b/vllm_kunlun/v1/attention/backends/kunlun_attn.py @@ -28,9 +28,9 @@ from typing import ( TypeVar, ) +import kunlun_ops import numpy as np import torch -import xtorch_ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, @@ -39,6 +39,7 @@ from vllm.attention.backends.abstract import ( AttentionType, ) from vllm.config import VllmConfig +from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionCGSupport, CommonAttentionMetadata, @@ -227,9 +228,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): def __post_init__(self): """__post_init__""" - self.attn_bias: Optional[List[AttentionBias]] = None - self.encoder_attn_bias: Optional[List[AttentionBias]] = None - self.cross_attn_bias: Optional[List[AttentionBias]] = None + self.attn_bias: Optional[List[AttentionBias]] = None # noqa: F821 + self.encoder_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821 + self.cross_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821 @property def is_all_encoder_attn_metadata_set(self): @@ -572,12 +573,11 @@ class KunlunAttentionMetadataBuilder: """build""" num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens - max_query_len = common_attn_metadata.max_query_len + common_prefix_len = common_prefix_len block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to( self.device, non_blocking=True @@ -771,7 +771,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): # not cached. This happens during the initial memory value = value.contiguous() if key_cache.is_contiguous(): - xtorch_ops.reshape_and_cache( + kunlun_ops.reshape_and_cache( key[: attn_metadata.num_actual_tokens], value[: attn_metadata.num_actual_tokens], key_cache, @@ -781,7 +781,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): else: cast_key_cache = key_cache.squeeze(1).unsqueeze(-2) cast_value_cache = value_cache.squeeze(1).unsqueeze(-2) - xtorch_ops.reshape_and_cache_flash( + kunlun_ops.reshape_and_cache_flash( key, value, cast_key_cache, @@ -791,7 +791,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): assert attn_type == AttentionType.DECODER # Decoder self-attention supports chunked prefill. - num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens # Only enforce this shape-constraint for decoder # self-attention @@ -811,7 +810,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): # Prefix cache if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]: - xtorch_ops.prefill_attention( + kunlun_ops.prefill_attention( q=prefill_query, k=key_cache, # Key Cache [block_num, head, block_size, dim] v=value_cache, @@ -827,7 +826,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): softmax_lse=None, ) else: - xtorch_ops.prefill_attention( + kunlun_ops.prefill_attention( q=prefill_query, k=prefill_key, v=prefill_value, @@ -860,9 +859,9 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): decode_meta.block_tables * 2 ) # only test in Qwen3-Next - sig = inspect.signature(xtorch_ops.speculative_attention) + sig = inspect.signature(kunlun_ops.speculative_attention) if "max_window_size" in sig.parameters: - xtorch_ops.speculative_attention( + kunlun_ops.speculative_attention( out=output[:num_decode_tokens], # Only MLA support q len > 1 right now q=decode_query.unsqueeze(0), @@ -890,7 +889,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): ), ) elif not attn_metadata.is_speculative: - xtorch_ops.paged_attention( + kunlun_ops.paged_attention( x=decode_query, k_cache=key_cache, v_cache=value_cache, @@ -910,7 +909,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): out = output[:num_decode_tokens] assert out.is_contiguous() - xtorch_ops.speculative_attention( + kunlun_ops.speculative_attention( out=out.view(batch_size, qlen, head_num, self.head_size), q=decode_query.view(batch_size, qlen, head_num, head_dim), k_cache=key_cache, diff --git a/vllm_kunlun/v1/attention/backends/mla/common.py b/vllm_kunlun/v1/attention/backends/mla/common.py index 60c901d..fa9cec4 100644 --- a/vllm_kunlun/v1/attention/backends/mla/common.py +++ b/vllm_kunlun/v1/attention/backends/mla/common.py @@ -220,7 +220,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, infer_global_hyperparameters, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec -import xtorch_ops +import kunlun_ops try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -1106,7 +1106,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): ) * q_len sorted_tokens_idx = torch.arange( self.num_heads * q_len, dtype=torch.int, device="cuda") - xtorch_ops.mla_bmm_I8( + kunlun_ops.mla_bmm_I8( x.contiguous(), # [1, 16, 512] torch.float16 self.W_UV, # [16, 128, 512] torch.int8 self.W_UV_SCALE, # [2048, 1] torch.float32 @@ -1220,7 +1220,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): tp_q_head_num=q.size(1) softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device) softmax_lse.fill_(float('-inf')) - xtorch_ops.attention( + kunlun_ops.attention( q=q, k_cache=k, v_cache=maybe_padded_v, @@ -1406,7 +1406,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): self.W_UK_T = W_UK.transpose(1, 2).contiguous() self.W_UK_SCALE = torch.empty([W_UK.shape[0] * W_UK.shape[2], 1], dtype=torch.float, device=kv_b_proj_weight.device) - xtorch_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE) + kunlun_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE) self.W_UV = W_UV.contiguous() self.W_UV_SCALE = W_UV_SCALE.contiguous().reshape(-1, 1) else: @@ -1836,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # write the latent and rope to kv cache if kv_cache.numel() > 0: - xtorch_ops.concat_and_cache_mla( + kunlun_ops.concat_and_cache_mla( k_c_normed, k_pe.squeeze(1), attn_metadata.slot_mapping.flatten(), @@ -1885,7 +1885,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): sorted_tokens_idx = torch.arange( self.num_heads * q_len, dtype=torch.int, device="cuda") extra_params = {"trans": False} - xtorch_ops.mla_bmm_I8( + kunlun_ops.mla_bmm_I8( decode_q_nope.contiguous(), self.W_UK_T, self.W_UK_SCALE, diff --git a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py index 8d904e5..d29e7de 100644 --- a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py +++ b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py @@ -10,7 +10,7 @@ from packaging import version from vllm import envs from vllm.logger import init_logger from vllm.platforms import current_platform -import xtorch_ops +import kunlun_ops import os logger = init_logger(__name__) @@ -200,16 +200,16 @@ def flashinfer_sample( probs = logits.softmax(dim=-1, dtype=torch.float32) if k is None: # Top-p only. - next_token_ids = xtorch_ops.top_p_sampling_from_probs( + next_token_ids = kunlun_ops.top_p_sampling_from_probs( probs,top_p=p, deterministic=True) elif p is None: # Top-k only. - next_token_ids = xtorch_ops.top_k_sampling_from_probs( + next_token_ids = kunlun_ops.top_k_sampling_from_probs( probs, top_k=k, deterministic=True) else: # Both top-k and top-p. k = k.to(torch.int32) - next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs( + next_token_ids = kunlun_ops.top_k_top_p_sampling_from_probs( probs, top_k=k, top_p=p, deterministic=True) return next_token_ids.view(-1) diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index c0606a7..d738fb9 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -405,7 +405,7 @@ def add_rmsnorm( residual_output: torch.Tensor = None, output_max: torch.Tensor = None, ) -> None: - xtorch_ops.add_rmsnorm( + kunlun_ops.add_rmsnorm( x, y, # 原来写 residual,这里其实是 y residual_output=residual_output, @@ -429,7 +429,7 @@ def add_rmsnorm_cuda( residual_output: torch.Tensor = None, output_max: torch.Tensor = None, ) -> None: - xtorch_ops.add_rmsnorm( + kunlun_ops.add_rmsnorm( x, y, residual_output=residual_output, @@ -451,7 +451,7 @@ def rmsnorm( residual_output: torch.Tensor = None, output_max: torch.Tensor = None, ) -> None: - xtorch_ops.rmsnorm( + kunlun_ops.rmsnorm( x, weight, output, @@ -471,7 +471,7 @@ def rmsnorm_cuda( residual_output: torch.Tensor = None, output_max: torch.Tensor = None, ) -> None: - xtorch_ops.rmsnorm( + kunlun_ops.rmsnorm( x, weight, output, @@ -541,7 +541,7 @@ def split_norm_rope_neox( rotary_dim: int, emb_batch_size: int = 1, ) -> None: - xtorch_ops.split_norm_rope_neox( + kunlun_ops.split_norm_rope_neox( q_emb, k_emb, v_out, @@ -577,7 +577,7 @@ def split_norm_rope_neox_cuda( rotary_dim: int, emb_batch_size: int = 1, ) -> None: - xtorch_ops.split_norm_rope_neox( + kunlun_ops.split_norm_rope_neox( q_emb, k_emb, v_out, @@ -649,7 +649,7 @@ if hasattr(torch.ops.custom_ops, "fc_fusion"): def silu_and_mul( out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True ) -> None: - xtorch_ops.swiglu( + kunlun_ops.swiglu( x=x, y=out, ) @@ -659,7 +659,7 @@ def silu_and_mul( def silu_and_mul_cuda( out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True ) -> None: - xtorch_ops.swiglu( + kunlun_ops.swiglu( x=x, y=out, ) @@ -736,7 +736,7 @@ def moe_softmax_topk( axis: int = -1, turn: bool = True, ) -> None: - xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic) + kunlun_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic) @impl("_C::moe_softmax_topk", "CUDA") @@ -748,7 +748,7 @@ def moe_softmax_topk_cuda( axis: int = -1, turn: bool = True, ) -> None: - xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic) + kunlun_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic) def _fake_moe_softmax_topk( @@ -781,7 +781,7 @@ def moe_ffn_block( w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> None: - xtorch_ops.moe_ffn_block( + kunlun_ops.moe_ffn_block( x=x, gate_w=gate_w, inter_w=inter_w, @@ -812,7 +812,7 @@ def moe_ffn_block_cuda( w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> None: - xtorch_ops.moe_ffn_block( + kunlun_ops.moe_ffn_block( x=x, gate_w=gate_w, inter_w=inter_w, @@ -863,7 +863,7 @@ def moe_ffn_per_token_block( ep_size: int = 1, ep_rank: int = 0, ) -> None: - xtorch_ops.moe_ffn_per_token_block( + kunlun_ops.moe_ffn_per_token_block( x=x, inter_weight=inter_weight, inter_scale=inter_scale, @@ -897,7 +897,7 @@ def moe_ffn_per_token_block_cuda( ep_size: int = 1, ep_rank: int = 0, ) -> None: - xtorch_ops.moe_ffn_per_token_block( + kunlun_ops.moe_ffn_per_token_block( x=x, inter_weight=inter_weight, inter_scale=inter_scale, @@ -948,7 +948,7 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - xtorch_ops.rotary_embedding( + kunlun_ops.rotary_embedding( positions=positions, query=query, key=key, @@ -967,7 +967,7 @@ def rotary_embedding_cuda( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - xtorch_ops.rotary_embedding( + kunlun_ops.rotary_embedding( positions=positions, query=query, key=key, @@ -999,7 +999,7 @@ def gemm_I8_I8_bf16_nt( weight_scale: torch.Tensor, out: torch.Tensor, ) -> None: - xtorch_ops.gemm_I8_I8_bf16_nt( + kunlun_ops.gemm_I8_I8_bf16_nt( lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out ) @@ -1012,7 +1012,7 @@ def gemm_I8_I8_bf16_nt_cuda( weight_scale: torch.Tensor, out: torch.Tensor, ) -> None: - xtorch_ops.gemm_I8_I8_bf16_nt( + kunlun_ops.gemm_I8_I8_bf16_nt( lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out ) @@ -1038,7 +1038,7 @@ def moe_softmax_topk_norm( block_statistic: torch.Tensor, stable: bool = True, ) -> None: - xtorch_ops.moe_softmax_topk_norm( + kunlun_ops.moe_softmax_topk_norm( x, normed_score, topk_index, block_statistic, stable ) @@ -1051,7 +1051,7 @@ def moe_softmax_topk_norm_cuda( block_statistic: torch.Tensor, stable: bool = True, ) -> None: - xtorch_ops.moe_softmax_topk_norm( + kunlun_ops.moe_softmax_topk_norm( x, normed_score, topk_index, block_statistic, stable ) @@ -1071,14 +1071,14 @@ moe_softmax_topk_norm.register_fake(_fake_moe_softmax_topk_norm) @custom_op("_C::gen_block_statistic", mutates_args=()) def gen_block_statistic(topk_ids: torch.Tensor, block_statistic: torch.Tensor) -> None: - xtorch_ops.gen_block_statistic(topk_ids, block_statistic) + kunlun_ops.gen_block_statistic(topk_ids, block_statistic) @impl("_C::gen_block_statistic", "CUDA") def gen_block_statistic_cuda( topk_ids: torch.Tensor, block_statistic: torch.Tensor ) -> None: - xtorch_ops.gen_block_statistic(topk_ids, block_statistic) + kunlun_ops.gen_block_statistic(topk_ids, block_statistic) def fake_gen_block_statistic( @@ -1101,7 +1101,7 @@ def moe_pre_sorted( sorted_tokens_num_lod: torch.Tensor, index_have_neg: bool = False, ) -> None: - xtorch_ops.moe_pre_sorted( + kunlun_ops.moe_pre_sorted( x, topk_index, block_statistic, @@ -1123,7 +1123,7 @@ def moe_pre_sorted_cuda( sorted_tokens_num_lod: torch.Tensor, index_have_neg: bool = False, ) -> None: - xtorch_ops.moe_pre_sorted( + kunlun_ops.moe_pre_sorted( x, topk_index, block_statistic, @@ -1171,7 +1171,7 @@ def moe_fc( use_pack_int4: Optional[bool] = False, sort_mode: Optional[bool] = True, ) -> None: - xtorch_ops.moe_fc( + kunlun_ops.moe_fc( x=x, weight=weight, sorted_tokens_num_lod=sorted_tokens_num_lod, @@ -1214,7 +1214,7 @@ def moe_fc_cuda( use_pack_int4: Optional[bool] = False, sort_mode: Optional[bool] = True, ) -> None: - xtorch_ops.moe_fc( + kunlun_ops.moe_fc( x=x, weight=weight, sorted_tokens_num_lod=sorted_tokens_num_lod, @@ -1270,7 +1270,7 @@ def moe_post( dequant_scale: torch.Tensor, y: torch.Tensor, ) -> None: - xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y) + kunlun_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y) @impl("_C::moe_post", "CUDA") @@ -1281,7 +1281,7 @@ def moe_post_cuda( dequant_scale: torch.Tensor, y: torch.Tensor, ) -> None: - xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y) + kunlun_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y) def fake_moe_post( @@ -1308,7 +1308,7 @@ def moe_sigmoid_group_topk_norm( n_group: int, topk_group: int, ) -> None: - xtorch_ops.moe_sigmoid_group_topk_norm( + kunlun_ops.moe_sigmoid_group_topk_norm( x=x, norm_score=norm_score, topk_index=topk_index, @@ -1331,7 +1331,7 @@ def moe_sigmoid_group_topk_norm_cuda( n_group: int, topk_group: int, ) -> None: - xtorch_ops.moe_sigmoid_group_topk_norm( + kunlun_ops.moe_sigmoid_group_topk_norm( x=x, norm_score=norm_score, topk_index=topk_index, @@ -1376,7 +1376,7 @@ def awq_dequantize( device=qweight.device, ) group_m = int(qweight.shape[0] / scales.shape[0]) - xtorch_ops.awq_dequantize( + kunlun_ops.awq_dequantize( qweight=qweight, scales=scales, zeros=zeros, @@ -1402,7 +1402,7 @@ def awq_dequantize_cuda( device=qweight.device, ) group_m = int(qweight.shape[0] / scales.shape[0]) - xtorch_ops.awq_dequantize( + out = kunlun_ops.awq_dequantize( qweight=qweight, scales=scales, zeros=zeros, @@ -1447,7 +1447,7 @@ def awq_gemm( (x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device ) group_size = int(qweight.shape[0] / scale.shape[0]) - xtorch_ops.awq_gemm( + kunlun_ops.awq_gemm( x=x, w=qweight, scale=scale, @@ -1471,7 +1471,7 @@ def awq_gemm_cuda( (x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device ) group_size = int(qweight.shape[0] / scale.shape[0]) - xtorch_ops.awq_gemm( + kunlun_ops.awq_gemm( x=x, w=qweight, scale=scale, @@ -1508,7 +1508,7 @@ def gptq_shuffle( q_perm: torch.Tensor, bit: int, ) -> None: - xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit) + kunlun_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit) @impl("_C::gptq_shuffle", "CUDA") @@ -1517,7 +1517,7 @@ def gptq_shuffle_cuda( q_perm: torch.Tensor, bit: int, ) -> None: - xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit) + kunlun_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit) def _fake_gptq_shuffle( @@ -1541,7 +1541,7 @@ def concat_and_cache_mla( kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)] slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens] ) -> None: - xtorch_ops.concat_and_cache_mla( + kunlun_ops.concat_and_cache_mla( kv_c=kv_c, k_pe=k_pe, slot_mapping=slot_mapping, @@ -1556,7 +1556,7 @@ def concat_and_cache_mla_cuda( kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)] slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens] ) -> None: - xtorch_ops.concat_and_cache_mla( + kunlun_ops.concat_and_cache_mla( kv_c=kv_c, k_pe=k_pe, slot_mapping=slot_mapping, @@ -1598,7 +1598,7 @@ def scaled_int8_quant( azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32) if symmetric: # NOTE: For quant2d ops, scale represents max. - xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True) + kunlun_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True) else: torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant( x_q, x.contiguous(), scale, azp @@ -1625,7 +1625,7 @@ def scaled_int8_quant_cuda( azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32) if symmetric: # NOTE: For quant2d ops, scale represents max. - xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True) + kunlun_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True) else: torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant( x_q, x.contiguous(), scale, azp @@ -1777,7 +1777,7 @@ def matmul( dtype=out_dtype, device=x.device, ) - xtorch_ops.matmul( + kunlun_ops.matmul( x=x.contiguous(), w=w.contiguous(), out=out, @@ -1814,7 +1814,7 @@ def matmul_cuda( dtype=out_dtype, device=x.device, ) - xtorch_ops.matmul( + kunlun_ops.matmul( x=x.contiguous(), w=w.contiguous(), out=out, @@ -1865,7 +1865,7 @@ def quant2d( max: torch.Tensor, force_sdnn: bool = False, ) -> None: - xtorch_ops.quant2d( + kunlun_ops.quant2d( x=x, y=x_q, max=max, @@ -1880,7 +1880,7 @@ def quant2d_cuda( max: torch.Tensor, force_sdnn: bool = False, ) -> None: - xtorch_ops.quant2d( + kunlun_ops.quant2d( x=x, y=x_q, max=max, @@ -1954,7 +1954,7 @@ def I8_mqa_logits( is_causal: Optional[bool] = False, use_xfa_boost: Optional[bool] = False, ) -> None: - xtorch_ops.I8_mqa_logits( + kunlun_ops.I8_mqa_logits( q=q, fused_kv_cache=fused_kv_cache, weights=weights, @@ -1984,7 +1984,7 @@ def I8_mqa_logits_cuda( is_causal: Optional[bool] = False, use_xfa_boost: Optional[bool] = False, ) -> None: - xtorch_ops.I8_mqa_logits( + kunlun_ops.I8_mqa_logits( q=q, fused_kv_cache=fused_kv_cache, weights=weights, @@ -2034,7 +2034,8 @@ def I8_paged_mqa_logits( out: torch.Tensor, use_xfa_boost: Optional[bool] = False, ) -> None: - xtorch_ops.I8_paged_mqa_logits( + kunlun_ops.sparse_prefill_fwd_opt( +.I8_paged_mqa_logits( q=q, fused_kv_cache=fused_kv_cache, weights=weights, @@ -2060,7 +2061,7 @@ def I8_paged_mqa_logits_cuda( out: torch.Tensor, use_xfa_boost: Optional[bool] = False, ) -> None: - xtorch_ops.I8_paged_mqa_logits( + kunlun_ops.I8_paged_mqa_logits( q=q, fused_kv_cache=fused_kv_cache, weights=weights, @@ -2111,7 +2112,7 @@ def sparse_prefill_fwd_opt( is_causal: Optional[bool] = True, use_xfa_boost: Optional[bool] = False, ) -> None: - xtorch_ops.sparse_prefill_fwd_opt( + kunlun_ops.sparse_prefill_fwd_opt( q=q, kv=kv, indices=indices, @@ -2147,7 +2148,7 @@ def sparse_prefill_fwd_opt_cuda( is_causal: Optional[bool] = True, use_xfa_boost: Optional[bool] = False, ) -> None: - xtorch_ops.sparse_prefill_fwd_opt( + kunlun_ops.sparse_prefill_fwd_opt( q=q, kv=kv, indices=indices, @@ -2207,7 +2208,7 @@ def fwd_kvcache_mla( use_xfa_boost: Optional[bool] = False, kv_lod_xpu: Optional[torch.Tensor] = None, ) -> None: - xtorch_ops.fwd_kvcache_mla( + kunlun_ops.fwd_kvcache_mla( q_c=q_c, kv_cache=kv_cache, indices=indices, @@ -2241,7 +2242,7 @@ def fwd_kvcache_mla_cuda( use_xfa_boost: Optional[bool] = False, kv_lod_xpu: Optional[torch.Tensor] = None, ) -> None: - xtorch_ops.fwd_kvcache_mla( + kunlun_ops.fwd_kvcache_mla( q_c=q_c, kv_cache=kv_cache, indices=indices, @@ -2293,7 +2294,7 @@ def dequant_int4( int4_signed: bool = True, use_mode_fast: bool = False, ) -> None: - xtorch_ops.dequant_int4( + kunlun_ops.dequant_int4( x=x, scale=scale, zero=zero, @@ -2315,7 +2316,7 @@ def dequant_int4_cuda( int4_signed: bool = True, use_mode_fast: bool = False, ) -> None: - xtorch_ops.dequant_int4( + kunlun_ops.dequant_int4( x=x, scale=scale, zero=zero, @@ -2350,7 +2351,10 @@ def fast_topkv2( score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048 ) -> torch.Tensor: assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now" - topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk) + topk_indices = kunlun_ops.fast_topkv2( + score=score, + lengths=lengths, + topk=topk) return topk_indices @@ -2359,7 +2363,10 @@ def fast_topkv2_cuda( score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048 ) -> torch.Tensor: assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now" - topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk) + topk_indices = kunlun_ops.fast_topkv2( + score=score, + lengths=lengths, + topk=topk) return topk_indices