diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp index 17fc94c..7ea231c 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp @@ -19,7 +19,7 @@ #include "../op_host/tiling/mla_preprocess_tiling.h" extern "C" __global__ __aicore__ void mla_preprocess( - GM_ADDR hiddenState, GM_ADDR gamma1, GM_ADDR beta1, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv, + GM_ADDR hiddenState, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv, GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3, GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq, GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q, @@ -143,7 +143,7 @@ extern "C" __global__ __aicore__ void mla_preprocess( case KEY_FP16_CACHEMODE_0_QUANTMODE_0: { MLAPO_FP16::MLAOperation opFp16Cm0Qm0( mlaTilingData, tiling); - opFp16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + opFp16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, s1, s2, s3); @@ -158,7 +158,7 @@ extern "C" __global__ __aicore__ void mla_preprocess( case KEY_FP16_CACHEMODE_1_QUANTMODE_0: { MLAPO_FP16::MLAOperation opFp16Cm1Qm0(mlaTilingData, tiling); - opFp16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + opFp16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, s1, s2, s3); @@ -174,7 +174,7 @@ extern "C" __global__ __aicore__ void mla_preprocess( MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, QuantMode::PER_TENSOR_ASYMM_QUANT> opBf16Cm0Qm0(mlaTilingData, tiling); - opBf16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + opBf16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, s1, s2, s3, s4, s5); @@ -190,7 +190,7 @@ extern "C" __global__ __aicore__ void mla_preprocess( MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, QuantMode::PER_TENSOR_ASYMM_QUANT> opBf16Cm1Qm0(mlaTilingData, tiling); - opBf16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + opBf16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, s1, s2, s3, s4, s5); @@ -206,7 +206,7 @@ extern "C" __global__ __aicore__ void mla_preprocess( MLAPO_BF16::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, QuantMode::PER_TENSOR_ASYMM_QUANT> opBf16Cm3Qm0(mlaTilingData, tiling); - opBf16Cm3Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + opBf16Cm3Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, s1, s2, s3, s4, s5); @@ -230,8 +230,6 @@ namespace vllm_ascend { extern void mla_preprocess_impl( void* stream, void* hidden_state, - void* gamma1, - void* beta1, void* quant_scale1, void* quant_offset1, void* wdqkv, @@ -264,8 +262,6 @@ extern void mla_preprocess_impl( { mla_preprocess<<>>( hidden_state, - gamma1, - beta1, quant_scale1, quant_offset1, wdqkv, diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp index f58f4aa..43d9509 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp @@ -2388,7 +2388,7 @@ public: this->mlaParams = mlaParams_; } - __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm, + __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm, GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm, GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm, @@ -2426,7 +2426,6 @@ public: #endif hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(hiddenStateGm)); - gamma1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma1Gm)); quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale1Gm)); quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm)); wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wdqkvGm)); @@ -2444,7 +2443,6 @@ public: qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(qGm2)); bias1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias1Gm)); bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm)); - beta1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta1Gm)); beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta2Gm)); #ifdef __DAV_C220_VEC__ @@ -2711,7 +2709,6 @@ private: AscendC::GlobalTensor hiddenStateGmTensor; - AscendC::GlobalTensor gamma1GmTensor; AscendC::GlobalTensor quantScale1GmTensor; AscendC::GlobalTensor quantOffset1GmTensor; @@ -2741,7 +2738,6 @@ private: AscendC::GlobalTensor s5GmTensor; AscendC::GlobalTensor descale1gmTensor; AscendC::GlobalTensor descale2gmTensor; - AscendC::GlobalTensor beta1GmTensor; AscendC::GlobalTensor beta2GmTensor; AscendC::GlobalTensor bias1gmTensor; diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp index 097fbc2..73cb04d 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp @@ -294,8 +294,7 @@ class Quant public: __aicore__ inline Quant() {} - __aicore__ inline void Init(AscendC::GlobalTensor gammaGmTensor, AscendC::GlobalTensor betaGmTensor, - AscendC::GlobalTensor quantScaleGmTensor, + __aicore__ inline void Init(AscendC::GlobalTensor quantScaleGmTensor, AscendC::GlobalTensor quantOffsetGmTensor, AscendC::GlobalTensor inputGmTensor, AscendC::GlobalTensor outputGmTensor, uint32_t stride, uint32_t num_col, float avg_factor, uint64_t gm_offset, @@ -2037,7 +2036,7 @@ public: this->mlaParams = mlaParams_; } - __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm, + __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm, GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm, GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm, @@ -2057,7 +2056,6 @@ public: mm_w8a8_1.PreloadDoubleWeight(); #endif hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(hiddenStateGm)); - gamma1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(gamma1Gm)); quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(quantScale1Gm)); quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm)); @@ -2081,7 +2079,6 @@ public: qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(qGm2)); bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm)); - beta1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta1Gm)); beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta2Gm)); #ifdef __DAV_C220_CUBE__ mm_w8a8_2.Init(s1GmTensor, wuqGmTensor, bias2gmTensor, descale2gmTensor, s2GmTensor, mlaParams, 1); @@ -2105,7 +2102,7 @@ public: row_work_ = 0; } this->splitN = mlaParams.perTaskNum; - Quant1.Init(gamma1GmTensor, beta1GmTensor, quantScale1GmTensor, quantOffset1GmTensor, hiddenStateGmTensor, + Quant1.Init(quantScale1GmTensor, quantOffset1GmTensor, hiddenStateGmTensor, s1GmTensor, 0, num_col_1, 0.0001395089285, vectorBlockIdx * static_cast(row_work) * num_col_1, vectorBlockIdx * static_cast(row_work) * num_col_1, row_work_, mlaParams); @@ -2316,7 +2313,6 @@ private: AscendC::GlobalTensor hiddenStateGmTensor; - AscendC::GlobalTensor gamma1GmTensor; AscendC::GlobalTensor quantScale1GmTensor; AscendC::GlobalTensor quantOffset1GmTensor; @@ -2343,7 +2339,6 @@ private: AscendC::GlobalTensor s3GmTensor; AscendC::GlobalTensor descale1gmTensor; AscendC::GlobalTensor descale2gmTensor; - AscendC::GlobalTensor beta1GmTensor; AscendC::GlobalTensor beta2GmTensor; AscendC::GlobalTensor bias1gmTensor; diff --git a/csrc/ops.h b/csrc/ops.h index 6364005..c249bb5 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -128,8 +128,6 @@ namespace vllm_ascend { extern void mla_preprocess_impl( void* stream, void* hidden_state, - void* gamma1, - void* beta1, void* quant_scale1, void* quant_offset1, void* wdqkv, diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 74614e5..9eaba72 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -108,7 +108,7 @@ std::tuple rotary_embedding(at::Tensor &positions, at::T } std::tuple mla_preprocess( - const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv, + const at::Tensor &hiddenState, const at::Tensor &wdqkv, const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq, const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, @@ -135,8 +135,6 @@ std::tuple mla_preproces ); void *hidden_state_ptr = hiddenState.data_ptr(); - void *gamma0_ptr = gamma0.data_ptr(); - void *beta0_ptr = beta0.data_ptr(); void *quant_scale0_ptr = quant_scale0.data_ptr(); void *quant_offset0_ptr = quant_offset0.data_ptr(); void *wdqkv_ptr = wdqkv.data_ptr(); @@ -168,12 +166,12 @@ std::tuple mla_preproces at_npu::native::OpCommand cmd; cmd.Name("mla_preprocess"); - cmd.SetCustomHandler([stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, + cmd.SetCustomHandler([stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr, tiling_ptr, block_dim]() -> int { - mla_preprocess_impl(stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, + mla_preprocess_impl(stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr, kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr, @@ -502,7 +500,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand); ops.def( - "mla_preprocess(Tensor hiddenState, Tensor gamma0, Tensor beta0, Tensor wdqkv," + "mla_preprocess(Tensor hiddenState, Tensor wdqkv," " Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1," " Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache," " Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0," diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index bf7ed01..dbb056b 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -83,8 +83,6 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_ std::tuple mla_preprocess( const at::Tensor &hiddenState, - const at::Tensor &gamma0, - const at::Tensor &beta0, const at::Tensor &wdqkv, const at::Tensor &descale0, const at::Tensor &gamma1, diff --git a/tests/e2e/singlecard/ops/test_mla_preprocess.py b/tests/e2e/singlecard/ops/test_mla_preprocess.py index c4de3cc..e73310f 100644 --- a/tests/e2e/singlecard/ops/test_mla_preprocess.py +++ b/tests/e2e/singlecard/ops/test_mla_preprocess.py @@ -18,8 +18,6 @@ def test_mla_preprocess_kernel(): dtype = torch.bfloat16 hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu() - gamma0 = torch.randn((N_7168), dtype=dtype).npu() - beta0 = torch.randn((N_7168), dtype=dtype).npu() quant_scale0 = torch.randn((1, ), dtype=dtype).npu() quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu() @@ -74,8 +72,6 @@ def test_mla_preprocess_kernel(): torch.ops._C_ascend.mla_preprocess( hidden_states, - gamma0, - beta0, wdqkv, de_scale0, gamma1, diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index b6d8b66..adb0e6a 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -1,5 +1,8 @@ from __future__ import annotations +import os +from unittest.mock import patch + import pytest from vllm import SamplingParams from vllm.config import CompilationConfig, CUDAGraphMode @@ -108,3 +111,19 @@ def test_mtp2_correctness_full_graph( model_name: str, ): mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL) + + +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"}) +def test_mtp_correctness_piecewise_graph_with_mlapo_kernel( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, model_name, 1) + + +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"}) +def test_mtp_correctness_full_graph_with_mlapo_kernel( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 7f6e7f8..b249741 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -716,17 +716,7 @@ class AscendMLAImpl(MLAAttentionImpl): self.qb_qt_bias = qb_qt_bias.reshape( self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) - device = self.q_proj.weight.device - self.gamma0 = torch.ones( - [self.fused_qkv_a_proj.weight.shape[-1]], - dtype=act_dtype, - device=device, - ) - self.beta0 = torch.zeros( - [self.fused_qkv_a_proj.weight.shape[-1]], - dtype=act_dtype, - device=device, - ) + device = self.q_a_proj.weight.device self.gamma1 = self.q_a_layernorm.weight.data self.beta1 = self.q_a_layernorm.bias.data self.gamma2 = self.kv_a_layernorm.weight.data @@ -1085,8 +1075,6 @@ class AscendMLAImpl(MLAAttentionImpl): torch.ops._C_ascend.mla_preprocess( hidden_states, - self.gamma0, - self.beta0, self.wd_qkv, self.deq_scale_qkv, self.gamma1,