/* * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MLA_PREPROCESS_TORCH_ADPT_H #define MLA_PREPROCESS_TORCH_ADPT_H #include "op_host/mla_preprocess.h" namespace vllm_ascend { std::tuple mla_preprocess( const at::Tensor &hiddenState, const at::Tensor &wdqkv, const c10::optional &descale0, const at::Tensor &gamma1, const c10::optional &beta1, const at::Tensor &wuq, const c10::optional &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, const c10::optional &quant_scale0, const c10::optional &quant_offset0, const c10::optional &bias0, const c10::optional &quant_scale1, const c10::optional &quant_offset1, const c10::optional &bias1, const c10::optional &ctkv_scale, const c10::optional &q_nope_scale, c10::optional cache_mode, c10::optional quant_mode, c10::optional enable_inner_out, at::Tensor &q_out0, at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1, at::Tensor &inner_out) { at::Tensor Descale0 = descale0.has_value() ? descale0.value() : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); at::Tensor Descale1 = descale1.has_value() ? descale1.value() : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); at::Tensor Beta1 = beta1.has_value() ? beta1.value() : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); at::Tensor Quant_scale0 = quant_scale0.has_value() ? quant_scale0.value() : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); at::Tensor Quant_scale1 = quant_scale1.has_value() ? quant_scale1.value() : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); at::Tensor Quant_offset0 = quant_offset0.has_value() ? quant_offset0.value() : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); at::Tensor Quant_offset1 = quant_offset1.has_value() ? quant_offset1.value() : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); at::Tensor Bias0 = bias0.has_value() ? bias0.value() : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); at::Tensor Bias1 = bias1.has_value() ? bias1.value() : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); at::Tensor CtkvScale = ctkv_scale.has_value() ? ctkv_scale.value() : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); at::Tensor QnopeScale = q_nope_scale.has_value() ? q_nope_scale.value() : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); bool enableInnerOut = enable_inner_out.has_value() ? enable_inner_out.value() : false; auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling( hiddenState, wdqkv, wuk, gamma1, kv_cache_rope, cache_mode, quant_mode, enableInnerOut ); void *hidden_state_ptr = hiddenState.data_ptr(); void *quant_scale0_ptr = Quant_scale0.data_ptr(); void *quant_offset0_ptr = Quant_offset0.data_ptr(); void *wdqkv_ptr = wdqkv.data_ptr(); void *bias0_ptr = Bias0.data_ptr(); void *gamma1_ptr = gamma1.data_ptr(); void *beta1_ptr = Beta1.data_ptr(); void *quant_scale1_ptr = Quant_scale1.data_ptr(); void *quant_offset1_ptr = Quant_offset1.data_ptr(); void *gamma2_ptr = gamma2.data_ptr(); void *sin_ptr = sin.data_ptr(); void *cos_ptr = cos.data_ptr(); void *kv_cache_ptr = kv_cache.data_ptr(); void *slotmapping_ptr = slotmapping.data_ptr(); void *wuq_ptr = wuq.data_ptr(); void *bias1_ptr = Bias1.data_ptr(); void *wuk_ptr = wuk.data_ptr(); void *descale0_ptr = Descale0.data_ptr(); void *descale1_ptr = Descale1.data_ptr(); void *ctkv_scale_ptr = CtkvScale.data_ptr(); void *qnope_scale_ptr = QnopeScale.data_ptr(); void *q_out0_ptr = q_out0.data_ptr(); void *kv_cache_out0_ptr = kv_cache_out0.data_ptr(); void *q_out1_ptr = q_out1.data_ptr(); void *kv_cache_out1_ptr = kv_cache_out1.data_ptr(); void *inner_out_ptr = inner_out.data_ptr(); void *workspace_ptr = workspace_tensor.data_ptr(); void *tiling_ptr = tiling.data_ptr(); aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); at_npu::native::OpCommand cmd; cmd.Name("mla_preprocess"); cmd.SetCustomHandler([stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, inner_out_ptr, workspace_ptr, tiling_ptr, block_dim]() -> int { mla_preprocess_impl(stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr, kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, inner_out_ptr, workspace_ptr, tiling_ptr, block_dim); return 0; }); cmd.Run(); return std::forward_as_tuple(q_out0, kv_cache_out0, q_out1, kv_cache_out1, inner_out); } } #endif