This commit is contained in:
Chranos
2026-02-04 17:39:32 +08:00
parent 8511fe8530
commit 79dfc69789
299 changed files with 55927 additions and 0 deletions

View File

@@ -0,0 +1,58 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/moe/add_bias_activation.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void active(const torch::Tensor &input,
const torch::Tensor &output,
const c10::optional<torch::Tensor> &bias,
const c10::optional<torch::Tensor> &cusum_token_count,
const std::string &act_mode,
bool is_gated,
int64_t start_expert_id,
int64_t expert_size,
double active_coef) {
TORCH_CHECK(
act_mode == "silu" || act_mode == "gelu" || act_mode == "quick_gelu" || act_mode == "swish",
"act_mode must be 'silu', 'gelu', 'quick_gelu' or 'swish'.")
cnnlActivationMode_t act_type = act_mode == "gelu" ? CNNL_ACTIVATION_GELU : CNNL_ACTIVATION_SWISH;
if (act_mode == "quick_gelu") {
active_coef = 1.702f;
} else if (act_mode == "silu") {
active_coef = 1.0f;
}
TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2")
auto input_shape = input.sizes();
int64_t in_channel = input_shape.back();
TORCH_CHECK(in_channel > 0, "in_channel > 0")
if (is_gated) {
TORCH_CHECK(in_channel % 2 == 0, "in_channel % 2 == 0 if is_gated is true")
}
int64_t total_tokens = input.numel() / in_channel;
int64_t inner_size = is_gated ? in_channel / 2 : in_channel;
int64_t num_expert = cusum_token_count.has_value() ? (cusum_token_count.value().size(0) - 1) : 0;
const torch_mlu::mlu::MLUGuard device_guard(input.device());
int64_t output_stride = output.stride(-2);
auto data_dtype = getCnnlDataType(input.scalar_type());
auto queue = torch_mlu::getCurMLUStream();
tmo::invokeGroupAddBiasActivationKernel(
queue, getAtTensorPtr(output), getAtTensorPtr(input), getAtTensorPtr(bias),
(int *)getAtTensorPtr(cusum_token_count), num_expert, total_tokens, inner_size, output_stride,
data_dtype, is_gated, act_type, start_expert_id, expert_size, active_coef);
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,145 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/rotary_embedding.mluh"
#include "torch_ops_api.h"
#include "utils.h"
namespace tmo {
namespace torch_api {
void apply_rotary(const torch::Tensor &input,
const torch::Tensor &sin_cache,
const torch::Tensor &cos_cache,
const c10::optional<torch::Tensor> &position_ids,
const c10::optional<torch::Tensor> &cu_seqlens,
bool interleaved,
bool discrete,
bool dynamic_ntk,
int64_t max_seqlen) {
auto output = input;
// 1.check device and tensor type
checkTensorSameAttr<TensorAttr::ALL>(input, sin_cache, cos_cache);
const bool has_position_ids = position_ids.has_value();
const bool has_cu_seqlens = cu_seqlens.has_value();
const int origin_device_id = input.get_device();
const void *position_ids_ptr = nullptr;
const void *cu_seqlens_ptr = nullptr;
if (has_position_ids) {
TORCH_CHECK(position_ids.value().dtype() == torch::kInt32,
"position_ids type need be torch::kInt32");
TORCH_CHECK(position_ids.value().get_device() == origin_device_id,
"Tensor device index is not the same, original index: ", origin_device_id,
"now index is: ", position_ids.value().get_device());
position_ids_ptr = position_ids.value().data_ptr();
}
if (has_cu_seqlens) {
TORCH_CHECK(cu_seqlens.value().dtype() == torch::kInt32,
"cu_seqlens type need be torch::kInt32");
TORCH_CHECK(cu_seqlens.value().get_device() == origin_device_id,
"Tensor device index is not the same, original index: ", origin_device_id,
"now index is: ", cu_seqlens.value().get_device());
cu_seqlens_ptr = cu_seqlens.value().data_ptr();
}
// 2. check shape
int total_seqlen = 0;
int batch_size = 0;
int head_size = input.size(-1);
if (input.dim() == 3) { // pack mode
TORCH_CHECK(has_cu_seqlens,
"input has 3 dims: (total_seq_len, head_num, head_size),"
" which means pack mode, cu_seqlens should not be None");
total_seqlen = input.size(0);
batch_size = cu_seqlens.value().size(0) - 1;
} else if (input.dim() == 4) {
TORCH_CHECK(!has_cu_seqlens,
"input has 4 dims: (batch_size, seq_len, head_num, head_size),"
" which means pad mode, cu_seqlens should be None");
TORCH_CHECK(max_seqlen == input.size(1),
"input has 4 dims: (batch_size, seq_len, head_num, head_size),"
" which means pad mode, max_seqlen must be equtals to input.size(1)");
batch_size = input.size(0);
total_seqlen = batch_size * input.size(1);
} else {
TORCH_CHECK(false, "input only support 3 or 4 dims");
}
TORCH_CHECK(head_size <= 256, "only support input head_size <= 256");
const int rope_seqlen = dynamic_ntk ? sin_cache.size(1) : sin_cache.size(0);
const int rope_dim = dynamic_ntk ? sin_cache.size(2) : sin_cache.size(1);
if (has_position_ids) {
if (discrete) {
CHECK_SHAPE(position_ids.value(), total_seqlen);
} else {
CHECK_SHAPE(position_ids.value(), batch_size);
}
} else {
TORCH_CHECK(!discrete, "discrete must be false if position ids is null.")
}
if (!(has_position_ids && discrete)) {
TORCH_CHECK(max_seqlen <= rope_seqlen, "max_seqlen must less than or equal to rope_seqlen.")
}
if (dynamic_ntk) {
CHECK_SHAPE(sin_cache, batch_size, rope_seqlen, rope_dim);
CHECK_SHAPE(cos_cache, batch_size, rope_seqlen, rope_dim);
} else {
CHECK_SHAPE(sin_cache, rope_seqlen, rope_dim);
CHECK_SHAPE(cos_cache, rope_seqlen, rope_dim);
}
// 3. check strides
TORCH_CHECK(input.stride(-1) == 1, "input last dim must be contiguous");
if (dynamic_ntk) {
TORCH_CHECK(sin_cache.stride(1) == cos_cache.stride(1),
"sin_cache second stride must be equal to cos_cache second stride");
} else {
TORCH_CHECK(sin_cache.stride(0) == cos_cache.stride(0),
"sin_cache first stride must be equal to cos_cache second stride");
}
if (has_position_ids) {
TORCH_CHECK(position_ids.value().is_contiguous(), "position_ids must be contiguous");
}
if (has_cu_seqlens) {
TORCH_CHECK(cu_seqlens.value().is_contiguous(), "cu_seqlens must be contiguous");
}
// prepare inputs
auto dims = input.dim();
const int64_t num_heads = input.size(dims - 2);
const int64_t head_dim = input.size(dims - 1);
const int64_t input_seq_stride = input.strides()[dims - 3];
const int64_t input_head_stride = input.strides()[dims - 2];
const int64_t output_seq_stride = output.strides()[dims - 3];
const int64_t output_head_stride = output.strides()[dims - 2];
const torch_mlu::mlu::MLUGuard device_guard(input.device());
auto queue = torch_mlu::getCurMLUStream();
auto data_type = getCnnlDataType(input.scalar_type());
invokeRotaryEmbedding(queue, output.data_ptr(), input.data_ptr(), sin_cache.data_ptr(),
cos_cache.data_ptr(), (int *)position_ids_ptr, (int *)cu_seqlens_ptr,
batch_size, max_seqlen, num_heads, head_dim, rope_seqlen, rope_dim,
dynamic_ntk ? sin_cache.strides()[1] : sin_cache.strides()[0],
input_seq_stride, input_head_stride, output_seq_stride, output_head_stride,
interleaved, discrete, dynamic_ntk, data_type);
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,151 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
std::vector<at::Tensor> attention_project(const at::Tensor &input,
const at::Tensor &q_weight,
const c10::optional<at::Tensor> &q_bias,
const c10::optional<at::Tensor> &k_weight,
const c10::optional<at::Tensor> &k_bias,
const c10::optional<at::Tensor> &v_weight,
const c10::optional<at::Tensor> &v_bias,
const c10::optional<at::Tensor> &norm_weight,
const c10::optional<at::Tensor> &norm_bias,
const c10::optional<at::Tensor> &residual,
const std::string &out_layout,
int64_t head_size,
double eps,
double alpha,
double beta,
bool norm_out) {
// check device and dtype
checkTensorSameAttr<TensorAttr::ALL>(input, q_weight, q_bias, k_weight, k_bias, v_weight, v_bias,
norm_weight, norm_bias, residual);
// check contiguous
CHECK_TENSOR_CONTIGUOUS(input)
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(residual)
// check size
const int64_t nDim = input.dim();
at::Tensor input_view = input;
if (nDim == 2) {
input_view = input.unsqueeze(0);
}
const bool has_k = k_weight.has_value();
const bool has_v = v_weight.has_value();
const int64_t n = input_view.size(0);
const int64_t t = input_view.size(1);
const int64_t hidden_size_q = q_weight.size(0);
const int64_t hidden_size_k = has_k ? k_weight.value().size(0) : 0;
const int64_t hidden_size_v = has_v ? v_weight.value().size(0) : 0;
// check params
bool has_bias = q_bias.has_value();
bool has_ln = norm_weight.has_value();
bool has_residual = residual.has_value();
TORCH_CHECK(!(has_ln && has_residual), "cannot support layernorm and residual at the same time.")
TORCH_CHECK(out_layout == "nthc" || (out_layout == "nhtc" && nDim == 3),
"input must be 3-D if out_layout is 'nhtc'")
bool trans_out = (out_layout == "nhtc");
const int64_t head_num_q = hidden_size_q / head_size;
const int64_t head_num_k = hidden_size_k / head_size;
const int64_t head_num_v = hidden_size_v / head_size;
const torch_mlu::mlu::MLUGuard device_guard(input_view.device());
auto q_shape = trans_out ? std::vector<int64_t>({n, head_num_q, t, head_size})
: std::vector<int64_t>({n, t, hidden_size_q});
auto k_shape = trans_out ? std::vector<int64_t>({n, head_num_k, t, head_size})
: std::vector<int64_t>({n, t, hidden_size_k});
auto v_shape = trans_out ? std::vector<int64_t>({n, head_num_v, t, head_size})
: std::vector<int64_t>({n, t, hidden_size_v});
auto out_q = at::empty(q_shape, input_view.options());
auto out_k = has_k ? at::empty(k_shape, input_view.options()) : at::Tensor();
auto out_v = has_v ? at::empty(q_shape, input_view.options()) : at::Tensor();
auto out_ln = norm_out ? at::empty(input.sizes(), input_view.options()) : at::Tensor();
// create tensor descs
auto descs =
createTensorDescs({input_view, q_weight, q_bias.value_or(at::Tensor()),
k_weight.value_or(at::Tensor()), k_bias.value_or(at::Tensor()),
v_weight.value_or(at::Tensor()), v_bias.value_or(at::Tensor()),
norm_weight.value_or(at::Tensor()), norm_bias.value_or(at::Tensor()),
residual.value_or(at::Tensor()), out_q, out_k, out_v, out_ln});
// create and set attn_proj_desc
cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode =
has_ln ? CNNL_TRANSFORMER_PRE_LAYERNORM_NO_RESIDUAL
: (has_residual ? CNNL_TRANSFORMER_NO_LAYERNORM_WITH_RESIDUAL
: CNNL_TRANSFORMER_NO_LAYERNORM_NO_RESIDUAL);
auto compute_dtype = getCnnlDataType(input_view.scalar_type());
auto attn_proj_desc = tmo::op_desc::AttnProjDesc();
attn_proj_desc.setDesc(layernorm_residual_mode, /*layernorm_residual_mode*/
compute_dtype, true, /*has_q*/
has_k, /*has_k*/
has_v, /*has_v*/
has_bias, false, 0, /*no packed*/
trans_out, /*trans_out*/
norm_out, /*store_layernorm out*/
alpha, beta, /*alpha && beta */
eps /*eps*/);
auto handle = torch_mlu::getCurrentHandle();
// get workspace size
size_t workspace_size = 0;
CNNL_CHECK_FATAL(cnnlGetTransformerAttnProjWorkspaceSize(handle, attn_proj_desc, nullptr,
descs[0].get(), descs[1].get(),
descs[10].get(), &workspace_size));
auto workspace =
at::empty({static_cast<int64_t>(workspace_size)}, input.options().dtype(at::kByte));
// run forward
CNNL_CHECK_FATAL(cnnlTransformerAttnProj(
handle, attn_proj_desc, nullptr, /*quant_desc*/
descs[0].get(), getAtTensorPtr(input_view), /* input */
descs[9].get(), getAtTensorPtr(residual), /* residual */
descs[1].get(), getAtTensorPtr(q_weight), /* q weight */
descs[3].get(), getAtTensorPtr(k_weight), /* k weight */
descs[5].get(), getAtTensorPtr(v_weight), /* v weight */
descs[2].get(), getAtTensorPtr(q_bias), /* q bias */
descs[4].get(), getAtTensorPtr(k_bias), /* k bias */
descs[6].get(), getAtTensorPtr(v_bias), /* v bias */
nullptr, nullptr, /*no valid token*/
descs[7].get(), getAtTensorPtr(norm_weight), /* layernorm weight */
descs[8].get(), getAtTensorPtr(norm_bias), /* layernorm bias */
getAtTensorPtr(workspace), workspace_size, /* workspace */
descs[10].get(), getAtTensorPtr(out_q), /* q out */
descs[11].get(), getAtTensorPtr(out_k), /* k out */
descs[12].get(), getAtTensorPtr(out_v), /* v out */
descs[13].get(), getAtTensorPtr(out_ln) /* layernorm out */
));
// return
if (nDim == 2) {
out_q.squeeze_(0);
if (has_k) out_k.squeeze_(0);
if (has_v) out_v.squeeze_(0);
if (norm_out) out_ln.squeeze_(0);
}
std::vector<at::Tensor> output_list;
output_list.emplace_back(out_q);
if (has_k) output_list.emplace_back(out_k);
if (has_v) output_list.emplace_back(out_v);
if (norm_out) output_list.emplace_back(out_ln);
return output_list;
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,91 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void batch_matmul(const at::Tensor &a,
const at::Tensor &b,
const at::Tensor &c,
double alpha,
double beta,
double a_scale,
double b_scale,
bool trans_a,
bool trans_b) {
bool use_beta = beta == 0 ? false : true;
int batch = a.size(0);
int m = trans_a ? a.size(-1) : a.size(-2);
int n = trans_b ? b.size(-2) : b.size(-1);
checkTensorSameAttr<TensorAttr::DEVICE>(a, b, c);
// check contiguous
TORCH_CHECK(a.is_contiguous(), "a must be contiguous.")
TORCH_CHECK(b.is_contiguous(), "b must be contiguous.")
TORCH_CHECK(c.is_contiguous(), "c must be contiguous.")
CHECK_SHAPE(c, batch, m, n);
// get cnnl data type and init output
auto a_dtype = getCnnlDataType(a.scalar_type());
auto b_dtype = getCnnlDataType(b.scalar_type());
TORCH_CHECK(a_dtype == b_dtype, "a, b must be same dtype.");
auto c_dtype = getCnnlDataType(c.scalar_type());
if (a_dtype == CNNL_DTYPE_BFLOAT16) c_dtype = CNNL_DTYPE_FLOAT;
// create tensor desc
auto descs = createTensorDescs({a, b, c});
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[0].get(), a_dtype));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[1].get(), b_dtype));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[2].get(), c_dtype));
if (a_dtype == CNNL_DTYPE_INT8) {
float int_max = 127.0;
int quant_bit = 8;
float max_a = int_max / a_scale;
float max_b = int_max / b_scale;
int pos_a = std::floor(std::log2(max_a) - (quant_bit - 2));
int pos_b = std::floor(std::log2(max_b) - (quant_bit - 2));
float new_a_scale = std::pow(2.0f, pos_a) * a_scale;
float new_b_scale = std::pow(2.0f, pos_b) * b_scale;
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[0].get(), pos_a, new_a_scale));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[1].get(), pos_b, new_b_scale));
TORCH_CHECK(c_dtype != CNNL_DTYPE_BFLOAT16,
"output dtype cannot be bfloat16 when a/b is fixed-point")
}
// get && set Op desc
cnnlMatMulHeuristicResult_t heuristic_result;
cnnlMatMulAlgo_t algo;
auto bmm_ex_desc =
tmo::op_desc::BatchMatMulDesc(heuristic_result, algo, use_beta, trans_a, trans_b);
const torch_mlu::mlu::MLUGuard device_guard(a.device());
auto handle = torch_mlu::getCurrentHandle();
size_t workspace_size = 0;
int requested_algo_count = 1;
int returned_algo_count = 0;
CNNL_CHECK_FATAL(cnnlGetBatchMatMulExAlgoHeuristic(
handle, bmm_ex_desc, descs[0].get(), descs[1].get(), descs[2].get(), nullptr,
requested_algo_count, &heuristic_result, &returned_algo_count));
CNNL_CHECK_FATAL(cnnlGetBatchMatMulExHeuristicResult(heuristic_result, algo, &workspace_size));
auto workspace = at::empty({static_cast<int64_t>(workspace_size)}, a.options().dtype(at::kByte));
// run forward
float alpha_f = alpha;
float beta_f = beta;
CNNL_CHECK_FATAL(cnnlBatchMatMulEx(handle, bmm_ex_desc, algo, &alpha_f, descs[0].get(),
getAtTensorPtr(a), descs[1].get(), getAtTensorPtr(b), &beta_f,
descs[2].get(), getAtTensorPtr(c), getAtTensorPtr(workspace),
workspace_size));
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,45 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "cnpx.h"
#include <sstream>
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
static cnpxDomainHandle_t domain = cnpxDomainCreate("CNPERF_KERNEL_TMO");
static bool cnperf_kernel_analysis = getenv("CNPERF_KERNEL_ANALYSIS");
void cnpxPush(const OpTheory &op) {
if (cnperf_kernel_analysis) {
size_t calc = op.getTheoryCalc();
cnnlDataType_t calc_dtype = op.getCalcDtype();
size_t io = op.getTheoryIO();
auto op_name = op.getOpName();
std::ostringstream jsonStream;
jsonStream << "{\"name\":\"" << op_name << "\", \"theo_calc\":" << calc
<< ", \"theo_bytes\":" << io << ", \"calc_type\":" << calc_dtype << "}";
std::string json = jsonStream.str();
// std::cout << json.c_str() << std::endl;
cnpxDomainRangePush(domain, json.c_str());
}
}
void cnpxPop() {
if (cnperf_kernel_analysis) {
cnpxDomainRangePop(domain);
}
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,39 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "comm_overlap.h"
namespace tmo {
namespace torch_api {
#define CNCL_TYPE_AND_SCALAR_TYPE(_) \
_(cnclFloat, at::kFloat) \
_(cnclBfloat16, at::kBFloat16) \
_(cnclHalf, at::kHalf) \
_(cnclInt32, at::kInt) \
_(cnclInt64, at::kLong) \
_(cnclInt8, at::kChar) \
_(cnclUint8, at::kByte) \
_(cnclInt16, at::kShort)
cnclDataType_t getCnclDataType(const at::ScalarType &data_type) {
switch (data_type) {
#define DEFINE_CASE(cncl_dtype, scalar_type) \
case scalar_type: \
return cncl_dtype;
CNCL_TYPE_AND_SCALAR_TYPE(DEFINE_CASE)
#undef DEFINE_CASE
default:
std::string msg("getCnclDataType() not supported for ");
throw std::runtime_error(msg + c10::toString(data_type));
}
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,69 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_TORCH_API_COMPUTE_ALLREDUCE_COMM_OVERLAP_H_
#define CSRC_TORCH_API_COMPUTE_ALLREDUCE_COMM_OVERLAP_H_
#include <cstddef>
#include <cstdint>
#include <optional>
#include <vector>
#include "cncl.h"
#include "framework/core/MLUEvent.h"
#include "framework/core/MLUStream.h"
#include "torch/extension.h"
#include "torch_api/utils.h"
namespace tmo {
namespace torch_api {
cnclDataType_t getCnclDataType(const at::ScalarType &data_type);
template <typename MmOp>
struct ParallelAllReduce {
static std::vector<torch_mlu::MLUEvent> event_;
static std::optional<torch_mlu::MLUStream> stream_;
cnclComm_t cncl_comm_;
std::vector<at::Tensor> d_list_;
ParallelAllReduce(int64_t cncl_comm) { cncl_comm_ = reinterpret_cast<cnclComm_t>(cncl_comm); }
at::Tensor operator()(MmOp &mm) {
auto compute_stream_ = torch_mlu::getCurrentMLUStream();
if (!stream_.has_value()) {
stream_.emplace(torch_mlu::getStreamFromPool());
}
auto comm_stream_ = stream_.value();
uint64_t loop = mm.getLoopNum();
if (event_.size() < loop + 1) {
event_.resize(loop + 1);
}
for (uint64_t i = 0; i < loop; i++) {
d_list_.push_back(mm.forward(i));
event_[i].place(compute_stream_);
event_[i].wait(comm_stream_);
// reduce_sum d_send
CNCL_CHECK(cnclAllReduce(getAtTensorPtr(d_list_[i]), getAtTensorPtr(d_list_[i]),
d_list_[i].numel(), getCnclDataType(d_list_[i].scalar_type()),
cnclSum, cncl_comm_, comm_stream_.stream()));
}
event_[loop].place(comm_stream_);
event_[loop].wait(compute_stream_);
return mm.getOutput();
}
};
template <typename MmOp>
std::vector<torch_mlu::MLUEvent> ParallelAllReduce<MmOp>::event_;
template <typename MmOp>
std::optional<torch_mlu::MLUStream> ParallelAllReduce<MmOp>::stream_;
} // namespace torch_api
} // namespace tmo
#endif // CSRC_TORCH_API_COMPUTE_ALLREDUCE_COMM_OVERLAP_H_

View File

@@ -0,0 +1,180 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "comm_overlap.h"
#include "torch_api/torch_ops_api.h"
namespace tmo {
namespace torch_api {
using namespace torch::indexing;
struct FlashAttnSmoothQuantMatmul {
at::Tensor q_;
at::Tensor k_;
at::Tensor v_;
const c10::optional<at::Tensor> &cu_seq_lens_q_;
const c10::optional<at::Tensor> &cu_seq_lens_kv_;
const at::Tensor &smooth_;
const at::Tensor &weight_;
const at::Tensor &weight_scale_;
const c10::optional<at::Tensor> &bias_;
const std::string &compute_dtype_;
std::string input_dtype_str_;
int64_t max_seq_len_q_;
int64_t max_seq_len_kv_;
double softmax_scale_;
bool is_causal_;
int64_t block_seq_;
std::vector<int> cumsum_seq_;
at::Tensor output_;
FlashAttnSmoothQuantMatmul(const at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
const c10::optional<at::Tensor> &cu_seq_lens_q,
const c10::optional<at::Tensor> &cu_seq_lens_kv,
const at::Tensor &smooth,
const at::Tensor &weight,
const at::Tensor &weight_scale,
const c10::optional<at::Tensor> &bias,
const int64_t max_seq_len_q,
const int64_t max_seq_len_kv,
const double softmax_scale,
const bool is_causal,
const std::string &compute_dtype,
const int64_t block_seq)
: q_(q),
k_(k),
v_(v),
cu_seq_lens_q_(cu_seq_lens_q),
cu_seq_lens_kv_(cu_seq_lens_kv),
smooth_(smooth),
weight_(weight),
weight_scale_(weight_scale),
bias_(bias),
compute_dtype_(compute_dtype),
max_seq_len_q_(max_seq_len_q),
max_seq_len_kv_(max_seq_len_kv),
softmax_scale_(softmax_scale),
is_causal_(is_causal) {
input_dtype_str_ = q.scalar_type() == at::kBFloat16 ? "bfloat16"
: q.scalar_type() == at::kHalf ? "half"
: "float";
bool is_pack = cu_seq_lens_q.has_value();
if (is_pack) {
TORCH_CHECK(cu_seq_lens_q.value().size(0) == 2 && cu_seq_lens_kv.value().size(0) == 2,
"only support 1 batch.")
TORCH_CHECK(q.dim() == 3 && k.dim() == 3 && v.dim() == 3, "q,k,v must be 3-d in pack mode.")
// 1 batch pack to pad
q_ = q_.unsqueeze(0);
k_ = k_.unsqueeze(0);
v_ = v_.unsqueeze(0);
} else {
TORCH_CHECK(q.size(0) == 1, "only support 1 batch.")
TORCH_CHECK(q.dim() == 4, "q must be 4-d in pad mode.")
}
int total_seq_q = q_.size(0) * q_.size(1);
output_ = at::empty({total_seq_q, weight.size(0)}, q.options());
if (total_seq_q >= 4096) {
block_seq_ = 4;
split_4(total_seq_q);
} else {
block_seq_ = 1;
split_1(total_seq_q);
}
}
void split_1(int seq) {
cumsum_seq_.resize(2);
cumsum_seq_[0] = 0;
cumsum_seq_[1] = seq;
}
void split_4(int seq) {
auto pad_up = [](int x, int y) -> int { return (x + y - 1) / y * y; };
int seq_4 = pad_up(seq / 8, 256);
int seq_3 = pad_up(seq / 8, 256);
int seq_2 = pad_up(seq / 4, 256);
int seq_1 = seq - seq_2 - seq_3 - seq_4;
cumsum_seq_.resize(5);
cumsum_seq_[0] = 0;
cumsum_seq_[1] = cumsum_seq_[0] + seq_1;
cumsum_seq_[2] = cumsum_seq_[1] + seq_2;
cumsum_seq_[3] = cumsum_seq_[2] + seq_3;
cumsum_seq_[4] = cumsum_seq_[3] + seq_4;
}
auto split_tensor(const at::Tensor &a, int64_t dim, int64_t start, int64_t end) {
return a.narrow(dim, start, end - start);
}
at::Tensor forward(const int64_t block_id) {
// flash attn
int64_t end_kv = is_causal_ ? cumsum_seq_[block_id + 1] : cumsum_seq_[block_seq_];
auto q_i = split_tensor(q_, 1, cumsum_seq_[block_id], cumsum_seq_[block_id + 1]);
auto k_i = split_tensor(k_, 1, 0, end_kv);
auto v_i = split_tensor(v_, 1, 0, end_kv);
auto attn_out = at::empty_like(q_i);
flash_attention(q_i, k_i, v_i, attn_out, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt,
c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt, q_i.size(1),
k_i.size(1), softmax_scale_, is_causal_, -1, -1, compute_dtype_, false);
// smooth quant
auto smooth_quant_input =
attn_out.flatten(-2, -1).flatten(0, 1); // (b, s, hn, hs) -> (b*s, hn*hs)
auto quant_out =
at::empty(smooth_quant_input.sizes(),
torch::TensorOptions().dtype(torch::kInt8).device(smooth_quant_input.device()));
auto quant_out_scale = at::empty({smooth_quant_input.size(0)}, smooth_.options());
smooth_quant(smooth_quant_input, smooth_, quant_out, quant_out_scale, c10::nullopt,
c10::nullopt, c10::nullopt, c10::nullopt, "per_token", true);
// quant matmul
auto d_i = quant_matmul(
quant_out, quant_out_scale, c10::nullopt, weight_, weight_scale_, c10::nullopt, bias_,
c10::nullopt, c10::nullopt, c10::nullopt, weight_scale_, c10::nullopt, input_dtype_str_,
split_tensor(output_, 0, cumsum_seq_[block_id], cumsum_seq_[block_id + 1]), "smooth_quant",
"quantize_per_token", "quantize_per_channel", 8, "none", false, 1.0, 1.0, 1.0, false, true);
return d_i;
}
int64_t getLoopNum() const { return block_seq_; }
at::Tensor getOutput() const { return output_; }
};
at::Tensor flash_attn_sq_mm_allreduce(const int64_t cncl_comm,
const at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
const c10::optional<at::Tensor> &cu_seq_lens_q,
const c10::optional<at::Tensor> &cu_seq_lens_kv,
const c10::optional<at::Tensor> &alibi_slope,
const c10::optional<at::Tensor> &attn_bias,
const at::Tensor &smooth,
const at::Tensor &weight,
const at::Tensor &weight_scale,
const c10::optional<at::Tensor> &bias,
const int64_t max_seq_len_q,
const int64_t max_seq_len_kv,
const double softmax_scale,
const bool is_causal,
const int64_t window_size_left,
const int64_t window_size_right,
const std::string &compute_dtype,
const int64_t block_seq) {
FlashAttnSmoothQuantMatmul mm(q, k, v, cu_seq_lens_q, cu_seq_lens_kv, smooth, weight,
weight_scale, bias, max_seq_len_q, max_seq_len_kv, softmax_scale,
is_causal, compute_dtype, block_seq);
ParallelAllReduce<FlashAttnSmoothQuantMatmul> parallel_rs(cncl_comm);
return parallel_rs(mm);
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,192 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "comm_overlap.h"
#include "kernels/moe/combine_result.mluh"
#include "torch_api/torch_ops_api.h"
#include "torch_api/utils.h"
namespace tmo {
namespace torch_api {
using namespace torch::indexing;
struct GroupGemmCombineResultSplitN {
const at::Tensor &a_;
const at::Tensor &b_;
const at::Tensor &m_list_;
const at::Tensor &combine_idx_;
const at::Tensor &combine_weight_;
const c10::optional<at::Tensor> &c_;
const c10::optional<at::Tensor> &alpha_;
const c10::optional<at::Tensor> &beta_;
const c10::optional<at::Tensor> &a_scale_;
const c10::optional<at::Tensor> &b_scale_;
const c10::optional<std::string> &data_type_;
int64_t num_token_;
int64_t topk_;
int64_t block_n_;
int64_t expert_num_;
int64_t n_;
bool has_c_;
bool has_alpha_;
bool has_beta_;
bool has_a_scale_;
bool has_b_scale_;
bool has_data_type_;
int64_t n_per_blk_;
cnrtQueue_t queue_;
at::Tensor output_buff_;
at::Tensor b_scale_trans_;
c10::optional<at::Tensor> b_offset_ = c10::nullopt;
cnnlDataType_t cnnl_dtype_;
GroupGemmCombineResultSplitN(const at::Tensor &a_tensor,
const at::Tensor &b_tensor,
const at::Tensor &m_list,
const at::Tensor &combine_idx,
const at::Tensor &combine_weight,
const c10::optional<at::Tensor> &c_tensor,
const c10::optional<at::Tensor> &alpha,
const c10::optional<at::Tensor> &beta,
const c10::optional<at::Tensor> &a_scale,
const c10::optional<at::Tensor> &b_scale,
const c10::optional<std::string> &data_type,
const int64_t num_token,
const int64_t topk,
const int64_t block_n)
: a_(a_tensor),
b_(b_tensor),
m_list_(m_list),
combine_idx_(combine_idx),
combine_weight_(combine_weight),
c_(c_tensor),
alpha_(alpha),
beta_(beta),
a_scale_(a_scale),
b_scale_(b_scale),
data_type_(data_type),
num_token_(num_token),
topk_(topk),
block_n_(block_n) {
has_c_ = c_.has_value();
has_alpha_ = alpha_.has_value();
has_beta_ = beta_.has_value();
has_a_scale_ = a_scale_.has_value();
has_b_scale_ = b_scale_.has_value();
has_data_type_ = data_type_.has_value();
auto b_shape = b_.sizes();
expert_num_ = b_.dim() == 3 ? b_shape[0] : b_shape[0] * b_shape[2];
n_ = b_shape[1];
auto k = b_.size(-1);
auto m = a_.size(0);
set_block(m);
queue_ = torch_mlu::getCurMLUStream();
auto a_options = a_.options();
if (has_data_type_) {
auto dtype = data_type_.value();
TORCH_CHECK(dtype == "float" || dtype == "half" || dtype == "bfloat16",
"data type must be 'float', 'half' or 'bfloat16'");
auto torch_dtype = str2TorchDtype(dtype);
cnnl_dtype_ = str2CnnlDtype(dtype);
output_buff_ = at::empty({block_n_, num_token_, n_per_blk_}, a_options.dtype(torch_dtype));
} else {
output_buff_ = at::empty({block_n_, num_token_, n_per_blk_}, a_options);
cnnl_dtype_ = getCnnlDataType(a_.scalar_type());
}
if (has_b_scale_) {
b_scale_trans_ = b_scale_.value().view({expert_num_, block_n_, -1});
b_scale_trans_ = b_scale_trans_.transpose(0, 1).contiguous();
}
auto b_offset = at::empty({expert_num_}, a_options.dtype(at::kLong).device(at::kCPU));
auto element_size = b_.element_size();
if (block_n_ > 1 && b_.dim() == 3) {
for (int64_t i = 0; i < expert_num_; i++) {
b_offset[i] = n_ * k * i * element_size;
}
b_offset_ = b_offset;
} else if (b_.dim() == 4) {
for (int64_t i = 0; i < expert_num_; i++) {
b_offset[i] = (i / b_shape[2] * b_.stride(0) + i % b_shape[2] * b_shape[3]) * element_size;
}
b_offset_ = b_offset;
}
}
auto split(const at::Tensor &a, int64_t dim, int64_t start, int64_t block) {
return a.narrow(dim, start, block);
}
at::Tensor forward(const int64_t block_id) {
auto gg_o =
group_gemm(a_, split(b_, 1, block_id * n_per_blk_, n_per_blk_), m_list_, c10::nullopt,
(has_c_) ? split(c_.value(), 1, block_id * n_per_blk_, n_per_blk_) : c_, alpha_,
beta_, a_scale_, has_b_scale_ ? b_scale_trans_[block_id] : b_scale_,
c10::nullopt, data_type_, c10::nullopt, b_offset_, num_token_);
tmo::invokeMoeCombineResultKernel(
queue_, getAtTensorPtr(output_buff_[block_id]), getAtTensorPtr(gg_o), nullptr, nullptr,
(float *)getAtTensorPtr(combine_weight_), nullptr, (int *)getAtTensorPtr(combine_idx_),
num_token_, topk_, expert_num_, n_per_blk_, 0, expert_num_, cnnl_dtype_);
return output_buff_[block_id];
}
void set_block(int64_t m) {
if (block_n_ < 1) {
if (m >= 4096 && m < 8192 && n_ % 2048 == 0) {
n_per_blk_ = 2048;
block_n_ = n_ / n_per_blk_;
} else if (m >= 8192 && n_ % 1024 == 0) {
n_per_blk_ = 1024;
block_n_ = n_ / n_per_blk_;
} else {
n_per_blk_ = n_;
block_n_ = 1;
}
} else {
TORCH_CHECK(n_ % block_n_ == 0, "n must be divisible by block_n");
n_per_blk_ = n_ / block_n_;
}
}
int64_t getLoopNum() const { return block_n_; }
at::Tensor getOutput() const { return output_buff_.transpose(0, 1).reshape({num_token_, n_}); }
};
at::Tensor group_gemm_combine_result_allreduce(int64_t cncl_comm,
const at::Tensor &a_tensor,
const at::Tensor &b_tensor,
const at::Tensor &m_list,
const at::Tensor &combine_idx,
const at::Tensor &combine_weight,
const c10::optional<at::Tensor> &c_tensor,
const c10::optional<at::Tensor> &alpha,
const c10::optional<at::Tensor> &beta,
const c10::optional<at::Tensor> &a_scale,
const c10::optional<at::Tensor> &b_scale,
const c10::optional<std::string> &data_type,
const int64_t num_token,
const int64_t topk,
const int64_t block_n) {
GroupGemmCombineResultSplitN gg(a_tensor, b_tensor, m_list, combine_idx, combine_weight, c_tensor,
alpha, beta, a_scale, b_scale, data_type, num_token, topk,
block_n);
ParallelAllReduce<GroupGemmCombineResultSplitN> parallel_rs(cncl_comm);
return parallel_rs(gg);
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,94 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "comm_overlap.h"
#include "torch_api/torch_ops_api.h"
namespace tmo {
namespace torch_api {
using namespace torch::indexing;
struct MatmulSplitM {
const at::Tensor &a_;
const at::Tensor &b_;
const c10::optional<at::Tensor> &bias_;
const c10::optional<at::Tensor> &c_;
const c10::optional<at::Tensor> &d_;
double alpha_;
double beta_;
int64_t block_m_;
bool has_res_;
bool has_output_;
at::Tensor output_;
MatmulSplitM(const at::Tensor &a,
const at::Tensor &b,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &c,
const c10::optional<at::Tensor> &d,
const double alpha,
const double beta,
const int64_t block_m)
: a_(a), b_(b), bias_(bias), c_(c), d_(d), alpha_(alpha), beta_(beta), block_m_(block_m) {
auto m = a_.size(0);
if (block_m_ > m) {
block_m_ = 1;
} else if (block_m_ < 1) {
block_m_ = m > 8192 ? 8 : (m < 2048 ? 1 : 4);
}
has_res_ = c_.has_value();
has_output_ = d_.has_value();
if (has_output_) {
output_ = d_.value();
} else {
output_ = at::empty({a.size(0), b.size(0)}, a.options());
}
}
auto split(const at::Tensor &a, const int64_t block_id) {
auto m = a.size(0);
auto m_per_blk = m / block_m_;
auto remain = m % block_m_;
auto start = block_id * m_per_blk + std::min(block_id, remain);
auto end = (block_id + 1) * m_per_blk + std::min(block_id + 1, remain);
return a.narrow(0, start, end - start);
}
at::Tensor forward(const int64_t block_id) {
auto Di = split(output_, block_id);
matmul(split(a_, block_id), b_, Di, bias_, has_res_ ? split(c_.value(), block_id) : c_, None,
"none", alpha_, beta_, true, true, 1.0, 1.0, false, true);
return Di;
}
int64_t getLoopNum() const { return block_m_; }
at::Tensor getOutput() const { return output_; }
at::Tensor getDSplit(const int64_t block_id) { return split(output_, block_id); }
};
at::Tensor matmul_allreduce(const int64_t cncl_comm,
const at::Tensor &a,
const at::Tensor &b,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &c,
const c10::optional<at::Tensor> &d,
const double alpha,
const double beta,
const int64_t block_m) {
MatmulSplitM mm(a, b, bias, c, d, alpha, beta, block_m);
ParallelAllReduce<MatmulSplitM> parallel_rs(cncl_comm);
return parallel_rs(mm);
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,198 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <atomic>
#include <vector>
#include "comm_overlap.h"
#include "torch_api/torch_ops_api.h"
namespace tmo {
namespace torch_api {
using namespace torch::indexing;
struct QuantMatmulSplitM {
const at::Tensor &a_;
const c10::optional<at::Tensor> &a_scale_;
const c10::optional<at::Tensor> &a_zero_;
const at::Tensor &b_;
const c10::optional<at::Tensor> &b_scale_;
const c10::optional<at::Tensor> &b_zero_;
const c10::optional<at::Tensor> &bias_;
const c10::optional<at::Tensor> &c_;
const c10::optional<at::Tensor> &c_scale_;
const c10::optional<at::Tensor> &c_zero_;
const c10::optional<at::Tensor> &output_scale_;
const c10::optional<at::Tensor> &output_zero_;
const c10::optional<std::string> &data_type_;
const c10::optional<at::Tensor> &d_;
const std::string &quant_algo_;
const std::string &a_quant_layout_;
const std::string &b_quant_layout_;
int64_t quant_bit_size_;
double alpha_;
double beta_;
bool trans_a_;
bool trans_b_;
int64_t block_m_;
bool has_a_scale_;
bool has_a_zero_;
bool has_b_scale_;
bool has_b_zero_;
bool has_bias_;
bool has_c_;
bool has_c_scale_;
bool has_c_zero_;
bool has_output_scale_;
bool has_output_zero_;
bool has_output_;
bool has_dtype_;
at::Tensor output_;
QuantMatmulSplitM(const at::Tensor &a_tensor,
const c10::optional<at::Tensor> &a_scale,
const c10::optional<at::Tensor> &a_zero,
const at::Tensor &b_tensor,
const c10::optional<at::Tensor> &b_scale,
const c10::optional<at::Tensor> &b_zero,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &c_tensor,
const c10::optional<at::Tensor> &c_scale,
const c10::optional<at::Tensor> &c_zero,
const c10::optional<at::Tensor> &gemm_output_scale,
const c10::optional<at::Tensor> &gemm_output_zero,
const c10::optional<std::string> &data_type,
const c10::optional<at::Tensor> &d,
const std::string &quant_algo,
const std::string &a_quant_layout,
const std::string &b_quant_layout,
int64_t quant_bit_size,
double alpha,
double beta,
bool trans_a,
bool trans_b,
const int64_t block_m)
: a_(a_tensor),
a_scale_(a_scale),
a_zero_(a_zero),
b_(b_tensor),
b_scale_(b_scale),
b_zero_(b_zero),
bias_(bias),
c_(c_tensor),
c_scale_(c_scale),
c_zero_(c_zero),
output_scale_(gemm_output_scale),
output_zero_(gemm_output_zero),
data_type_(data_type),
d_(d),
quant_algo_(quant_algo),
a_quant_layout_(a_quant_layout),
b_quant_layout_(b_quant_layout),
quant_bit_size_(quant_bit_size),
alpha_(alpha),
beta_(beta),
trans_a_(trans_a),
trans_b_(trans_b),
block_m_(block_m) {
auto m = a_.size(0);
if (block_m_ > m) {
block_m_ = 1;
} else if (block_m_ < 1) {
block_m_ = m > 8192 ? 8 : (m < 2048 ? 1 : 4);
}
has_a_scale_ = a_scale_.has_value();
has_a_zero_ = a_zero_.has_value();
has_b_scale_ = b_scale_.has_value();
has_b_zero_ = b_zero_.has_value();
has_bias_ = bias_.has_value();
has_c_ = c_.has_value();
has_c_scale_ = c_scale_.has_value();
has_c_zero_ = c_zero_.has_value();
has_output_scale_ = output_scale_.has_value();
has_output_zero_ = output_zero_.has_value();
has_output_ = d_.has_value();
has_dtype_ = data_type_.has_value();
auto a_options = a_.options();
if (has_output_) {
output_ = d_.value();
} else if (has_dtype_) {
auto dtype = data_type_.value();
TORCH_CHECK(dtype == "float" || dtype == "half" || dtype == "bfloat16",
"data type must be 'float', 'half' or 'bfloat16'");
auto torch_dtype = str2TorchDtype(dtype);
output_ = at::empty({a_.size(0), b_.size(0)}, a_options.dtype(torch_dtype));
} else {
output_ = at::empty({a_.size(0), b_.size(0)}, a_options);
}
}
auto split(const at::Tensor &a, const int64_t block_id) {
auto m = a.size(0);
auto m_per_blk = m / block_m_;
auto remain = m % block_m_;
auto start = block_id * m_per_blk + std::min(block_id, remain);
auto end = (block_id + 1) * m_per_blk + std::min(block_id + 1, remain);
return a.narrow(0, start, end - start);
}
at::Tensor forward(const int64_t block_id) {
auto Di = quant_matmul(
split(a_, block_id), has_a_scale_ ? split(a_scale_.value(), block_id) : a_scale_,
has_a_zero_ ? split(a_zero_.value(), block_id) : a_zero_, b_, b_scale_, b_zero_, bias_,
has_c_ ? split(c_.value(), block_id) : c_,
has_c_scale_ ? split(c_scale_.value(), block_id) : c_scale_,
has_c_zero_ ? split(c_zero_.value(), block_id) : c_zero_, output_scale_, output_zero_,
data_type_, split(output_, block_id), quant_algo_, a_quant_layout_, b_quant_layout_,
quant_bit_size_, "none", false, 1.0, alpha_, beta_, trans_a_, trans_b_);
return Di;
}
int64_t getLoopNum() const { return block_m_; }
at::Tensor getOutput() const { return output_; }
};
at::Tensor quant_matmul_allreduce(const int64_t cncl_comm,
const at::Tensor &a_tensor,
const c10::optional<at::Tensor> &a_scale,
const c10::optional<at::Tensor> &a_zero,
const at::Tensor &b_tensor,
const c10::optional<at::Tensor> &b_scale,
const c10::optional<at::Tensor> &b_zero,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &c_tensor,
const c10::optional<at::Tensor> &c_scale,
const c10::optional<at::Tensor> &c_zero,
const c10::optional<at::Tensor> &gemm_output_scale,
const c10::optional<at::Tensor> &gemm_output_zero,
const c10::optional<std::string> &data_type,
const c10::optional<at::Tensor> &d,
const std::string &quant_algo,
const std::string &a_quant_layout,
const std::string &b_quant_layout,
int64_t quant_bit_size,
double alpha,
double beta,
bool trans_a,
bool trans_b,
const int64_t block_m) {
TORCH_CHECK(!trans_a && trans_b, "trans_a must be false and trans_b must be true");
QuantMatmulSplitM quant_mm(a_tensor, a_scale, a_zero, b_tensor, b_scale, b_zero, bias, c_tensor,
c_scale, c_zero, gemm_output_scale, gemm_output_zero, data_type, d,
quant_algo, a_quant_layout, b_quant_layout, quant_bit_size, alpha,
beta, trans_a, trans_b, block_m);
ParallelAllReduce<QuantMatmulSplitM> parallel_rs(cncl_comm);
return parallel_rs(quant_mm);
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,89 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/copy_blocks.mluh"
#include <map>
#include <vector>
#include "torch_ops_api.h"
#include "utils.h"
namespace tmo {
namespace torch_api {
void copy_blocks(const std::vector<torch::Tensor> &k_caches,
const std::vector<torch::Tensor> &v_caches,
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping_dict) {
// Create block mapping array.
std::vector<int32_t> block_mapping_vec;
for (const auto &item : block_mapping_dict) {
int64_t src_block_number = item.key();
auto value_vec = item.value().vec();
for (int64_t dst_block_number : value_vec) {
block_mapping_vec.push_back(int32_t(src_block_number));
block_mapping_vec.push_back(int32_t(dst_block_number));
}
}
TORCH_CHECK(!k_caches.empty() && !block_mapping_vec.empty(),
"k_caches and block_mapping_vec can not be empty.")
int32_t num_layers = k_caches.size();
for (auto i = 0; i < num_layers; i++) {
TORCH_CHECK(k_caches[i].dim() == 4, "every layer k_cache must be 4d.")
}
// check same device and tensor type
TORCH_CHECK(isMlu(k_caches[0]), "k_caches must on mlu.");
TORCH_CHECK(k_caches[0].dtype() == torch::kInt8 || k_caches[0].dtype() == torch::kUInt8 ||
k_caches[0].dtype() == torch::kInt16 || k_caches[0].dtype() == torch::kInt32 ||
k_caches[0].dtype() == torch::kLong || k_caches[0].dtype() == torch::kFloat16 ||
k_caches[0].dtype() == torch::kFloat32 || k_caches[0].dtype() == torch::kBFloat16,
"data type only supports torch::kInt8, torch::kUInt8, torch::kInt16, torch::kInt32, "
"torch::kLong, torch::kFloat16, torch::kFloat32 and torch::kBFloat16");
if (!v_caches.empty()) {
TORCH_CHECK(k_caches.size() == v_caches.size(),
"k_caches size must equal to "
"v_caches size if v_caches is not none.")
TORCH_CHECK(isMlu(v_caches[0]), "v_caches must on mlu.");
for (auto i = 0; i < num_layers; i++) {
TORCH_CHECK(k_caches[i].dtype() == v_caches[i].dtype(),
"the data type of k_caches and v_caches are not the same.")
TORCH_CHECK(k_caches[i].dim() == v_caches[i].dim(),
"every layer k_cache dim must equal to v_cache dim.")
// check shape
TORCH_CHECK(k_caches[i][0].numel() == v_caches[0][0].numel(),
"the block_size of k_caches and v_caches are not the same.")
}
}
const torch_mlu::mlu::MLUGuard device_guard(k_caches[0].device());
auto queue = torch_mlu::getCurMLUStream();
size_t block_size_bytes = k_caches[0][0].numel() * k_caches[0].element_size();
std::vector<void *> new_key_caches, new_value_caches;
for (auto i = 0; i < num_layers; ++i) {
new_key_caches.push_back(k_caches[i].data_ptr());
if (!v_caches.empty()) {
new_value_caches.push_back(v_caches[i].data_ptr());
}
}
TMO_KERNEL_CHECK_FATAL(invokeCopyBlocksKernel(queue, new_key_caches, new_value_caches,
block_mapping_vec, block_size_bytes));
}
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> copy_blocks_out_of_place(
const std::vector<at::Tensor> &k_caches,
const std::vector<at::Tensor> &v_caches,
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping) {
copy_blocks(k_caches, v_caches, block_mapping);
return std::make_tuple(k_caches, v_caches);
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,176 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/dequant_from_linear_cache.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
#define CHECK_TENSOR_DTYPE(x, expected_type) \
TORCH_CHECK(x.scalar_type() == expected_type, "Tensor " #x " type should be ", \
torchDtype2Str(expected_type), ".");
void dequant_from_linear_cache(
at::Tensor &key, // [total_seqlen, head_num, head_size]
const c10::optional<at::Tensor> &value, // same as above
const at::Tensor &key_cache, // [max_batch_size, head_num, cache_mem_len, head_size]
const c10::optional<at::Tensor> &value_cache, // same as above
const at::Tensor &key_quant_scale, // quant_mode is 0: [head_num, head_size]
// quant_mode is 1:
// [max_batch_size, head_num, cache_mem_len]
const c10::optional<at::Tensor> &value_quant_scale, // same as above
const at::Tensor &context_lengths,
const int64_t max_context_len,
const c10::optional<at::Tensor> &context_seq_offset,
const c10::optional<at::Tensor> &cache_bs_id,
const c10::optional<at::Tensor> &cache_seq_offset,
const int64_t quant_mode, // 0:per_channel, 1:per_head
const int64_t quant_bit) { // 4 or 8
// check same attr for tensors
checkTensorSameAttr<TensorAttr::DEVICE>(key, key_cache, key_quant_scale, value, value_cache,
value_quant_scale, context_lengths, context_seq_offset,
cache_bs_id, cache_seq_offset);
checkTensorSameAttr<TensorAttr::DTYPE>(context_lengths, context_seq_offset, cache_bs_id,
cache_seq_offset);
// check quant parameters first
TORCH_CHECK(quant_mode >= 0 && quant_mode <= 1, "quantization mode support 0 and 1.");
TORCH_CHECK(quant_bit == 4 || quant_bit == 8, "quantization bit width support 4 and 8.");
/****************************************check key***************************************/
// check dtype
TORCH_CHECK(key.scalar_type() == torch::kFloat16 || key.scalar_type() == torch::kBFloat16,
"Tensor key type should be half or bfloat16.");
CHECK_TENSOR_DTYPE(key_cache, torch::kInt8);
CHECK_TENSOR_DTYPE(key_quant_scale, torch::kFloat32);
// check_contiguous
TORCH_CHECK(key.stride(-1) == 1, "Tensor key last dim must be contiguous.");
CHECK_TENSOR_CONTIGUOUS(key_cache);
CHECK_TENSOR_CONTIGUOUS(key_quant_scale);
// check shape
TORCH_CHECK(key.dim() == 3, "The dimensions of tensor key only support 3.");
TORCH_CHECK(key_cache.dim() == 4, "The dimensions of tensor key_cache only supports 4.");
TORCH_CHECK(context_lengths.dim() == 1,
"The dimensions of tensor context_lengths only supports 1.");
const int32_t total_seqlen = key.size(0);
const int32_t head_num = key.size(1);
const int32_t head_size = key.size(2);
const int32_t max_batch_size = key_cache.size(0);
const int32_t cache_mem_len = key_cache.size(2);
const int32_t batch_size = context_lengths.size(0);
CHECK_SHAPE(context_lengths, batch_size);
TORCH_CHECK(max_batch_size >= batch_size,
"max_batch_size should be greater than or equal to batch_size.");
TORCH_CHECK(cache_mem_len % 2 == 0, "cache_mem_len should be a multiply of 2.");
if (quant_mode == 0) {
CHECK_SHAPE(key_quant_scale, head_num, head_size);
} else if (quant_mode == 1) {
CHECK_SHAPE(key_quant_scale, max_batch_size, head_num, cache_mem_len);
}
if (quant_bit == 4) {
TORCH_CHECK(head_size % 2 == 0, "head_size should be a multiply of 2 if quant_bit is 4.");
CHECK_SHAPE(key_cache, max_batch_size, head_num, cache_mem_len, head_size >> 1);
} else {
CHECK_SHAPE(key_cache, max_batch_size, head_num, cache_mem_len, head_size);
}
/***************************************check value***************************************/
if (value.has_value() || value_cache.has_value() || value_quant_scale.has_value()) {
TORCH_CHECK(value.has_value() && value_cache.has_value() && value_quant_scale.has_value(),
"value, value_cache, and value_quant_scale must all exists.")
}
if (value_cache.has_value()) {
checkTensorSameAttr<TensorAttr::DTYPE>(key, value);
checkTensorSameAttr<TensorAttr::DTYPE>(key_cache, value_cache);
checkTensorSameAttr<TensorAttr::DTYPE>(key_quant_scale, value_quant_scale);
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(value_cache);
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(value_quant_scale);
TORCH_CHECK(key_quant_scale.dim() == value_quant_scale.value().dim(),
"value_cache_quant_scale dim should keep same with key_cache_quant_scale.");
CHECK_SHAPE(value.value(), total_seqlen, head_num, head_size);
if (quant_mode == 0) {
CHECK_SHAPE(value_quant_scale.value(), head_num, head_size);
} else {
CHECK_SHAPE(value_quant_scale.value(), max_batch_size, head_num, cache_mem_len);
}
if (quant_bit == 4) {
CHECK_SHAPE(value_cache.value(), max_batch_size, head_num, cache_mem_len >> 1, head_size);
} else {
CHECK_SHAPE(value_cache.value(), max_batch_size, head_num, cache_mem_len, head_size);
}
for (int i = 0; i < key.dim(); i++) {
TORCH_CHECK(value.value().stride(i) == key.stride(i),
"key and value must have same stride along axi ", i, ".");
}
}
TORCH_CHECK(context_lengths.scalar_type() == torch::kInt32,
"context_lengths type need be torch::kInt32.");
/*********************************check optional tensor***********************************/
const void *context_seq_offset_ptr = nullptr;
if (context_seq_offset.has_value()) {
CHECK_SHAPE(context_seq_offset.value(), batch_size);
context_seq_offset_ptr = context_seq_offset.value().data_ptr();
} else {
TORCH_CHECK(batch_size <= 1024,
"batch_size greater than 1024 not support when context_seq_offset is None.");
}
const void *cache_bs_id_ptr = nullptr;
if (cache_bs_id.has_value()) {
CHECK_SHAPE(cache_bs_id.value(), batch_size);
cache_bs_id_ptr = cache_bs_id.value().data_ptr();
}
const void *cache_seq_offset_ptr = nullptr;
if (cache_seq_offset.has_value()) {
CHECK_SHAPE(cache_seq_offset.value(), batch_size);
cache_seq_offset_ptr = cache_seq_offset.value().data_ptr();
}
// propare parameters before calling invokeDequantFromLinearCache
const int32_t key_group_num = 1;
const int32_t value_group_num = 1;
const size_t context_head_stride = key.stride(1);
const size_t context_seq_stride = key.stride(0);
const size_t cache_bs_stride = key_cache.stride(0);
const size_t cache_head_stride = key_cache.stride(1);
const size_t key_cache_seq_stride = key_cache.stride(2);
const size_t value_cache_seq_stride = value_cache.has_value() ? value_cache.value().stride(2) : 0;
const size_t cache_scale_bs_stride = key_quant_scale.dim() == 3 ? key_quant_scale.stride(0) : 0;
const size_t cache_scale_head_stride =
key_quant_scale.dim() == 3 ? key_quant_scale.stride(1) : key_quant_scale.stride(0);
const torch_mlu::mlu::MLUGuard device_guard(key.device());
auto data_dtype = getCnnlDataType(key.scalar_type());
auto queue = torch_mlu::getCurMLUStream();
// run forward
TMO_KERNEL_CHECK_FATAL(invokeDequantFromLinearCache(
queue, getAtTensorPtr(key), getAtTensorPtr(value), getAtTensorPtr(key_cache),
getAtTensorPtr(value_cache), getAtTensorPtr(key_quant_scale),
getAtTensorPtr(value_quant_scale), getAtTensorPtr(context_lengths), context_seq_offset_ptr,
cache_bs_id_ptr, cache_seq_offset_ptr, (int)max_context_len, batch_size, head_num,
key_group_num, value_group_num, cache_mem_len, head_size, quant_mode, quant_bit,
context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride,
key_cache_seq_stride, value_cache_seq_stride, cache_scale_bs_stride, cache_scale_head_stride,
data_dtype));
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,162 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/dequant_from_paged_cache.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
#define CHECK_TENSOR_DTYPE(x, expected_type) \
TORCH_CHECK(x.scalar_type() == expected_type, "Tensor " #x " type should be ", \
torchDtype2Str(expected_type), ".");
void dequant_from_paged_cache(at::Tensor &key, // [total_seqlen, head_num, head_size]
const c10::optional<at::Tensor> &value, // same as above
const at::Tensor &key_cache, // [token_num, head_num, block_size]
const c10::optional<at::Tensor> &value_cache, // same as above
// [token_num, head_num, block_size, head_size] for per-token
// [head_num, head_size] for per-channel
const at::Tensor &key_cache_quant_scale,
const c10::optional<at::Tensor> &value_cache_quant_scale,
const at::Tensor &context_lengths, // [batch_size]
int64_t max_context_len,
const c10::optional<at::Tensor> &context_seq_offset, // [batch_size]
const at::Tensor &block_tables, // [batch_size, max_block_num]
int64_t quant_mode, // 0 is per-channel, 1 is per-token
int64_t quant_bit) { // quantization bit only support 8
// check same attr for tensors
checkTensorSameAttr<TensorAttr::DEVICE>(key, key_cache, key_cache_quant_scale, value, value_cache,
value_cache_quant_scale, context_lengths,
context_seq_offset, block_tables);
checkTensorSameAttr<TensorAttr::DTYPE>(context_lengths, context_seq_offset, block_tables);
// check quant parameters first
TORCH_CHECK(quant_bit == 8, "quantization bit width only supports 8.");
TORCH_CHECK(quant_mode >= 0 && quant_mode <= 1, "quantization mode support 0 and 1.");
/***************************************check key***************************************/
// check dtype
TORCH_CHECK(key.scalar_type() == torch::kFloat16 || key.scalar_type() == torch::kBFloat16,
"Tensor key type should be half or bfloat16.");
CHECK_TENSOR_DTYPE(key_cache, torch::kInt8);
CHECK_TENSOR_DTYPE(key_cache_quant_scale, torch::kFloat32);
CHECK_TENSOR_DTYPE(block_tables, torch::kInt32);
// check contiguous
TORCH_CHECK(key.stride(-1) == 1, "Tensor key last dim must be contiguous.")
CHECK_TENSOR_CONTIGUOUS(key_cache);
CHECK_TENSOR_CONTIGUOUS(key_cache_quant_scale);
CHECK_TENSOR_CONTIGUOUS(block_tables);
// check shape
TORCH_CHECK(key.dim() == 3, "The dimensions of tensor key only supports 3.");
TORCH_CHECK(key_cache.dim() == 4, "The dimensions of tensor key_cache only supports 4.");
TORCH_CHECK(context_lengths.dim() == 1,
"The dimensions of tensor context_lengths only supports 1.");
TORCH_CHECK(block_tables.dim() == 2, "The dimensions of tensor block_tables only supports 2.");
const int32_t token_num = key.size(0);
const int32_t head_num = key.size(1);
const int32_t head_size = key.size(2);
const int32_t block_num = key_cache.size(0);
const int32_t block_size = key_cache.size(2);
const int32_t batch_size = context_lengths.size(0);
const int32_t max_block_num = block_tables.size(1);
CHECK_SHAPE(key_cache, block_num, head_num, block_size, head_size);
int64_t kv_cache_range =
(int64_t)block_num * head_num * block_size * head_size * key_cache.element_size();
// kernel use uint32_t to calculate offsets
TORCH_CHECK(kv_cache_range <= UINT32_MAX, "The addressing range of key_cache cannot exceed 4GB.");
if (quant_mode == 0) {
CHECK_SHAPE(key_cache_quant_scale, head_num, head_size);
} else if (quant_mode == 1) {
CHECK_SHAPE(key_cache_quant_scale, block_num, head_num, block_size);
}
/***************************************check value***************************************/
if (value.has_value() || value_cache.has_value() || value_cache_quant_scale.has_value()) {
TORCH_CHECK(value.has_value() && value_cache.has_value() && value_cache_quant_scale.has_value(),
"value, value_cache, and value_cache_quant_scale must all exists.")
}
if (value_cache.has_value()) {
checkTensorSameAttr<TensorAttr::DTYPE>(key, value);
checkTensorSameAttr<TensorAttr::DTYPE>(key_cache, value_cache);
checkTensorSameAttr<TensorAttr::DTYPE>(key_cache_quant_scale, value_cache_quant_scale);
TORCH_CHECK(value.value().stride(-1) == 1, "Tensor value last dim must be contiguous.")
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(value_cache);
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(value_cache_quant_scale);
CHECK_SHAPE(value.value(), token_num, head_num, head_size);
TORCH_CHECK(value_cache_quant_scale.value().dim() == key_cache_quant_scale.dim(),
"value_cache_quant_scale dim should keep same with key_cache_quant_scale.");
if (quant_mode == 0) {
CHECK_SHAPE(value_cache_quant_scale.value(), head_num, head_size);
} else {
CHECK_SHAPE(value_cache_quant_scale.value(), block_num, head_num, block_size);
}
CHECK_SHAPE(value_cache.value(), block_num, head_num, block_size, head_size);
for (int i = 0; i < key.dim(); i++) {
TORCH_CHECK(value.value().stride(i) == key.stride(i),
"key and value must have same stride along axi ", i, ".");
}
}
/*********************************check index tensor***********************************/
CHECK_SHAPE(context_lengths, batch_size);
CHECK_SHAPE(block_tables, batch_size, max_block_num);
TORCH_CHECK(max_context_len <= block_size * max_block_num,
"max_context_len should smaller than or equal to block_size * max_block_num.");
/*********************************check optional tensor***********************************/
const void *context_seq_offset_ptr = nullptr;
if (context_seq_offset.has_value()) {
CHECK_SHAPE(context_seq_offset.value(), batch_size);
context_seq_offset_ptr = context_seq_offset.value().data_ptr();
} else {
TORCH_CHECK(batch_size <= 1024,
"batch_size greater than 1024 not support when context_seq_offset is None.");
}
// propare parameters before calling invokeDequantFromLinearCache
const int32_t key_group_num = 1;
const int32_t value_group_num = 1;
const size_t context_head_stride = key.stride(1);
const size_t context_seq_stride = key.stride(0);
const size_t cache_bs_stride = key_cache.stride(0);
const size_t cache_head_stride = key_cache.stride(1);
const size_t key_cache_seq_stride = key_cache.stride(2);
const size_t value_cache_seq_stride = value_cache.has_value() ? value_cache.value().stride(2) : 0;
const size_t cache_scale_bs_stride =
key_cache_quant_scale.dim() == 3 ? key_cache_quant_scale.stride(0) : 0;
const size_t cache_scale_head_stride = key_cache_quant_scale.dim() == 3
? key_cache_quant_scale.stride(1)
: key_cache_quant_scale.stride(0);
const torch_mlu::mlu::MLUGuard device_guard(key.device());
auto queue = torch_mlu::getCurMLUStream();
auto data_dtype = getCnnlDataType(key.scalar_type());
// run forward
TMO_KERNEL_CHECK_FATAL(invokeDequantFromPagedCache(
queue, getAtTensorPtr(key), getAtTensorPtr(value), getAtTensorPtr(key_cache),
getAtTensorPtr(value_cache), getAtTensorPtr(key_cache_quant_scale),
getAtTensorPtr(value_cache_quant_scale), getAtTensorPtr(context_lengths),
context_seq_offset_ptr, getAtTensorPtr(block_tables), (int)max_context_len, max_block_num,
batch_size, head_num, key_group_num, value_group_num, block_size, head_size, quant_mode,
quant_bit, context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride,
key_cache_seq_stride, value_cache_seq_stride, cache_scale_bs_stride, cache_scale_head_stride,
data_dtype));
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,156 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common/utils.h"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
/*
torch_ffn_api API funtion in pytorch op is:
class FeedForward(torch.nn.Module):
def __init__(self, hidden_size: int, inner_size: int, act_mode: str,
bias = False, gated = False):
super(FeedForward, self).__init__()
self.up_linear = torch.nn.Linear(hidden_size, inner_size, bias)
self.gated = gated
if self.gated:
self.gated_linear = torch.nn.Linear(hidden_size, inner_size, bias)
self.down_linear = torch.nn.Linear(inner_size, hidden_size, bias)
self.act = act_mode_dict[act_mode]
def forward(self, x):
act_out = self.act(self.up_linear(x).float()).to(x.dtype)
return self.down_linear(act_out * self.gated_linear(x)) \
if self.gated else self.down_linear(act_out)
Demo of pytorch python module in TGI:
class LlamaBtMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.act = config.hidden_act
self.gate_proj_weight = weights.get_multi_weights_col(
f"{prefix}.gate_proj", quantize=config.quantize, dim=dim
)
self.gate_proj_bias = None
self.up_proj_weight = weights.get_multi_weights_col(
f"{prefix}.up_proj", quantize=config.quantize, dim=dim
)
self.up_proj_bias = None
self.down_proj_weight = weights.get_multi_weights_col(
f"{prefix}.down_proj", quantize=config.quantize, dim=dim
)
self.down_proj_bias = None
def forward(self, hidden_states):
return tmo.ffn(hidden_states, self.up_proj_weight, self.up_proj_bias,
self.down_proj_weight, self.down_proj_bias, self.gate_proj_weight,
self.gate_proj_bias, self.act)
*/
// act_mode now only support silu, gelu and relu.
// std::string pytorch func
// silu --> nn.SiLU
// gelu --> nn.functional.gelu
// relu --> nn.ReLU
// Maybe aligned with transformer act mode later.
// https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L200
// Dimensions of ffn .
// Dimension of input: [batch_num, seq_len, hidden_size]
// Dimension of up_fc_filters: [filter_size, hidden_size]
// Dimension of up_fc_bias: [filter_size]
// Dimension of down_fc_filters: [hidden_size, filter_size]
// Dimension of down_fc_bias: [hidden_size]
// Dimension of gated_fc_filters: [filter_size, hidden_size] only for gated fnn
// Dimension of gated_fc_bias: [filter_size] only for gated fnn
// Dimension of layer norm weight: [seq_len, hidden_size] only for layer norm fused
// Dimension of layer norm bias: [seq_len, hidden_size] only for layer norm fused
// Dimension of output is same as input.
// Dimension of \b output: [batch_num, seq_len, hidden_size]
// Fused layer norm is not support now. And only support fused
// per layer norm later.
at::Tensor ffn(const at::Tensor &input,
const at::Tensor &up_fc_weight,
const c10::optional<at::Tensor> &up_fc_bias,
const at::Tensor &down_proj_weight,
const c10::optional<at::Tensor> &down_proj_bias,
const c10::optional<at::Tensor> &gate_up_proj_weight,
const c10::optional<at::Tensor> &gate_up_proj_bias,
const c10::optional<at::Tensor> &layernorm_weight,
const c10::optional<at::Tensor> &layernorm_bias,
const std::string &act_mode,
const std::string &residual_is,
double eps,
double alpha,
double beta) {
// Check tensor type and tensor device.
checkTensorSameAttr<TensorAttr::ALL>(input, up_fc_weight, up_fc_bias, down_proj_weight,
down_proj_bias, gate_up_proj_weight, gate_up_proj_bias,
layernorm_weight, layernorm_bias);
// Check dims.
const int64_t nDim = input.dim();
at::Tensor input_view = input;
if (nDim == 2) input_view = input_view.unsqueeze(0);
// Check contiguous.
CHECK_TENSOR_CONTIGUOUS(input_view)
const torch_mlu::mlu::MLUGuard device_guard(input_view.device());
at::Tensor output = at::empty_like(input_view);
// Convert torch tensor to tensor desc
auto descs = createTensorDescs(
{input_view, up_fc_weight, up_fc_bias.value_or(torch::Tensor()), down_proj_weight,
down_proj_bias.value_or(torch::Tensor()), gate_up_proj_weight.value_or(torch::Tensor()),
gate_up_proj_bias.value_or(torch::Tensor()), layernorm_weight.value_or(torch::Tensor()),
layernorm_bias.value_or(torch::Tensor()), output});
bool has_ln = layernorm_weight.has_value();
TORCH_CHECK(residual_is == "input" || residual_is == "normed_input" || residual_is == "none",
"residual_is must be 'input' or 'normed_input' or 'none'.")
bool has_residual = residual_is != "none";
auto ln_res_mode = tmo::lnres::makeLnresEnum(has_ln, has_residual, residual_is == "input");
auto compute_type = getCnnlDataType(input_view.scalar_type());
tmo::op_desc::FeedForwardDesc ffn_desc(ln_res_mode, act_mode, compute_type, eps, alpha, beta);
if (gate_up_proj_weight.has_value() && gate_up_proj_weight->defined()) {
CNNL_CHECK_FATAL(cnnlSetTransformerFeedForwardDescriptorGateFiltersBias(
ffn_desc, descs[5].get(), getAtTensorPtr(gate_up_proj_weight), descs[6].get(),
getAtTensorPtr(gate_up_proj_bias)));
}
// Get current handle.
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
// Get workspace size and malloc workspace.
size_t workspace_size = 0;
CNNL_CHECK_FATAL(cnnlGetTransformerFeedForwardWorkspaceSize_v2(
handle, ffn_desc, descs[0].get(), descs[1].get(), nullptr, ffn_desc, &workspace_size));
auto workspace =
at::empty({static_cast<int64_t>(workspace_size)}, input.options().dtype(at::kByte));
// forward
CNNL_CHECK_FATAL(cnnlTransformerFeedForward(
handle, ffn_desc, ffn_desc, nullptr, descs[0].get(), getAtTensorPtr(input_view),
descs[1].get(), getAtTensorPtr(up_fc_weight), descs[2].get(), getAtTensorPtr(up_fc_bias),
descs[3].get(), getAtTensorPtr(down_proj_weight), descs[4].get(),
getAtTensorPtr(down_proj_bias), descs[7].get(), getAtTensorPtr(layernorm_weight),
descs[8].get(), getAtTensorPtr(layernorm_bias), getAtTensorPtr(workspace), workspace_size,
descs[9].get(), getAtTensorPtr(output)));
return nDim == 2 ? output.squeeze_(0) : output;
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,117 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void flash_attention(const at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
const at::Tensor &output,
const c10::optional<at::Tensor> &output_lse,
const c10::optional<at::Tensor> &cu_seq_lens_q,
const c10::optional<at::Tensor> &cu_seq_lens_kv,
const c10::optional<at::Tensor> &alibi_slope,
const c10::optional<at::Tensor> &attn_bias,
const c10::optional<at::Tensor> &k_cache_quant_scale,
const c10::optional<at::Tensor> &v_cache_quant_scale,
const c10::optional<at::Tensor> &block_tables,
const int64_t max_seq_len_q,
const int64_t max_seq_len_kv,
const double softmax_scale,
const bool is_causal,
const int64_t window_size_left,
const int64_t window_size_right,
const std::string &compute_dtype,
bool return_lse) {
TORCH_CHECK(compute_dtype == "float" || compute_dtype == "half" || compute_dtype == "bfloat16",
"compute_dtype must be 'float', 'half' or 'bfloat16'.");
TORCH_CHECK(k_cache_quant_scale.has_value() == false, "k_cache_scale for reserve.");
TORCH_CHECK(v_cache_quant_scale.has_value() == false, "v_cache_scale for reserve.");
bool has_block_table = block_tables.has_value();
bool is_pack = cu_seq_lens_q.has_value();
int64_t batch = is_pack ? cu_seq_lens_q.value().size(0) - 1 : q.size(0);
int qk_head_size = q.size(-1);
int v_head_size = v.size(-1);
// Check tensor type and tensor device.
checkTensorSameAttr<TensorAttr::ALL>(q, k, v, output);
// 3d for packed
TORCH_CHECK(q.dim() == 3 || q.dim() == 4, "query must be 3d or 4d.");
if (has_block_table) {
TORCH_CHECK(block_tables.value().dim() == 2, "block_tables must be 2d.");
TORCH_CHECK(k.dim() == 4, "with block table, key_cache must be 4d.");
TORCH_CHECK(v.dim() == 4, "with block table, value_cache must be 4d.");
int max_num_blocks_per_seq = block_tables.value().size(1);
int num_blocks = k.size(0);
int block_size = k.size(2);
int k_head_num = k.size(1);
CHECK_SHAPE(k, num_blocks, k_head_num, block_size, qk_head_size);
CHECK_SHAPE(v, num_blocks, k_head_num, block_size, v_head_size);
if (max_num_blocks_per_seq > 1) { // paged
CHECK_SHAPE(cu_seq_lens_kv.value(), batch + 1);
}
} else {
// 3d for packed
TORCH_CHECK(k.dim() == 3 || k.dim() == 4, "key_cache must be 3d or 4d.");
TORCH_CHECK(v.dim() == 3 || v.dim() == 4, "value_cache must be 3d or 4d.");
if (k.dim() == 3) { // packed_kv
CHECK_SHAPE(cu_seq_lens_kv.value(), batch + 1);
}
}
// Convert torch tensor to tensor descs
auto descs = createTensorDescs(
{q, k, v, cu_seq_lens_q.value_or(at::Tensor()), cu_seq_lens_kv.value_or(at::Tensor()),
alibi_slope.value_or(at::Tensor()), attn_bias.value_or(at::Tensor()),
block_tables.value_or(at::Tensor()), output, output_lse.value_or(at::Tensor())});
// Get current handle.
const torch_mlu::mlu::MLUGuard device_guard(q.device());
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
// Get workspace size and malloc workspace.
cnnlDataType_t cnnl_compute_dtype = compute_dtype == "float" ? CNNL_DTYPE_FLOAT
: compute_dtype == "half" ? CNNL_DTYPE_HALF
: CNNL_DTYPE_BFLOAT16;
size_t workspace_size = 0;
CNNL_CHECK_FATAL(cnnlGetScaledDotProductAttnWorkspaceSize_v2(
handle, nullptr /*op_desc*/, nullptr /*quant_desc*/, descs[0].get(), descs[1].get(),
descs[2].get(), descs[3].get(), descs[4].get(), descs[6].get(), descs[5].get(),
descs[7].get(), max_seq_len_q, max_seq_len_kv, is_causal, window_size_left, window_size_right,
return_lse, CNNL_ACTIVATION_FAST, cnnl_compute_dtype, &workspace_size));
auto workspace = at::empty({static_cast<int64_t>(workspace_size)}, q.options().dtype(at::kByte));
int64_t total_q = is_pack ? q.size(0) : q.size(0) * q.size(1);
int64_t total_k =
has_block_table ? k.size(0) * k.size(2) : (k.dim() == 3 ? k.size(0) : k.size(0) * k.size(1));
int64_t head_q = q.size(-2);
int64_t head_k = has_block_table ? k.size(1) : k.size(-2);
cnnlDataType_t data_dtype = getCnnlDataType(q.scalar_type());
FlashAttnTheory obj(batch, total_q, total_k, head_q, head_k, qk_head_size, v_head_size, is_causal,
data_dtype);
cnpxPush(obj);
// call cnnl extra op.
CNNL_CHECK_FATAL(cnnlScaledDotProductAttn_v3(
handle, nullptr, nullptr, descs[0].get(), getAtTensorPtr(q), descs[1].get(),
getAtTensorPtr(k), descs[2].get(), getAtTensorPtr(v), descs[3].get(),
getAtTensorPtr(cu_seq_lens_q), descs[4].get(), getAtTensorPtr(cu_seq_lens_kv), nullptr,
nullptr, descs[6].get(), getAtTensorPtr(attn_bias), descs[5].get(),
getAtTensorPtr(alibi_slope), descs[7].get(), getAtTensorPtr(block_tables), max_seq_len_q,
max_seq_len_kv, is_causal, window_size_left, window_size_right, softmax_scale,
CNNL_ACTIVATION_FAST, cnnl_compute_dtype, getAtTensorPtr(workspace), workspace_size,
return_lse, descs[9].get(), getAtTensorPtr(output_lse), descs[8].get(),
getAtTensorPtr(output)));
cnpxPop();
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,134 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void fused_layernorm(const at::Tensor &input,
const at::Tensor &output,
const c10::optional<at::Tensor> &residual,
const c10::optional<at::Tensor> &gamma,
const c10::optional<at::Tensor> &beta,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &quant_scale,
const c10::optional<at::Tensor> &residual_out,
const c10::optional<at::Tensor> &smooth_quant_scale,
const std::string &norm_mode,
double eps,
bool store_output_before_norm,
bool dynamic_quant) {
// check device and dtype
checkTensorSameAttr<TensorAttr::ALL>(input, residual, gamma, beta, bias);
checkTensorSameAttr<TensorAttr::DEVICE>(input, output, residual_out, smooth_quant_scale);
cnnlQuantizeScheme_t output_quant_scheme = CNNL_QUANTIZE_NONE;
if (dynamic_quant) {
TORCH_CHECK(quant_scale.has_value(), "dynamic_quant output, must have quant_scale");
output_quant_scheme = CNNL_QUANTIZE_PER_TOKEN;
} else if (quant_scale.has_value()) {
output_quant_scheme = CNNL_QUANTIZE_PER_CHANNEL;
}
// check params
bool has_residual = residual.has_value();
bool quant_out = quant_scale.has_value();
int hidden_size = input.size(-1);
TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2.");
TORCH_CHECK(input.stride(-1) == 1, "input last dim must be contiguous.");
TORCH_CHECK(input.sizes() == output.sizes(), "input and output must have the same shape");
if (has_residual) {
TORCH_CHECK(input.sizes() == residual.value().sizes(),
"input and residual must have the same shape");
}
TORCH_CHECK(norm_mode == "layernorm" || norm_mode == "rmsnorm",
"norm_mode must be 'layernorm' or 'rmsnorm'.");
cnnlTransformerNormType_t mode =
norm_mode == "layernorm" ? CNNL_TRANSFORMER_LAYERNORM : CNNL_TRANSFORMER_RMSNORM;
if (norm_mode == "layernorm") {
TORCH_CHECK(gamma.has_value() && beta.has_value(), "layernorm mode need gamma and beta.");
TORCH_CHECK(gamma.value().sizes() == beta.value().sizes(),
"gamma and beta must have the same shape")
TORCH_CHECK(gamma.value().dim() == 1 && gamma.value().size(0) == hidden_size,
"layernorm mode, gamma and beta size must be hidden_size.");
} else {
TORCH_CHECK(gamma.has_value(), "rmsnorm mode need gamma.");
TORCH_CHECK(gamma.value().dim() == 1 && gamma.value().size(0) == hidden_size,
"rmsnorm mode, gamma size must be hidden_size.");
}
if (quant_out) {
TORCH_CHECK(quant_scale.value().dim() == 1 && quant_scale.value().size(0) == hidden_size,
"quant_scale shape must be [hidden_size]");
}
const torch_mlu::mlu::MLUGuard device_guard(input.device());
at::Tensor smooth_quant_scale_flat =
dynamic_quant ? smooth_quant_scale.value().flatten() : at::Tensor();
at::Tensor input_flat;
at::Tensor output_flat;
at::Tensor residual_flat;
at::Tensor residual_out_flat;
if (quant_out) {
// input must be 2-dim, output dim must be same as input
TORCH_CHECK(input.is_contiguous(), "quant_out is not support when input has stride.");
input_flat = input.flatten(0, -2);
output_flat = output.flatten(0, -2);
residual_flat = has_residual ? residual.value().flatten(0, -2) : residual_flat;
residual_out_flat =
store_output_before_norm ? residual_out.value().flatten(0, -2) : residual_out_flat;
} else if (input.dim() > 3) {
input_flat = input.flatten(0, -3);
output_flat = output.flatten(0, -3);
residual_flat = has_residual ? residual.value().flatten(0, -3) : residual_flat;
residual_out_flat =
store_output_before_norm ? residual_out.value().flatten(0, -3) : residual_out_flat;
} else { // 2-dim or 3-dim
input_flat = input;
output_flat = output;
residual_flat = has_residual ? residual.value() : residual_flat;
residual_out_flat = store_output_before_norm ? residual_out.value() : residual_out_flat;
}
TORCH_CHECK(input.data_ptr() == input_flat.data_ptr(), "check the strides of input.");
TORCH_CHECK(output_flat.data_ptr() == output.data_ptr(), "check the strides of output.");
if (has_residual)
TORCH_CHECK(residual.value().data_ptr() == residual_flat.data_ptr(),
"check the strides of residual.");
if (store_output_before_norm)
TORCH_CHECK(residual_out.value().data_ptr() == residual_out_flat.data_ptr(),
"check the strides of residual_out.");
// create tensor desc
auto descs = createTensorDescs({input_flat, gamma.value_or(at::Tensor()),
beta.value_or(at::Tensor()), bias.value_or(at::Tensor()),
residual_flat, quant_scale.value_or(at::Tensor()),
residual_out_flat, output_flat, smooth_quant_scale_flat});
auto compute_dtype = getCnnlDataType(input_flat.scalar_type());
// forward
auto handle = torch_mlu::getCurrentHandle();
FusedNormTheory obj(input_flat.size(0), input_flat.size(-1), has_residual, bias.has_value(),
quant_out, dynamic_quant, store_output_before_norm, compute_dtype, norm_mode);
cnpxPush(obj);
CNNL_CHECK_FATAL(cnnlFuseNorm_v3(handle, descs[0].get(), getAtTensorPtr(input_flat), // input
descs[5].get(), getAtTensorPtr(quant_scale), // input_scale
descs[1].get(), getAtTensorPtr(gamma), // norm_scale
descs[2].get(), getAtTensorPtr(beta), // norm_bias
descs[4].get(), getAtTensorPtr(residual_flat), // residual
descs[3].get(), getAtTensorPtr(bias), // bias
eps, output_quant_scheme, store_output_before_norm, mode,
compute_dtype, nullptr, 0, // set workspace
descs[7].get(), getAtTensorPtr(output_flat), // output
descs[6].get(),
getAtTensorPtr(residual_out_flat), // residual_out
descs[8].get(), getAtTensorPtr(smooth_quant_scale_flat)));
cnpxPop();
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,373 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cstdint>
#include <vector>
#include "kernels/moe/moe.mluh"
#include "torch_ops_api.h"
#include "utils.h"
namespace tmo {
namespace torch_api {
const std::string arch_370 = "MLU370";
using GroupGemmDesc = tmo::op_desc::GroupGemmDesc;
using QuantMode = tmo::op_desc::GroupGemmDesc::QuantMode;
at::Tensor fused_moe(const at::Tensor &hidden_states,
const at::Tensor &gating_output,
const at::Tensor &w1,
const at::Tensor &w2,
const c10::optional<at::Tensor> &bias1,
const c10::optional<at::Tensor> &bias2,
const c10::optional<at::Tensor> &residual,
const c10::optional<at::Tensor> &input_smooth,
const c10::optional<at::Tensor> &act_smooth,
const c10::optional<at::Tensor> &w1_scale,
const c10::optional<at::Tensor> &w2_scale,
const c10::optional<at::List<int64_t>> &w1_quant_flag,
const c10::optional<at::List<int64_t>> &w2_quant_flag,
const int64_t topk,
const bool renormalize,
const bool gated,
const std::string &act_mode,
const int64_t start_expert_id,
const int64_t block_n,
const int64_t cncl_comm) {
auto sizes_1 = hidden_states.sizes();
auto sizes_2 = gating_output.sizes();
auto w2_shape = w2.sizes();
TORCH_CHECK(sizes_1.size() == sizes_2.size(),
"hidden_states and gating_output must have the same rank.")
TORCH_CHECK(sizes_1.size() == 2 || sizes_1.size() == 3, "hidden_states must be 2-D or 3-D.")
TORCH_CHECK(sizes_1[0] == sizes_2[0],
"hidden_states and gating_output must have the same batch.");
if (sizes_1.size() == 3) {
TORCH_CHECK(sizes_1[1] == sizes_2[1],
"hidden_states and gating_output must have the same seq.");
}
if (residual.has_value()) {
TORCH_CHECK(residual.value().sizes() == sizes_1,
"hidden_states and residual must have the same shape.");
}
TORCH_CHECK(!bias1.has_value() && !bias2.has_value(), "Currently not support bias1 and bias2.")
TORCH_CHECK(w1.dim() == 3 || w1.dim() == 1, "w1 should be 1-D or 3-D.")
TORCH_CHECK(w2.dim() == 3 || w2.dim() == 4 || w2.dim() == 1, "w2 should be 1-D or 3-D or 4-D.")
TORCH_CHECK((w1.dim() == 1 && w2.dim() == 1) || (w1.dim() != 1 && w2.dim() != 1),
"w1 and w2 should be both 1-D or not both 1-D.")
if (w1.dim() == 1) {
TORCH_CHECK(w1_quant_flag.has_value() && w2_quant_flag.has_value(),
"w1_quant_flag and w2_quant_flag need to exist simultaneously.");
}
// check contiguous
CHECK_TENSOR_CONTIGUOUS(hidden_states)
CHECK_TENSOR_CONTIGUOUS(gating_output)
CHECK_TENSOR_CONTIGUOUS(w1)
CHECK_TENSOR_CONTIGUOUS(w2)
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(act_smooth)
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(w1_scale)
const int64_t hidden_size = hidden_states.size(-1);
const int64_t num_expert = gating_output.size(-1);
const int64_t expert_size = w1_quant_flag.has_value() ? w1_scale.value().size(1) : w1.size(0);
auto hidden_states_ = hidden_states.view({-1, hidden_size});
auto gating_output_ = gating_output.view({-1, num_expert});
const int64_t num_token = hidden_states_.size(0);
const int64_t num_expand_token = num_token * topk;
const int64_t gemm1_co = w1_quant_flag.has_value() ? w1_scale.value().size(2) : w1.size(1);
const int64_t inner_size = gemm1_co / (1 + gated);
bool has_input_smooth = input_smooth.has_value();
bool has_act_smooth = act_smooth.has_value();
bool has_w1_scale = w1_scale.has_value();
bool has_w2_scale = w2_scale.has_value();
int opt_num = has_input_smooth + has_act_smooth + has_w1_scale + has_w2_scale;
QuantMode quant_mode = QuantMode::noQuant;
bool quant_grouped = false;
bool per_token_sq = opt_num == 4;
TORCH_CHECK(opt_num == 0 || opt_num == 4,
"input_smooth, act_smooth, w1_scale and w2_scale must be present and absent at the "
"same time.")
if (per_token_sq) {
TORCH_CHECK(input_smooth.value().dtype() == torch::kFloat32,
"the data type of input_smooth must be float.");
checkTensorSameAttr<TensorAttr::DTYPE>(input_smooth, act_smooth, w1_scale, w2_scale);
CHECK_SHAPE(input_smooth.value(), expert_size, hidden_size);
CHECK_SHAPE(act_smooth.value(), expert_size, inner_size);
quant_grouped = w1_scale.value().dim() == 3 ? true : false;
if (quant_grouped) {
CHECK_SHAPE(w1_scale.value(), w1_scale.value().size(0), expert_size, gemm1_co);
CHECK_SHAPE(w2_scale.value(), w2_scale.value().size(0), expert_size, hidden_size);
} else {
CHECK_SHAPE(w1_scale.value(), expert_size, gemm1_co);
CHECK_SHAPE(w2_scale.value(), expert_size, hidden_size);
}
if (w1_quant_flag.has_value()) {
quant_mode = QuantMode::W4W8;
quant_grouped = true;
} else {
TORCH_CHECK(hidden_size == w1.size(-1) || hidden_size == w1.size(-1) * 2,
"hidden_size == w1.size(-1) || hidden_size == w1.size(-1) * 2.");
quant_mode = hidden_size == w1.size(-1) ? QuantMode::W8 : QuantMode::W4;
}
}
if (quant_mode == QuantMode::W8 || quant_mode == QuantMode::noQuant) {
CHECK_SHAPE(w1, expert_size, gemm1_co, hidden_size);
if (w2.dim() == 3) {
CHECK_SHAPE(w2, expert_size, hidden_size, inner_size);
} else {
TORCH_CHECK(w2_shape[0] * w2_shape[2] == expert_size,
"w2_shape[0] * w2_shape[2] == expert_size");
CHECK_SHAPE(w2, w2_shape[0], hidden_size, w2_shape[2], inner_size);
}
} else if (quant_mode == QuantMode::W4) {
CHECK_SHAPE(w1, expert_size, gemm1_co, hidden_size / 2);
CHECK_SHAPE(w2, expert_size, hidden_size, inner_size / 2);
}
if (bias1.has_value()) {
CHECK_SHAPE(bias1.value(), num_expert, gemm1_co);
}
if (bias2.has_value()) {
CHECK_SHAPE(bias2.value(), num_expert, hidden_size);
}
TORCH_CHECK(topk <= num_expert, "topk <= num_expert.")
TORCH_CHECK(act_mode == "silu" || act_mode == "gelu",
"act_mode must be 'silu' or 'gelu', but got ", act_mode)
checkTensorSameAttr<TensorAttr::DTYPE>(hidden_states, bias1, bias2);
checkTensorSameAttr<TensorAttr::DEVICE>(hidden_states, gating_output, w1, w2, residual,
input_smooth, act_smooth, w1_scale, w2_scale);
const torch_mlu::mlu::MLUGuard device_guard(hidden_states_.device());
torch_mlu::DeviceProp *dev_prop = torch_mlu::getDeviceProperties(hidden_states.get_device());
std::string dev_name = dev_prop->name;
bool is_mlu370 = arch_370.compare(3, 3, dev_name, 3, 3) >= 0 ? true : false;
auto handle = torch_mlu::getCurrentHandle();
auto gating_output_dtype = getCnnlDataType(gating_output.scalar_type());
auto queue = torch_mlu::getCurMLUStream();
auto tensor_options = hidden_states_.options();
auto data_dtype = getCnnlDataType(hidden_states.scalar_type());
auto weight_dtype = getCnnlDataType(w1.scalar_type());
auto input_dtype = per_token_sq ? CNNL_DTYPE_INT8 : weight_dtype;
if (quant_mode == QuantMode::W4) {
weight_dtype = CNNL_DTYPE_INT4X2;
}
// create tensors
auto reduce_weight = at::empty({num_token, topk}, tensor_options.dtype(torch::kFloat));
auto expert_id = at::empty({num_token, topk}, tensor_options.dtype(torch::kInt32));
auto int32_idx = at::empty({2, num_token, topk}, tensor_options.dtype(torch::kInt32));
auto gather_expand_idx = int32_idx[0];
auto gather_combine_idx = int32_idx[1];
auto token_count = at::empty({num_expert}, tensor_options.dtype(torch::kInt32));
auto cumsum_token_count = at::empty({num_expert + 1}, tensor_options.dtype(torch::kInt32));
auto gen_idx_workspace =
at::empty({num_expert + 1 + num_expand_token}, tensor_options.dtype(torch::kInt32));
int64_t mid_dim = (1 + gated) * (1 + per_token_sq);
auto gemm1_output = at::empty({num_expand_token, mid_dim, inner_size}, w1.options());
auto gemm2_output = at::empty({num_expand_token, hidden_size}, tensor_options);
auto quant_input = at::Tensor();
auto input_scale = at::Tensor();
auto act_scale = at::Tensor();
auto expand_hidden_state = at::Tensor();
if (is_mlu370 || !per_token_sq) {
expand_hidden_state = at::empty({num_expand_token, hidden_size}, tensor_options);
}
if (per_token_sq) {
quant_input = at::empty({num_expand_token, hidden_size}, tensor_options.dtype(torch::kInt8));
input_scale = at::empty({num_expand_token}, tensor_options.dtype(torch::kFloat));
act_scale = at::empty({num_expand_token}, tensor_options.dtype(torch::kFloat));
}
//=========================================1:topk_softmax=========================================
int normalize_mode = renormalize ? 1 : 0;
tmo::invokeMoeSoftmaxTopkKernel(queue, (float *)getAtTensorPtr(reduce_weight),
(int *)getAtTensorPtr(expert_id), getAtTensorPtr(gating_output_),
nullptr, num_token, num_expert, 1, topk, -1, 0,
gating_output_dtype, normalize_mode);
//=========================================2:generate_idx=========================================
tmo::invokeMoeGenIdxKernel(
queue, (int *)getAtTensorPtr(gather_expand_idx), (int *)getAtTensorPtr(gather_combine_idx),
(int *)getAtTensorPtr(token_count), (int *)getAtTensorPtr(cumsum_token_count),
getAtTensorPtr(gen_idx_workspace), getAtTensorPtr(expert_id), num_token, num_expert, topk);
if (per_token_sq) {
//========================================3:pertoken_sq(optinal)==================================
if (is_mlu370) {
tmo::invokeMoeExpandInputKernel(
queue, getAtTensorPtr(expand_hidden_state), getAtTensorPtr(hidden_states_),
(int *)getAtTensorPtr(gather_expand_idx), (int *)getAtTensorPtr(cumsum_token_count),
num_token, hidden_size, topk, data_dtype, num_expert, start_expert_id, expert_size);
SmoothQuantTheory sm_obj1(num_expand_token, hidden_size, data_dtype, "per_token");
cnpxPush(sm_obj1);
tmo::ops::SmoothQuant(
handle, getAtTensorPtr(expand_hidden_state), getAtTensorPtr(input_smooth),
getAtTensorPtr(token_count[start_expert_id]), nullptr, nullptr,
getAtTensorPtr(quant_input), getAtTensorPtr(input_scale), num_expand_token, hidden_size,
expert_size, hidden_size, hidden_size, 1, data_dtype);
cnpxPop();
} else {
SmoothQuantTheory sm_obj1(num_expand_token, hidden_size, data_dtype, "per_token");
cnpxPush(sm_obj1);
tmo::ops::SmoothQuant(handle, getAtTensorPtr(hidden_states_), getAtTensorPtr(input_smooth),
getAtTensorPtr(token_count[start_expert_id]),
getAtTensorPtr(gather_expand_idx),
getAtTensorPtr(cumsum_token_count[start_expert_id]),
getAtTensorPtr(quant_input), getAtTensorPtr(input_scale), num_token,
hidden_size, expert_size, hidden_size, hidden_size, topk, data_dtype);
cnpxPop();
}
} else {
//========================================3:expand_input========================================
tmo::invokeMoeExpandInputKernel(
queue, getAtTensorPtr(expand_hidden_state), getAtTensorPtr(hidden_states_),
(int *)getAtTensorPtr(gather_expand_idx), (int *)getAtTensorPtr(cumsum_token_count),
num_token, hidden_size, topk, data_dtype, num_expert, start_expert_id, expert_size);
}
//========================================4:group_gemm1===========================================
GroupGemmDesc group_gemm_desc1(
expert_size, num_token /*the maximum number of token that may be processed by each expert*/,
gemm1_co, hidden_size, data_dtype, false, quant_mode);
group_gemm_desc1.setInputOutputTensor(
input_dtype, weight_dtype, data_dtype, CNNL_DTYPE_INT32,
per_token_sq ? getAtTensorPtr(quant_input) : getAtTensorPtr(expand_hidden_state),
getAtTensorPtr(w1), nullptr, getAtTensorPtr(gemm1_output), nullptr, hidden_size, hidden_size,
num_expand_token, false);
std::vector<int> w1_flag_vec;
if (per_token_sq) {
if (w1_quant_flag.has_value()) {
auto vec = w1_quant_flag.value().vec();
w1_flag_vec.resize(vec.size());
std::copy(vec.begin(), vec.end(), w1_flag_vec.begin());
}
group_gemm_desc1.setPerRowColScaleBiasAct(
getAtTensorPtr(input_scale), getAtTensorPtr(w1_scale), w1_flag_vec.data(), nullptr,
data_dtype, quant_grouped ? hidden_size : 0,
quant_grouped ? hidden_size / w1_scale.value().size(0) : 0, 0);
}
size_t group_gemm1_wsize =
tmo::ops::getGroupGemmWorkspaceSize(handle, group_gemm_desc1, expert_size);
auto group_gemm1_workspace = at::empty({static_cast<int64_t>(group_gemm1_wsize)},
hidden_states.options().dtype(at::kByte));
std::vector<int> ldb_array;
if (quant_mode != QuantMode::W4W8) ldb_array.assign(expert_size, hidden_size);
GroupGemmTheory gg_obj1(num_expand_token, expert_size, hidden_size, gemm1_co, false /*has_res*/,
input_dtype, data_dtype);
cnpxPush(gg_obj1);
tmo::ops::GroupGemm(handle, group_gemm_desc1, getAtTensorPtr(token_count[start_expert_id]),
nullptr, nullptr, getAtTensorPtr(group_gemm1_workspace), group_gemm1_wsize,
expert_size, hidden_size, /*k*/
gemm1_co, /*n*/
hidden_size /*lda*/, ldb_array /*ldb*/);
cnpxPop();
//========================================5:activation============================================
cnnlActivationMode_t act_type = act_mode == "silu" ? CNNL_ACTIVATION_SWISH : CNNL_ACTIVATION_GELU;
GroupAddBiasActiveTheory add_bias_obj(expert_size, num_expand_token, inner_size, gated,
bias1.has_value(), data_dtype, act_mode);
cnpxPush(add_bias_obj);
tmo::invokeGroupAddBiasActivationKernel(
queue, getAtTensorPtr(gemm1_output), getAtTensorPtr(gemm1_output),
nullptr /*getAtTensorPtr(bias1)*/, (int *)getAtTensorPtr(cumsum_token_count), num_expert,
num_expand_token, inner_size, gemm1_co, data_dtype, gated, act_type, start_expert_id,
expert_size, 1.0f);
cnpxPop();
//========================================6:smooth_quant==========================================
if (per_token_sq) {
SmoothQuantTheory sm_obj2(num_expand_token, inner_size * expert_size, data_dtype, "per_token");
cnpxPush(sm_obj2);
tmo::ops::SmoothQuant(handle, getAtTensorPtr(gemm1_output), getAtTensorPtr(act_smooth),
getAtTensorPtr(token_count[start_expert_id]), nullptr /*gather_idx*/,
nullptr, getAtTensorPtr(gemm1_output), getAtTensorPtr(act_scale),
num_expand_token, inner_size, expert_size, gemm1_co /*input_stride*/,
2 * gemm1_co /*output_stride*/, 1 /*topk*/, data_dtype);
cnpxPop();
}
if (cncl_comm > 0) {
TORCH_CHECK(num_expert == expert_size, "expert_size must be num_expert when cncl_comm > 0");
c10::optional<at::Tensor> act_scale_opt(act_scale);
c10::optional<at::Tensor> w2_scale_opt(w2_scale);
std::string dtype_s = torchDtype2Str(hidden_states.scalar_type());
auto gg_input =
gemm1_output.as_strided({num_expand_token, inner_size}, {mid_dim * inner_size, 1});
auto output = group_gemm_combine_result_allreduce(
cncl_comm, gg_input, w2, token_count, gather_combine_idx, reduce_weight, c10::nullopt,
c10::nullopt, c10::nullopt, per_token_sq ? act_scale_opt : c10::nullopt,
per_token_sq ? w2_scale_opt : c10::nullopt, dtype_s, num_token, topk, block_n);
if (residual.has_value()) {
output.view(sizes_1) += residual.value();
}
return output.view(sizes_1);
} else {
//========================================8:group_gemm2===========================================
GroupGemmDesc group_gemm_desc2(expert_size, num_token, hidden_size, inner_size, data_dtype,
false, quant_mode);
std::vector<int64_t> b_offset(expert_size);
int w2_ldb = inner_size;
if (w2.dim() == 4) {
w2_ldb = w2_shape[2] * inner_size;
auto elem_size = w2.element_size();
for (int64_t i = 0; i < expert_size; i++) {
b_offset[i] = (i / w2_shape[2] * w2.stride(0) + i % w2_shape[2] * w2_shape[3]) * elem_size;
}
}
group_gemm_desc2.setInputOutputTensor(input_dtype, weight_dtype, data_dtype, CNNL_DTYPE_INT32,
getAtTensorPtr(gemm1_output), getAtTensorPtr(w2), nullptr,
getAtTensorPtr(gemm2_output), nullptr, inner_size,
gemm1_co * (per_token_sq + 1), num_expand_token, false,
w2.dim() == 4 ? b_offset.data() : nullptr);
std::vector<int> w2_flag_vec;
if (per_token_sq) {
if (w2_quant_flag.has_value()) {
auto vec = w2_quant_flag.value().vec();
w2_flag_vec.resize(vec.size());
std::copy(vec.begin(), vec.end(), w2_flag_vec.begin());
}
group_gemm_desc2.setPerRowColScaleBiasAct(
getAtTensorPtr(act_scale), getAtTensorPtr(w2_scale), w2_flag_vec.data(), nullptr,
data_dtype, quant_grouped ? inner_size : 0,
quant_grouped ? inner_size / w2_scale.value().size(0) : 0, 0);
}
size_t group_gemm2_wsize =
tmo::ops::getGroupGemmWorkspaceSize(handle, group_gemm_desc2, expert_size);
auto group_gemm2_workspace = at::empty({static_cast<int64_t>(group_gemm2_wsize)},
hidden_states.options().dtype(at::kByte));
std::vector<int> w2_ldb_array;
if (quant_mode != QuantMode::W4W8) w2_ldb_array.assign(expert_size, w2_ldb);
GroupGemmTheory gg_obj2(num_expand_token, expert_size, inner_size, hidden_size,
false /*has_res*/, input_dtype, data_dtype);
cnpxPush(gg_obj2);
tmo::ops::GroupGemm(handle, group_gemm_desc2, getAtTensorPtr(token_count[start_expert_id]),
nullptr, nullptr, getAtTensorPtr(group_gemm2_workspace), group_gemm2_wsize,
expert_size, inner_size, /*k*/
hidden_size, /*n*/
gemm1_co * (per_token_sq + 1) /*lda*/, w2_ldb_array /*ldb*/);
cnpxPop();
//========================================9:combine_result=======================================
auto output = at::empty(hidden_states_.sizes(), hidden_states_.options());
MoeCombineResultTheory cr_obj(num_token, topk, hidden_size, expert_size, bias2.has_value(),
residual.has_value(), data_dtype);
cnpxPush(cr_obj);
tmo::invokeMoeCombineResultKernel(
queue, getAtTensorPtr(output), getAtTensorPtr(gemm2_output), nullptr,
getAtTensorPtr(residual), (float *)getAtTensorPtr(reduce_weight),
(int *)getAtTensorPtr(cumsum_token_count), (int *)getAtTensorPtr(gather_combine_idx),
num_token, topk, num_expert, hidden_size, start_expert_id, expert_size, data_dtype);
cnpxPop();
return output.view(sizes_1);
}
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,281 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/fused_rope.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void fused_rope(at::Tensor &qkv,
at::Tensor &key_cache_hp,
at::Tensor &value_cache_hp,
const c10::optional<at::Tensor> &key_cache_lp,
const c10::optional<at::Tensor> &value_cache_lp,
const at::Tensor &sin_table,
const at::Tensor &cos_table,
const at::Tensor &position_ids,
const at::Tensor &gamma,
const at::Tensor &beta,
const c10::optional<at::Tensor> &key_scale_hp,
const c10::optional<at::Tensor> &value_scale_hp,
const c10::optional<at::Tensor> &key_scale_lp,
const c10::optional<at::Tensor> &value_scale_lp,
const c10::optional<at::Tensor> &cache_bs_id_hp,
const c10::optional<at::Tensor> &cache_seq_offsets_hp,
const c10::optional<at::Tensor> &cache_bs_id_lp,
const c10::optional<at::Tensor> &cache_seq_offsets_lp,
const c10::optional<at::Tensor> &slot_mapping_hp,
const c10::optional<at::Tensor> &slot_mapping_lp,
const double eps) {
bool paged_cache_hp = slot_mapping_hp.has_value();
bool paged_cache_lp = slot_mapping_lp.has_value();
bool mixed_cache = key_cache_lp.has_value() && value_cache_lp.has_value();
const int origin_device_id = qkv.get_device();
checkTensorSameAttr<TensorAttr::ALL>(qkv, sin_table, cos_table, gamma, beta);
TORCH_CHECK(qkv.is_contiguous(), "qkv tensor must be contiguous.")
TORCH_CHECK(key_cache_hp.is_contiguous(), "key_cache_hp tensor must be contiguous")
TORCH_CHECK(value_cache_hp.is_contiguous(), "value_cache_hp tensor must be contiguous")
if (mixed_cache) {
TORCH_CHECK(key_cache_lp.value().is_contiguous(), "key_cache_lp tensor must be contiguous")
TORCH_CHECK(key_scale_hp.has_value() && value_scale_hp.has_value(),
"key_scale_hp and value_scale_hp must not be null under mixed cache.")
}
if (mixed_cache) {
TORCH_CHECK(value_cache_lp.value().is_contiguous(), "value_cache_lp tensor must be contiguous")
}
TORCH_CHECK(position_ids.is_contiguous(), "position_ids tensor must be contiguous");
TORCH_CHECK(gamma.is_contiguous(), "gamma tensor must be contiguous");
TORCH_CHECK(beta.is_contiguous(), "beta tensor must be contiguous");
if (cache_bs_id_hp.has_value()) {
TORCH_CHECK(cache_bs_id_hp.value().is_contiguous(), "cache_bs_id_hp tensor must be contiguous.")
}
if (cache_seq_offsets_hp.has_value()) {
TORCH_CHECK(cache_seq_offsets_hp.value().is_contiguous(),
"cache_seq_offsets_hp tensor must be contiguous")
}
if (cache_bs_id_lp.has_value()) {
TORCH_CHECK(cache_bs_id_lp.value().is_contiguous(), "cache_bs_id_lp tensor must be contiguous.")
}
if (cache_seq_offsets_lp.has_value()) {
TORCH_CHECK(cache_seq_offsets_lp.value().is_contiguous(),
"cache_seq_offsets_lp tensor must be contiguous")
}
if (paged_cache_hp) {
TORCH_CHECK(slot_mapping_hp.value().is_contiguous(),
"slot_mapping_hp tensor must be contiguous.")
}
if (paged_cache_lp) {
TORCH_CHECK(slot_mapping_lp.value().is_contiguous(),
"slot_mapping_lp tensor must be contiguous.")
}
// check qkv key_cache value_cache
auto dtype = qkv.dtype();
TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16,
"qkv dtype must be half or bfloat16.")
TORCH_CHECK(qkv.dim() == 4,
"qkv tensor dim must be 4: (batch_size, 1, head_num_q + head_num_kv * 2, head_size).")
TORCH_CHECK(qkv.size(1) == 1, "only support seq_len == 1.");
TORCH_CHECK(key_cache_hp.dim() == 4,
"key_cache_hp dim must be 4: (max_bs, head_num_kv, max_decode_len_hp, head_size),"
"or (num_blocks, head_num_kv, block_size_hp, head_size).");
TORCH_CHECK(value_cache_hp.dim() == 4,
"value_cache_hp dim must be 4: (max_bs, head_num_kv, max_decode_len_hp, head_size),"
"or (num_blocks, head_num_kv, block_size_hp, head_size).");
void *key_cache_lp_ptr = nullptr;
void *value_cache_lp_ptr = nullptr;
if (mixed_cache) {
TORCH_CHECK(
key_cache_lp.value().dim() == 4,
"key_cache_lp dim must be 4: (max_bs, head_num_kv, max_decode_len_lp, head_size / 2),"
"or (num_blocks, head_num_kv, block_size_lp, head_size / 2).");
TORCH_CHECK(
value_cache_lp.value().dim() == 4,
"value_cache_lp dim must be 4: (max_bs, head_num_kv, max_decode_len_lp / 2, head_size),"
"or (num_blocks, head_num_kv, block_size_lp / 2, head_size).");
TORCH_CHECK(key_cache_lp.value().get_device() == origin_device_id,
"key_scale_hp tensor device index is not same, original index: ", origin_device_id,
"now index is: ", key_cache_lp.value().get_device());
TORCH_CHECK(value_cache_lp.value().get_device() == origin_device_id,
"value_scale_hp tensor device index is not same, original index: ",
origin_device_id, "now index is: ", value_cache_lp.value().get_device());
key_cache_lp_ptr = key_cache_lp.value().data_ptr();
value_cache_lp_ptr = value_cache_lp.value().data_ptr();
}
int batch_size = qkv.size(0);
int head_num_qkv = qkv.size(-2);
int head_size = qkv.size(-1);
int head_num_kv = key_cache_hp.size(1); // linear cache and paged cache same
int max_bs_hp = key_cache_hp.size(0); // for linear cache, paged cache no use
int max_bs_lp = mixed_cache ? key_cache_lp.value().size(0) : 0;
int max_decode_len_hp = key_cache_hp.size(2); // for linear cache, paged cache no use
int max_decode_len_lp = mixed_cache ? key_cache_lp.value().size(2) : 0;
int num_blocks_hp = key_cache_hp.size(0);
int num_blocks_lp = mixed_cache ? key_cache_lp.value().size(0) : 0;
int block_size_hp = key_cache_hp.size(2); // for paged cache, linear cache no use
int block_size_lp = mixed_cache ? key_cache_lp.value().size(2) : 0;
int head_num_q = head_num_qkv - 2 * head_num_kv;
int group_size = head_size;
TORCH_CHECK(head_size <= 128, "only support qkv head_size <= 128.");
TORCH_CHECK(head_size % 2 == 0, "head_size must be divided by 2.");
TORCH_CHECK(head_num_q <= 32, "only support head_num_q <= 32.");
TORCH_CHECK(head_num_kv <= 32, "only support head_num_kv <= 32.");
int kv_out_size = key_scale_hp.has_value() && value_scale_hp.has_value() ? 1 : 2;
size_t hp_cache_total_bytes =
paged_cache_hp
? (size_t)num_blocks_hp * head_num_kv * block_size_hp * head_size * kv_out_size
: (size_t)max_bs_hp * max_decode_len_hp * head_num_kv * head_size * kv_out_size;
TORCH_CHECK(hp_cache_total_bytes <= INT32_MAX,
"hp_cache memory btyes must be less than or equal to INT32_MAX.");
if (mixed_cache) {
size_t lp_cache_total_bytes =
paged_cache_lp ? (size_t)num_blocks_lp * head_num_kv * block_size_lp * head_size / 2
: (size_t)max_bs_lp * max_decode_len_lp * head_num_kv * head_size / 2;
TORCH_CHECK(lp_cache_total_bytes <= INT32_MAX,
"lp_cache memory btyes must be less than or equal to INT32_MAX.");
}
// check sin_table cos_table gamma beta
TORCH_CHECK(sin_table.dim() == 2 && cos_table.dim() == 2,
"sin_table and cos_table tensor dim must be 2: (rotary_seqlen, head_size).");
TORCH_CHECK(sin_table.size(1) == head_size && cos_table.size(1) == head_size,
"rotary_dim must be same with head_size.");
TORCH_CHECK(sin_table.stride(0) == cos_table.stride(0),
"sin_table first stride must be equal to cos_table first_stride.");
int rotary_stride = sin_table.stride(0);
CHECK_SHAPE(gamma, head_size);
CHECK_SHAPE(beta, head_size);
// check key value scale
bool has_kv_scale = key_scale_hp.has_value() && value_scale_hp.has_value();
const void *key_scale_hp_ptr = nullptr;
const void *value_scale_hp_ptr = nullptr;
if (has_kv_scale) {
CHECK_SHAPE(key_scale_hp.value(), head_num_kv, head_size);
CHECK_SHAPE(value_scale_hp.value(), head_num_kv, head_size);
TORCH_CHECK(key_scale_hp.value().scalar_type() == torch::kFloat32,
"key_scale_hp dtype must be float.");
TORCH_CHECK(value_scale_hp.value().scalar_type() == torch::kFloat32,
"value_scale_hp dtype must be float.");
TORCH_CHECK(key_scale_hp.value().get_device() == origin_device_id,
"key_scale_hp tensor device index is not same, original index: ", origin_device_id,
"now index is: ", key_scale_hp.value().get_device());
TORCH_CHECK(value_scale_hp.value().get_device() == origin_device_id,
"value_scale_hp tensor device index is not same, original index: ",
origin_device_id, "now index is: ", value_scale_hp.value().get_device());
key_scale_hp_ptr = key_scale_hp.value().data_ptr();
value_scale_hp_ptr = value_scale_hp.value().data_ptr();
}
void *key_scale_lp_ptr = nullptr;
void *value_scale_lp_ptr = nullptr;
if (mixed_cache) {
TORCH_CHECK(key_scale_lp.value().dim() == 4,
"key_scale_lp dim must be 4: (max_bs, head_num_kv, max_decode_len_lp, group_num),"
"or (num_blocks, head_num_kv, block_size_lp, group_num).");
TORCH_CHECK(key_scale_lp.value().scalar_type() == torch::kFloat32,
"key_scale_hp dtype must be float.");
TORCH_CHECK(value_scale_lp.value().scalar_type() == torch::kFloat32,
"value_scale_hp dtype must be float.");
TORCH_CHECK(key_scale_lp.value().get_device() == origin_device_id,
"key_scale_hp tensor device index is not same, original index: ", origin_device_id,
"now index is: ", key_scale_lp.value().get_device());
TORCH_CHECK(value_scale_lp.value().get_device() == origin_device_id,
"value_scale_hp tensor device index is not same, original index: ",
origin_device_id, "now index is: ", value_scale_lp.value().get_device());
group_size = head_size / key_scale_lp.value().size(-1);
key_scale_lp_ptr = key_scale_lp.value().data_ptr();
value_scale_lp_ptr = value_scale_lp.value().data_ptr();
}
// check cache_bs_id cache_seq_offsets
const void *cache_bs_id_hp_ptr = nullptr;
const void *cache_bs_id_lp_ptr = nullptr;
bool has_cache_bs_id_hp = cache_bs_id_hp.has_value();
bool has_cache_bs_id_lp = cache_bs_id_lp.has_value();
if (has_cache_bs_id_hp) {
CHECK_SHAPE(cache_bs_id_hp.value(), batch_size);
TORCH_CHECK(cache_bs_id_hp.value().scalar_type() == torch::kInt32,
"cache_bs_id_hp dtype need be torch::kInt32");
TORCH_CHECK(cache_bs_id_hp.value().get_device() == origin_device_id,
"cache_bs_id_hp tensor device index is not same, original index: ",
origin_device_id, "now index is: ", cache_bs_id_hp.value().get_device());
cache_bs_id_hp_ptr = cache_bs_id_hp.value().data_ptr();
}
if (mixed_cache && has_cache_bs_id_lp) {
CHECK_SHAPE(cache_bs_id_lp.value(), batch_size);
TORCH_CHECK(cache_bs_id_lp.value().scalar_type() == torch::kInt32,
"cache_bs_id_lp dtype need be torch::kInt32");
TORCH_CHECK(cache_bs_id_lp.value().get_device() == origin_device_id,
"cache_bs_id_lp tensor device index is not same, original index: ",
origin_device_id, "now index is: ", cache_bs_id_lp.value().get_device());
cache_bs_id_lp_ptr = cache_bs_id_lp.value().data_ptr();
}
const void *slot_mapping_hp_ptr = nullptr;
const void *slot_mapping_lp_ptr = nullptr;
if (paged_cache_hp) {
CHECK_SHAPE(slot_mapping_hp.value(), batch_size);
TORCH_CHECK(slot_mapping_hp.value().scalar_type() == torch::kInt32,
"slot_mapping_hp dtype need be torch::kInt32");
TORCH_CHECK(slot_mapping_hp.value().get_device() == origin_device_id,
"slot_mapping_hp tensor device index is not same, original index: ",
origin_device_id, "now index is: ", slot_mapping_hp.value().get_device());
slot_mapping_hp_ptr = slot_mapping_hp.value().data_ptr();
}
if (mixed_cache && paged_cache_lp) {
CHECK_SHAPE(slot_mapping_lp.value(), batch_size);
TORCH_CHECK(slot_mapping_lp.value().scalar_type() == torch::kInt32,
"slot_mapping_lp dtype need be torch::kInt32");
TORCH_CHECK(slot_mapping_lp.value().get_device() == origin_device_id,
"slot_mapping_lp tensor device index is not same, original index: ",
origin_device_id, "now index is: ", slot_mapping_lp.value().get_device());
slot_mapping_lp_ptr = slot_mapping_lp.value().data_ptr();
}
const void *cache_seq_offsets_hp_ptr = nullptr;
const void *cache_seq_offsets_lp_ptr = nullptr;
if (!paged_cache_hp) {
CHECK_SHAPE(cache_seq_offsets_hp.value(), batch_size);
TORCH_CHECK(cache_seq_offsets_hp.value().dtype() == torch::kInt32,
"cache_seq_offsets_hp dtype need be torch::kInt32");
TORCH_CHECK(cache_seq_offsets_hp.value().get_device() == origin_device_id,
"cache_seq_offsets_hp tensor device index is not same, original index:",
origin_device_id, "now index is: ", cache_seq_offsets_hp.value().get_device());
cache_seq_offsets_hp_ptr = cache_seq_offsets_hp.value().data_ptr();
}
if (mixed_cache && !paged_cache_lp) {
CHECK_SHAPE(cache_seq_offsets_lp.value(), batch_size);
TORCH_CHECK(cache_seq_offsets_lp.value().dtype() == torch::kInt32,
"cache_seq_offsets_lp dtype need be torch::kInt32");
TORCH_CHECK(cache_seq_offsets_lp.value().get_device() == origin_device_id,
"cache_seq_offsets_lp tensor device index is not same, original index:",
origin_device_id, "now index is: ", cache_seq_offsets_lp.value().get_device());
cache_seq_offsets_lp_ptr = cache_seq_offsets_lp.value().data_ptr();
}
auto queue = torch_mlu::getCurrentMLUStream();
auto data_type = getCnnlDataType(qkv.scalar_type());
invokeFusedRope(queue, getAtTensorPtr(qkv), getAtTensorPtr(key_cache_hp),
getAtTensorPtr(value_cache_hp), key_cache_lp_ptr, value_cache_lp_ptr,
getAtTensorPtr(sin_table), getAtTensorPtr(cos_table),
getAtTensorPtr(position_ids), getAtTensorPtr(gamma), getAtTensorPtr(beta),
key_scale_hp_ptr, value_scale_hp_ptr, key_scale_lp_ptr, value_scale_lp_ptr,
cache_bs_id_hp_ptr, cache_seq_offsets_hp_ptr, cache_bs_id_lp_ptr,
cache_seq_offsets_lp_ptr, slot_mapping_hp_ptr, slot_mapping_lp_ptr, rotary_stride,
batch_size, head_num_q, head_num_kv, head_size, max_decode_len_hp,
max_decode_len_lp, block_size_hp, block_size_lp, group_size, data_type, eps);
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,46 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "glue_ops.h"
#include <ATen/FunctionalTensorWrapper.h>
void copy_blocks__functionalization_glue(
const std::vector<torch::Tensor> &k_caches,
const std::vector<torch::Tensor> &v_caches,
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(k_caches));
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(v_caches));
at::functionalization::impl::sync(k_caches);
at::functionalization::impl::sync(v_caches);
auto new_k_caches = at::functionalization::impl::from_functional_tensor(k_caches);
auto new_v_caches = at::functionalization::impl::from_functional_tensor(v_caches);
// Grab the dispatcher entry corresponding to the out-of-place op, "foo"
static auto op_handle =
c10::Dispatcher::singleton()
// specify namespace::op_name, op_overload_name
.findSchemaOrThrow("torch_mlu_ops::copy_blocks_out_of_place", "")
// Specify the C++ schema of the out-of-place op.
.typed<std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>>(
const std::vector<at::Tensor> &k_caches, const std::vector<at::Tensor> &v_caches,
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping)>();
// Next, redispatch to the out-of-place op, foo() (user called foo_, we call fooo)
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = op_handle.call(new_k_caches, new_v_caches, block_mapping);
}
// Finally, tell functionalization about this mutation.
at::functionalization::impl::replace_(k_caches, std::get<0>(tmp_output));
at::functionalization::impl::replace_(v_caches, std::get<1>(tmp_output));
at::functionalization::impl::commit_update(k_caches);
at::functionalization::impl::commit_update(v_caches);
}

View File

@@ -0,0 +1,22 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_TORCH_API_GLUE_OPS_H_
#define CSRC_TORCH_API_GLUE_OPS_H_
#include "torch/extension.h"
#include "utils.h"
void copy_blocks__functionalization_glue(
const std::vector<torch::Tensor> &k_caches,
const std::vector<torch::Tensor> &v_caches,
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping);
#endif // CSRC_TORCH_API_GLUE_OPS_H_

View File

@@ -0,0 +1,231 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <cstdint>
#include <vector>
#include "torch_ops_api.h"
#include "utils.h"
namespace tmo {
namespace torch_api {
using GroupGemmDesc = op_desc::GroupGemmDesc;
using QuantMode = GroupGemmDesc::QuantMode;
at::Tensor group_gemm(const at::Tensor &a_tensor,
const at::Tensor &b_tensor,
const at::Tensor &m_list,
const c10::optional<at::Tensor> &gather_idx,
const c10::optional<at::Tensor> &c_tensor,
const c10::optional<at::Tensor> &alpha,
const c10::optional<at::Tensor> &beta,
const c10::optional<at::Tensor> &a_scale,
const c10::optional<at::Tensor> &b_scale,
const c10::optional<at::Tensor> &bias,
const c10::optional<std::string> &data_type,
const c10::optional<at::List<int64_t>> &quant_flag,
const c10::optional<at::Tensor> &b_offset,
const int64_t max_m) {
bool has_dtype = data_type.has_value();
bool has_c_tensor = c_tensor.has_value();
bool has_gather_idx = gather_idx.has_value();
bool has_a_scale = a_scale.has_value();
bool has_b_scale = b_scale.has_value();
bool has_alpha = alpha.has_value();
bool has_beta = beta.has_value();
bool has_bias = bias.has_value();
bool has_quant_flag = quant_flag.has_value();
QuantMode quant_mode = QuantMode::noQuant;
bool quant_grouped = false;
// check stride
TORCH_CHECK(a_tensor.stride(-1) == 1, "a_tensor last dim must be contiguous");
CHECK_TENSOR_CONTIGUOUS(m_list);
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(c_tensor);
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(gather_idx);
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(a_scale);
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(b_scale);
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(alpha);
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(beta);
// check dim
TORCH_CHECK(m_list.dim() == 1, "the shape of m_list should be 2D");
TORCH_CHECK(a_tensor.dim() == 2, "the shape of a_tensor should be 2D");
TORCH_CHECK(b_tensor.dim() == 3 || b_tensor.dim() == 4 || (b_tensor.dim() == 1 && has_quant_flag),
"the shape of b_tensor should be 3D or 4D");
if (has_c_tensor) {
TORCH_CHECK(c_tensor.value().dim() == 2, "the shape of c_tensor should be 2D");
}
if (has_gather_idx) {
TORCH_CHECK(gather_idx.value().dim() == 1, "the shape of gather_idx should be 1D");
}
if (has_a_scale) {
TORCH_CHECK(a_scale.value().dim() == 1, "the shape of a_scale should be 1D");
}
if (has_b_scale) {
TORCH_CHECK(b_scale.value().dim() == 2 || b_scale.value().dim() == 3,
"the shape of b_scale should be 2D or 3D");
}
if (has_alpha) {
TORCH_CHECK(alpha.value().dim() == 1, "the shape of alpha should be 1D");
}
if (has_beta) {
TORCH_CHECK(beta.value().dim() == 1, "the shape of beta should be 1D");
}
// check shape
int64_t experts_num = m_list.size(0);
int64_t num_token = a_tensor.size(0);
int64_t k = a_tensor.size(-1);
int64_t n = has_quant_flag ? b_scale.value().size(2) : b_tensor.size(1);
int64_t lda = a_tensor.stride(0);
int64_t total_m = has_gather_idx ? gather_idx.value().size(0) : num_token;
TORCH_CHECK(experts_num > 0, "number of groups must be large than zero");
CHECK_SHAPE(a_tensor, num_token, k);
if (has_a_scale) {
CHECK_SHAPE(a_scale.value(), total_m);
}
if (has_quant_flag) {
quant_grouped = true;
quant_mode = QuantMode::W4W8;
TORCH_CHECK(has_b_scale, "if quant_flag given, b_scale must exist");
CHECK_SHAPE(b_scale.value(), b_scale.value().size(0), experts_num, n);
} else if (has_b_scale) {
int64_t b_k = b_tensor.size(-1);
quant_grouped = b_scale.value().dim() == 3 ? true : false;
if (quant_grouped) {
CHECK_SHAPE(b_scale.value(), b_scale.value().size(0), experts_num, n);
} else {
CHECK_SHAPE(b_scale.value(), experts_num, n);
}
TORCH_CHECK(k == b_k || k == b_k * 2, "k == b_k || k == b_k * 2.");
quant_mode = (b_k == k) ? QuantMode::W8 : QuantMode::W4;
}
if (has_alpha) {
CHECK_SHAPE(alpha.value(), experts_num);
}
if (has_beta) {
CHECK_SHAPE(beta.value(), experts_num);
}
auto b_shape = b_tensor.sizes();
cnnlDataType_t a_dtype = getCnnlDataType(a_tensor.scalar_type());
cnnlDataType_t b_dtype = getCnnlDataType(b_tensor.scalar_type());
torch::Dtype dtype = a_tensor.scalar_type();
if (quant_mode == QuantMode::W8 || quant_mode == QuantMode::noQuant) {
if (b_tensor.dim() == 3) {
CHECK_SHAPE(b_tensor, experts_num, n, k);
} else {
TORCH_CHECK(b_shape[0] * b_shape[2] == experts_num, "b_shape[0] * b_shape[2] == experts_num");
CHECK_SHAPE(b_tensor, b_shape[0], n, b_shape[2], k);
}
} else if (quant_mode == QuantMode::W4) {
CHECK_SHAPE(b_tensor, experts_num, n, k / 2);
b_dtype = CNNL_DTYPE_INT4X2;
}
TORCH_CHECK(m_list.dtype() == torch::kInt32, "data type of m_list should be int32");
if (has_gather_idx) {
TORCH_CHECK(gather_idx.value().dtype() == torch::kInt32,
"data type of gather_idx should be int32");
}
if (has_a_scale) {
TORCH_CHECK(has_dtype, "data_type must given when a_scale and b_scale are given");
TORCH_CHECK(data_type.value() == "float" || data_type.value() == "half" ||
data_type.value() == "bfloat16",
"data_type must be 'float', 'half' or 'bfloat16'.");
dtype = data_type.value() == "float" ? torch::kFloat32
: data_type.value() == "half" ? torch::kFloat16
: torch::kBFloat16;
}
const torch_mlu::mlu::MLUGuard device_guard(a_tensor.device());
at::Tensor d_tensor = at::empty({total_m, n}, a_tensor.options().dtype(dtype));
cnnlDataType_t d_dtype = getCnnlDataType(d_tensor.scalar_type());
// check device
TORCH_CHECK(isMlu(m_list), "m_list must on mlu");
TORCH_CHECK(isMlu(a_tensor), "a_tensor must on mlu");
TORCH_CHECK(isMlu(b_tensor), "b_tensor must on mlu");
TORCH_CHECK(isMlu(d_tensor), "d_tensor must on mlu");
if (has_c_tensor) {
TORCH_CHECK(isMlu(c_tensor.value()), "c_tensor must on mlu");
}
if (has_gather_idx) {
TORCH_CHECK(isMlu(gather_idx.value()), "gather_idx must on mlu");
}
if (has_a_scale) {
TORCH_CHECK(isMlu(a_scale.value()), "a_scale must on mlu");
}
if (has_b_scale) {
TORCH_CHECK(isMlu(b_scale.value()), "b_scale must on mlu");
}
if (has_alpha) {
TORCH_CHECK(isMlu(alpha.value()), "alpha must on mlu");
}
if (has_beta) {
TORCH_CHECK(isMlu(beta.value()), "beta must on mlu");
}
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
GroupGemmDesc group_gemm_desc(experts_num, max_m, n, k, d_dtype, has_gather_idx, quant_mode);
group_gemm_desc.setInputOutputTensor(a_dtype, b_dtype, d_dtype, CNNL_DTYPE_INT32,
getAtTensorPtr(a_tensor), getAtTensorPtr(b_tensor),
getAtTensorPtr(c_tensor), getAtTensorPtr(d_tensor),
getAtTensorPtr(gather_idx), k, lda, total_m, has_c_tensor,
(int64_t *)getAtTensorPtr(b_offset));
std::vector<int> flag_vec;
if (quant_mode != QuantMode::noQuant || has_bias) {
if (has_quant_flag) {
auto vec = quant_flag.value().vec();
flag_vec.resize(vec.size());
std::copy(vec.begin(), vec.end(), flag_vec.begin());
}
group_gemm_desc.setPerRowColScaleBiasAct(
has_a_scale ? getAtTensorPtr(a_scale) : nullptr,
has_b_scale ? getAtTensorPtr(b_scale) : nullptr, flag_vec.data(),
has_bias ? getAtTensorPtr(bias) : nullptr, d_dtype, quant_grouped ? k : 0,
quant_grouped ? k / b_scale.value().size(0) : 0, 0);
}
size_t group_gemm_wsize =
tmo::ops::getGroupGemmWorkspaceSize(handle, group_gemm_desc, experts_num);
auto group_gemm_workspace =
at::empty({static_cast<int64_t>(group_gemm_wsize)}, a_tensor.options().dtype(at::kByte));
std::vector<int> ldb_array;
// if (quant_mode == QuantMode::W4W8) {
// ldb_array.resize(experts_num);
// auto q_group = b_scale.value().size(0);
// for (auto i = 0; i < experts_num; ++i) {
// int sum = 0;
// for (auto j = 0; j < q_group; j++) {
// sum += flag_vec[i * q_group + j];
// }
// ldb_array[i] = (sum / 4) * (k / q_group / 2);
// }
// }
if (quant_mode != QuantMode::W4W8) {
int ldb = b_tensor.dim() == 3 ? k : b_shape[2] * k;
ldb_array.assign(experts_num, ldb);
}
GroupGemmTheory obj(total_m, experts_num, k, n, has_c_tensor, a_dtype, getCnnlDataType(dtype));
cnpxPush(obj);
ops::GroupGemm(handle, group_gemm_desc, getAtTensorPtr(m_list), getAtTensorPtr(alpha),
getAtTensorPtr(beta), getAtTensorPtr(group_gemm_workspace), group_gemm_wsize,
experts_num, k, n, lda, ldb_array);
cnpxPop();
return d_tensor;
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,129 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
at::Tensor matmul(const at::Tensor &a,
const at::Tensor &b,
const c10::optional<at::Tensor> &d,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &c,
const c10::optional<std::string> &dtype,
const std::string &act_mode,
double alpha,
double beta,
bool fast_act,
bool approximate,
double a_scale,
double b_scale,
bool trans_a,
bool trans_b) {
bool has_bias = bias.has_value();
bool has_res = c.has_value();
bool use_beta = false;
int m = trans_a ? a.size(1) : a.size(0);
int n = trans_b ? b.size(0) : b.size(1);
int k = trans_a ? a.size(0) : a.size(1);
// check device and dtype
checkTensorSameAttr<TensorAttr::DEVICE>(a, b, c, bias);
// check contiguous
TORCH_CHECK(a.stride(-1) == 1, "a last dim must be contiguous.")
TORCH_CHECK(b.stride(-1) == 1, "b last dim must be contiguous.")
if (has_bias) TORCH_CHECK(bias.value().is_contiguous(), "bias must be contiguous.")
if (has_res) {
use_beta = true;
TORCH_CHECK(c.value().is_contiguous(), "c must be contiguous.")
CHECK_SHAPE(c.value(), m, n);
}
at::Tensor bias_view;
if (has_bias) {
TORCH_CHECK(bias.value().dim() == 1, "bias must be 1-D tensor.")
bias_view = bias.value().unsqueeze(0);
}
// get cnnl data type and init output
auto a_dtype = getCnnlDataType(a.scalar_type());
auto b_dtype = getCnnlDataType(b.scalar_type());
TORCH_CHECK(a_dtype == b_dtype, "a, b must be same dtype.");
if (a_dtype == CNNL_DTYPE_INT8) {
TORCH_CHECK(d.has_value() || dtype.has_value(),
"d and d_dtype cant be none at the same time when input dtype is int8");
}
at::Tensor output;
if (d.has_value()) {
output = d.value();
} else {
torch::Dtype out_dtype = a.scalar_type();
if (dtype.has_value()) {
TORCH_CHECK(
dtype.value() == "float" || dtype.value() == "half" || dtype.value() == "bfloat16",
"data_type must be 'float', 'half' or 'bfloat16'.");
out_dtype = dtype.value() == "float" ? torch::kFloat32
: dtype.value() == "half" ? torch::kFloat16
: torch::kBFloat16;
}
output = at::empty({m, n}, a.options().dtype(out_dtype));
}
auto cd_dtype = getCnnlDataType(output.scalar_type());
if (a_dtype == CNNL_DTYPE_BFLOAT16) cd_dtype = CNNL_DTYPE_FLOAT;
// create tensor desc
auto descs = createTensorDescs({a, b, bias_view, c.value_or(at::Tensor()), output});
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[0].get(), a_dtype));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[1].get(), b_dtype));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[4].get(), cd_dtype));
if (has_res) {
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[3].get(), cd_dtype));
}
if (a_dtype == CNNL_DTYPE_INT8) {
float int_max = 127.0;
int quant_bit = 8;
float max_a = int_max / a_scale;
float max_b = int_max / b_scale;
int pos_a = std::floor(std::log2(max_a) - (quant_bit - 2));
int pos_b = std::floor(std::log2(max_b) - (quant_bit - 2));
float new_a_scale = std::pow(2.0f, pos_a) * a_scale;
float new_b_scale = std::pow(2.0f, pos_b) * b_scale;
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[0].get(), pos_a, new_a_scale));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[1].get(), pos_b, new_b_scale));
TORCH_CHECK(cd_dtype != CNNL_DTYPE_BFLOAT16,
"output dtype cannot be bfloat16 when a/b is fixed-point")
}
// get && set Op desc
int lda = a.stride(0);
int ldb = b.stride(0);
int ldc = output.stride(0);
size_t workspace_size = 0;
auto matmul_desc =
tmo::op_desc::MatMulDesc(descs[2].get(), getAtTensorPtr(bias_view), workspace_size, lda, ldb,
ldc, act_mode, use_beta, trans_a, trans_b, fast_act, approximate);
const torch_mlu::mlu::MLUGuard device_guard(a.device());
auto handle = torch_mlu::getCurrentHandle();
auto workspace = at::empty({static_cast<int64_t>(workspace_size)}, a.options().dtype(at::kByte));
// run forward
float alpha_f = alpha;
float beta_f = beta;
MatmulTheory obj(m, k, n, has_res, a_dtype, cd_dtype);
cnpxPush(obj);
CNNL_CHECK_FATAL(cnnlMatMulEx(handle, matmul_desc, &alpha_f, descs[0].get(), getAtTensorPtr(a),
descs[1].get(), getAtTensorPtr(b), &beta_f, descs[3].get(),
getAtTensorPtr(c), descs[4].get(), getAtTensorPtr(output), nullptr,
getAtTensorPtr(workspace), workspace_size));
cnpxPop();
return output;
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,52 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <numeric>
#include "kernels/moe/cast_gating.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
at::Tensor moe_cast_gating(const at::Tensor &input, const at::Tensor &weight) {
// check
checkTensorSameAttr<TensorAttr::DEVICE>(input, weight);
TORCH_CHECK(input.is_contiguous(), "input must be contiguous")
TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous")
TORCH_CHECK(weight.dim() == 2, "weight.dim() == 2")
TORCH_CHECK(input.size(-1) == weight.size(-1), "input.size(-1) == weight.size(-1)")
TORCH_CHECK(input.scalar_type() == torch::kFloat16 || input.scalar_type() == torch::kBFloat16,
"input type need be torch::kFloat16 or torch::kBFloat16")
TORCH_CHECK(weight.scalar_type() == torch::kFloat32, "weight type need be torch::kFloat32")
auto hidden_size = input.size(-1);
auto expert_num = weight.size(0);
TORCH_CHECK(hidden_size > 0 && hidden_size <= 16384, "hidden_size > 0 && hidden_size <= 16384")
TORCH_CHECK(expert_num > 0 && expert_num <= 128, "expert_num > 0 && expert_num <= 128")
auto input_shape = input.sizes(); // [..., hidden_size]
auto input_row =
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, std::multiplies<int>());
std::vector<int64_t> output_shape(input_shape.begin(), input_shape.end() - 1);
output_shape.push_back(expert_num);
const torch_mlu::mlu::MLUGuard device_guard(input.device());
auto output = at::empty(output_shape, weight.options());
const int64_t workspace_size = 16 * 1024 * 1024;
auto workspace = at::empty({workspace_size}, weight.options().dtype(at::kByte));
auto queue = torch_mlu::getCurMLUStream();
auto input_dtype = getCnnlDataType(input.scalar_type());
TMO_KERNEL_CHECK_FATAL(tmo::invokeCastGating(
queue, input.data_ptr(), weight.data_ptr(), output.data_ptr(), input_row, expert_num,
hidden_size, input_dtype, workspace.data_ptr(), workspace_size));
return output;
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,124 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/moe/combine_result.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
at::Tensor moe_combine_result(const at::Tensor &input,
const at::Tensor &reduce_weight,
const at::Tensor &gather_ids,
const c10::optional<at::Tensor> &residual,
const c10::optional<at::Tensor> &cusum_token_count,
int64_t start_expert_id,
int64_t expert_size,
const c10::optional<at::Tensor> &bias) {
// check device and dtype
checkTensorSameAttr<TensorAttr::DEVICE>(input, reduce_weight, residual, bias, cusum_token_count,
gather_ids);
// check contiguous
CHECK_TENSOR_CONTIGUOUS(input)
CHECK_TENSOR_CONTIGUOUS(reduce_weight)
CHECK_TENSOR_CONTIGUOUS(gather_ids)
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(residual)
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(bias)
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(cusum_token_count)
// do not support bias
TORCH_CHECK(!bias.has_value(), "bias is not supported.");
// check input
TORCH_CHECK(input.dim() == 2, "input should be with shape [num_tokens, hidden_size].");
TORCH_CHECK(input.dtype() == torch::kFloat32 || input.dtype() == torch::kFloat16 ||
input.dtype() == torch::kBFloat16,
"input dtype should be float32/float16/bfloat16.");
int num_tokens = input.size(0);
int hidden_size = input.size(1);
// check reduce weight
TORCH_CHECK(reduce_weight.dim() == 2, "reduce_weight should be with shape [num_token, topk].");
TORCH_CHECK(reduce_weight.dtype() == torch::kFloat32,
"reduce_weight should be with dtype float32.");
int topk = reduce_weight.size(1);
TORCH_CHECK(reduce_weight.numel() == num_tokens,
"reduce_weight.numel() should equal to input.size(0).");
int num_token = num_tokens / topk;
// check gather_ids
TORCH_CHECK(gather_ids.dim() == 1, "gather_ids should be with shape [num_tokens].");
TORCH_CHECK(gather_ids.dtype() == torch::kInt32, "gather_ids should be with dtype int32.");
TORCH_CHECK(gather_ids.size(0) == num_tokens,
"gather_ids.numel() should equal to input.size(0).");
if (residual.has_value()) {
TORCH_CHECK(residual.value().dim() == 2,
"residual should be with shape [num_token, hidden_size].");
TORCH_CHECK(residual.value().dtype() == input.dtype(),
"residual should have same dtype with input.");
TORCH_CHECK(residual.value().size(0) == num_token,
"residual.size(0) should equal to input.size(0).");
TORCH_CHECK(residual.value().size(1) == hidden_size,
"residual.size(1) should equal to input.size(1).");
}
// check bias and cusum_token_count
bool has_bias = bias.has_value();
bool has_cusum_token_count = cusum_token_count.has_value();
TORCH_CHECK((!has_bias || (has_bias && has_cusum_token_count)),
"if bias is not None, cusum_token_count should not be None.");
int num_expert = -1;
if (has_bias) {
TORCH_CHECK(bias.value().dim() == 2, "bias should be with shape [num_expert, hidden_size].");
TORCH_CHECK(bias.value().dtype() == input.dtype(), "bias should have same dtype with input.");
TORCH_CHECK(bias.value().size(1) == hidden_size, "bias.size(1) should equal to input.size(1).");
num_expert = bias.value().size(0);
}
if (has_cusum_token_count) {
TORCH_CHECK(cusum_token_count.value().dim() == 1,
"cusum_token_count should be with shape [num_expert + 1].");
TORCH_CHECK(cusum_token_count.value().dtype() == torch::kInt32,
"cusum_token_count should be with dtype int32.");
if (num_expert > 0) {
TORCH_CHECK(cusum_token_count.value().dim() == 1,
"cusum_token_count should be with shape [num_expert + 1].");
} else {
num_expert = cusum_token_count.value().size(0) - 1;
}
}
// check expert
TORCH_CHECK(start_expert_id >= 0, "start_expert_id shoule be larger or equal to 0.");
TORCH_CHECK(num_expert == -1 || num_expert >= (start_expert_id + expert_size),
"num_expert shape shoule be larger or equal to start_expert_id + expert_size.");
if (num_expert == -1) {
num_expert = expert_size;
}
const torch_mlu::mlu::MLUGuard device_guard(input.device());
auto queue = torch_mlu::getCurMLUStream();
auto output_shape = std::vector<int64_t>({num_token, hidden_size});
auto output = at::empty(output_shape, input.options());
auto dtype = getCnnlDataType(input.scalar_type());
invokeMoeCombineResultKernel(queue, output.data_ptr(), input.data_ptr(), getAtTensorPtr(bias),
getAtTensorPtr(residual), (float *)reduce_weight.data_ptr(),
(int *)getAtTensorPtr(cusum_token_count),
(int *)getAtTensorPtr(gather_ids), num_token, topk, num_expert,
hidden_size, start_expert_id, expert_size, dtype);
return output;
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,60 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/moe/expand_input.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
at::Tensor moe_expand_input(const torch::Tensor &input,
const torch::Tensor &gather_idx,
const c10::optional<torch::Tensor> &cusum_token_count,
int64_t start_expert_id,
int64_t expert_size) {
TORCH_CHECK(input.dim() == 2, "input dim must be equal to 2.")
TORCH_CHECK(gather_idx.dim() == 1, "gather_idx dim must be equal to 1.")
// check device and dtype
TORCH_CHECK(isMlu(input), "input tensor must on mlu.")
TORCH_CHECK(isMlu(gather_idx), "gather_idx must on mlu.")
TORCH_CHECK(gather_idx.dtype() == torch::kInt32, "data type of gather_idx must be int32.")
int64_t expand_token_num = gather_idx.size(0);
int64_t token_num = input.size(0);
int64_t hidden_size = input.size(1);
TORCH_CHECK(expand_token_num % token_num == 0, "expand_token_num % token_num == 0.")
int64_t topk = expand_token_num / token_num;
int64_t expert_num = 0;
if (cusum_token_count.has_value()) {
TORCH_CHECK(isMlu(cusum_token_count.value()), "cusum_token_count must on mlu.")
TORCH_CHECK(cusum_token_count.value().dtype() == torch::kInt32,
"data type of cusum_token_count must be int32.")
expert_num = cusum_token_count.value().size(0) - 1;
TORCH_CHECK(start_expert_id >= 0 && start_expert_id < expert_num,
"start_expert_id >=0 && start_expert_id < expert_num.")
TORCH_CHECK((start_expert_id + expert_size) <= expert_num,
"start_expert_id + expert_size <= expert_num.")
}
const torch_mlu::mlu::MLUGuard device_guard(input.device());
std::vector<int64_t> output_shape = {expand_token_num, hidden_size};
auto output = torch::empty(output_shape, input.options());
auto dtype = getCnnlDataType(input.scalar_type());
auto queue = torch_mlu::getCurMLUStream();
tmo::invokeMoeExpandInputKernel(queue, getAtTensorPtr(output), getAtTensorPtr(input),
(int *)getAtTensorPtr(gather_idx),
(int *)getAtTensorPtr(cusum_token_count), token_num, hidden_size,
topk, dtype, expert_num, start_expert_id, expert_size);
return output;
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,44 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/moe/gen_idx.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
std::vector<at::Tensor> moe_gen_idx(const torch::Tensor &expert_id, int64_t expert_num) {
TORCH_CHECK(expert_id.dim() == 2, "expert_id dim must be equal to 2.")
int64_t token_num = expert_id.size(0);
int64_t topk = expert_id.size(1);
// check device and dtype
TORCH_CHECK(isMlu(expert_id), "expert_id must on mlu.")
TORCH_CHECK(expert_id.dtype() == torch::kInt32, "data type of expert_id must be int32.")
const torch_mlu::mlu::MLUGuard device_guard(expert_id.device());
auto expand_idx = at::empty({token_num * topk}, expert_id.options());
auto combine_idx = at::empty({token_num * topk}, expert_id.options());
auto token_count = at::empty({expert_num}, expert_id.options());
auto cusum_token_count = at::empty({expert_num + 1}, expert_id.options());
auto gen_idx_workspace = at::empty({expert_num + 1 + token_num * topk}, expert_id.options());
auto queue = torch_mlu::getCurMLUStream();
tmo::invokeMoeGenIdxKernel(
queue, (int *)getAtTensorPtr(expand_idx), (int *)getAtTensorPtr(combine_idx),
(int *)getAtTensorPtr(token_count), (int *)getAtTensorPtr(cusum_token_count),
getAtTensorPtr(gen_idx_workspace), getAtTensorPtr(expert_id), token_num, expert_num, topk);
std::vector<at::Tensor> output = {expand_idx, combine_idx, token_count, cusum_token_count};
return output;
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,83 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/moe/softmax_topk.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
std::vector<at::Tensor> moe_softmax_topk(const at::Tensor &input, // (num_token, num_expert)
int64_t topk,
int64_t num_expert_group,
int64_t topk_group,
bool normalize,
const c10::optional<at::Tensor> &mask,
const std::string &normed_by) {
CHECK_TENSOR_CONTIGUOUS(input)
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(mask)
int normalize_mode = 0;
if (normalize) {
TORCH_CHECK(normed_by == "topk_logit" || normed_by == "softmax_logit",
"normed_by must be 'topk_logit' or 'softmax_logit'")
if (normed_by == "topk_logit") {
normalize_mode = 1;
} else if (normed_by == "softmax_logit") {
normalize_mode = 2;
}
}
TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2")
int num_expert = input.size(-1);
int num_token = (int)(input.numel() / num_expert);
int num_mask = 1;
auto input_shape = input.sizes();
TORCH_CHECK(topk > 0 && topk <= num_expert, "topk > 0 && topk <= num_expert")
bool has_mask = mask.has_value();
if (has_mask) {
TORCH_CHECK(mask.value().dim() == input.dim(), "the dim of mask should be the same as input")
TORCH_CHECK(mask.value().size(-1) == num_expert,
"the last dim of mask should be the same as the last dim of input")
TORCH_CHECK(mask.value().size(-2) == input.size(-2),
"the penultimate dim of mask should be the same as the penultimate dim of input")
int high_dim = mask.value().numel() / (mask.value().size(-1) * mask.value().size(-2));
TORCH_CHECK(high_dim == 1, "the product of all but the lower two dimensions of mask is 1")
num_mask = mask.value().numel() / num_expert;
TORCH_CHECK(mask.value().dtype() == input.dtype(),
"the dtype of mask should be the same as input")
}
if (num_expert_group > 1) {
TORCH_CHECK(has_mask == false, "if num_expert_group > 1, mask should be None")
TORCH_CHECK(num_expert % num_expert_group == 0, "num_expert % num_expert_group == 0")
TORCH_CHECK(topk_group > 0 && topk_group <= num_expert_group,
"topk_group > 0 && topk_group <= num_expert_group")
TORCH_CHECK(topk <= (num_expert / num_expert_group) * topk_group,
"topk <= (num_expert / num_expert_group) * topk_group")
}
std::vector<int64_t> out_shape(input_shape.begin(), input_shape.end());
out_shape.back() = topk;
const torch_mlu::mlu::MLUGuard device_guard(input.device());
auto tensor_options = input.options();
auto reduce_weight = at::empty(out_shape, tensor_options.dtype(torch::kFloat));
auto expert_id = at::empty(out_shape, tensor_options.dtype(torch::kInt32));
auto queue = torch_mlu::getCurMLUStream();
auto input_dtype = getCnnlDataType(input.scalar_type());
TMO_KERNEL_CHECK_FATAL(invokeMoeSoftmaxTopkKernel(
queue, (float *)reduce_weight.data_ptr(), (int *)expert_id.data_ptr(), input.data_ptr(),
getAtTensorPtr(mask), num_token, num_expert, num_mask, topk, num_expert_group, topk_group,
input_dtype, normalize_mode));
return {reduce_weight, expert_id};
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,177 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/offline_quant_to_linear_cache.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void offline_quant_to_linear_cache(
const at::Tensor &key, // [total_seqlen, num_heads, head_size]
// or [bs, max_context_len, num_heads, head_size]
const c10::optional<at::Tensor> &value,
at::Tensor &key_cache, // [max_bs, num_heads, cache_memory_len, head_size]
const c10::optional<at::Tensor> &value_cache,
const at::Tensor &key_cache_scale, // [num_heads, cache_memory_len] or [num_heads, head_size]
const c10::optional<at::Tensor> &value_cache_scale,
const at::Tensor &context_lengths,
const int64_t max_context_len,
const int64_t quant_mode, // 0:per_channel, others:per_head
const bool packed,
const c10::optional<at::Tensor> &context_seq_offset,
const c10::optional<at::Tensor> &cache_bs_id,
const c10::optional<at::Tensor> &cache_seqlen_offset) {
/****************************************check key***************************************/
// check dtype
TORCH_CHECK(key.scalar_type() == torch::kFloat32 || key.scalar_type() == torch::kFloat16 ||
key.scalar_type() == torch::kBFloat16,
"key type need be torch::kFloat32, torch::kFloat16 or torch::kBFloat16");
TORCH_CHECK(key_cache.scalar_type() == torch::kInt8, "key_cache type need be torch::kInt8");
TORCH_CHECK(key_cache_scale.scalar_type() == torch::kFloat32,
"key_cache_scale type need be torch::kFloat32");
// check shape
const int batch_size = packed ? context_lengths.size(0) - 1 : context_lengths.size(0);
const int max_bs = key_cache.size(0);
const int num_heads = key_cache.size(1);
const int cache_mem_len = key_cache.size(2);
const int head_size = key_cache.size(3);
const int total_seqlen = packed ? key.size(0) : batch_size * key.size(-2);
TORCH_CHECK(key_cache.size(0) >= batch_size,
"first dim of key_cache must be great than or equal to batch_size");
if (packed) {
CHECK_SHAPE(key, total_seqlen, num_heads, head_size);
CHECK_SHAPE(context_lengths, batch_size + 1);
} else {
CHECK_SHAPE(key, batch_size, key.size(1), num_heads, head_size);
CHECK_SHAPE(context_lengths, batch_size);
}
if (quant_mode == 0) {
CHECK_SHAPE(key_cache_scale, num_heads, head_size);
} else {
CHECK_SHAPE(key_cache_scale, num_heads, cache_mem_len);
}
// check contiguous
TORCH_CHECK(key_cache.is_contiguous(), "key_cache tensor must be contiguous.");
TORCH_CHECK(key_cache_scale.is_contiguous(), "key_cache_scale tensor must be contiguous.");
// check device
const int64_t device_id = key.get_device();
TORCH_CHECK(key_cache.get_device() == device_id,
"key_cache tensor device index is not same, original index: ", device_id,
"now index is: ", key_cache.get_device());
TORCH_CHECK(key_cache_scale.get_device() == device_id,
"key_cache_scale tensor device index is not same, original index: ", device_id,
"now index is: ", key_cache_scale.get_device());
/***************************************check value***************************************/
if (value.has_value() || value_cache.has_value() || value_cache_scale.has_value()) {
TORCH_CHECK(value.has_value() && value_cache.has_value() && value_cache_scale.has_value(),
"value_cache, value_cache_scale, value must all have value.")
}
if (value_cache.has_value()) {
TORCH_CHECK(value_cache.value().scalar_type() == torch::kInt8,
"value_cache type need be torch::kInt8");
TORCH_CHECK(value_cache_scale.value().scalar_type() == torch::kFloat32,
"value_cache_scale type need be torch::kFloat32");
TORCH_CHECK(value.value().scalar_type() == key.scalar_type(),
"value type need be same with key");
if (packed) {
CHECK_SHAPE(value.value(), total_seqlen, num_heads, head_size);
} else {
CHECK_SHAPE(value.value(), batch_size, key.size(1), num_heads, head_size);
}
if (quant_mode == 0) {
CHECK_SHAPE(value_cache_scale.value(), num_heads, head_size);
} else {
CHECK_SHAPE(value_cache_scale.value(), num_heads, cache_mem_len);
}
CHECK_SHAPE(value_cache.value(), max_bs, num_heads, cache_mem_len, head_size);
for (int i = 0; i < key.dim(); i++) {
TORCH_CHECK(value.value().stride(i) == key.stride(i),
"key and value must have same stride along axi ", i);
}
TORCH_CHECK(value_cache.value().is_contiguous(), "value_cache tensor must be contiguous.");
TORCH_CHECK(value_cache_scale.value().is_contiguous(),
"value_cache_scale tensor must be contiguous.");
TORCH_CHECK(value_cache.value().get_device() == device_id,
"value_cache tensor device index is not same, original index: ", device_id,
"now index is: ", value_cache.value().get_device());
TORCH_CHECK(value_cache_scale.value().get_device() == device_id,
"value_cache_scale tensor device index is not same, original index: ", device_id,
"now index is: ", value_cache_scale.value().get_device());
}
TORCH_CHECK(context_lengths.scalar_type() == torch::kInt32,
"context_lengths type need be torch::kInt32");
/*********************************check optional tensor***********************************/
const void *context_seq_offset_ptr = nullptr;
if (context_seq_offset.has_value()) {
CHECK_SHAPE(context_seq_offset.value(), batch_size);
TORCH_CHECK(context_seq_offset.value().scalar_type() == torch::kInt32,
"context_seq_offset type need be torch::kInt32");
TORCH_CHECK(context_seq_offset.value().get_device() == device_id,
"context_seq_offset tensor device index is not same, original index: ", device_id,
"now index is: ", context_seq_offset.value().get_device());
context_seq_offset_ptr = context_seq_offset.value().data_ptr();
}
const void *cache_bs_id_ptr = nullptr;
if (cache_bs_id.has_value()) {
CHECK_SHAPE(cache_bs_id.value(), batch_size);
TORCH_CHECK(cache_bs_id.value().scalar_type() == torch::kInt32,
"cache_bs_id type need be torch::kInt32");
TORCH_CHECK(cache_bs_id.value().get_device() == device_id,
"cache_bs_id tensor device index is not same, original index: ", device_id,
"now index is: ", cache_bs_id.value().get_device());
cache_bs_id_ptr = cache_bs_id.value().data_ptr();
}
const void *cache_seqlen_offset_ptr = nullptr;
if (cache_seqlen_offset.has_value()) {
CHECK_SHAPE(cache_seqlen_offset.value(), batch_size);
TORCH_CHECK(cache_seqlen_offset.value().scalar_type() == torch::kInt32,
"cache_seqlen_offset type need be torch::kInt32");
TORCH_CHECK(cache_seqlen_offset.value().get_device() == device_id,
"cache_seqlen_offset tensor device index is not same, original index: ", device_id,
"now index is: ", cache_seqlen_offset.value().get_device());
cache_seqlen_offset_ptr = cache_seqlen_offset.value().data_ptr();
}
const size_t context_bs_stride = packed ? 0 : key.stride(0);
const size_t context_head_stride = packed ? key.stride(1) : key.stride(2);
const size_t context_seq_stride = packed ? key.stride(0) : key.stride(1);
const size_t cache_bs_stride = key_cache.stride(0);
const size_t cache_head_stride = key_cache.stride(1);
const size_t cache_seq_stride = key_cache.stride(2);
const size_t cache_scale_head_stride = key_cache_scale.stride(0);
const torch_mlu::mlu::MLUGuard device_guard(key.device());
auto data_dtype = getCnnlDataType(key.scalar_type());
auto queue = torch_mlu::getCurMLUStream();
// run forward
TMO_KERNEL_CHECK_FATAL(invokeOfflineQuantToLinearCache(
queue, getAtTensorPtr(key_cache), getAtTensorPtr(value_cache),
getAtTensorPtr(key_cache_scale), getAtTensorPtr(value_cache_scale), cache_bs_id_ptr,
cache_seqlen_offset_ptr, getAtTensorPtr(key), getAtTensorPtr(value), context_seq_offset_ptr,
getAtTensorPtr(context_lengths), data_dtype, batch_size, num_heads, head_size,
(int)max_context_len, cache_mem_len, context_bs_stride, context_head_stride,
context_seq_stride, cache_bs_stride, cache_head_stride, cache_seq_stride,
cache_scale_head_stride, packed, quant_mode));
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,89 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/offline_quant_to_paged_cache.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void offline_quant_to_paged_cache(
const torch::Tensor &k, // [num_tokens, num_heads, head_size]
const c10::optional<torch::Tensor> &v, // [num_tokens, num_heads, head_size]
const torch::Tensor &k_cache_scale, // [num_heads, head_size]
const c10::optional<torch::Tensor> &v_cache_scale, // [num_heads, head_size]
const torch::Tensor &slot_mapping,
torch::Tensor &k_cache, // [num_blocks, num_heads, block_size, head_size]
const c10::optional<torch::Tensor>
&v_cache) { // [num_blocks, num_heads, block_size, head_size]
// 1. check device and tensor type
TORCH_CHECK(slot_mapping.dtype() == torch::kInt32, "slot_mapping type need be torch::kInt32");
checkTensorSameAttr<TensorAttr::DEVICE>(k, v, k_cache_scale, v_cache_scale, slot_mapping);
// 2. check dim and shape
TORCH_CHECK(k.dim() == 3, "dim of k must be 3");
TORCH_CHECK(k_cache.dim() == 4, "dim of k_cache must be 4");
TORCH_CHECK(k_cache_scale.dim() == 2, "dim of k_cache_scale must be 2");
TORCH_CHECK(
v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value(),
"v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value().");
if (v.has_value()) {
TORCH_CHECK(v.value().dim() == 3, "dim of v must be 3");
TORCH_CHECK(v_cache.value().dim() == 4, "dim of v_cache must be 4");
TORCH_CHECK(v_cache_scale.value().dim() == 2, "dim of v_cache_scale must be 2");
}
TORCH_CHECK(slot_mapping.dim() == 1, "dim of slot_mapping must be 1");
const int num_tokens = k.size(0);
const int num_heads = k.size(1);
const int head_size = k.size(2);
const int block_size = k_cache.size(2);
const int num_blocks = k_cache.size(0);
CHECK_SHAPE(k, num_tokens, num_heads, head_size);
CHECK_SHAPE(k_cache_scale, num_heads, head_size);
CHECK_SHAPE(k_cache, num_blocks, num_heads, block_size, head_size);
if (v.has_value()) {
CHECK_SHAPE(v.value(), num_tokens, num_heads, head_size);
CHECK_SHAPE(v_cache_scale.value(), num_heads, head_size);
CHECK_SHAPE(v_cache.value(), num_blocks, num_heads, block_size, head_size);
}
CHECK_SHAPE(slot_mapping, num_tokens);
// 3. check strides
TORCH_CHECK(slot_mapping.is_contiguous(), "slot_mapping need be contiguous.");
TORCH_CHECK(k.stride(-1) == 1, "k last dim must be contiguous.");
TORCH_CHECK(k.stride(-2) == head_size, "k second dim must be contiguous.");
if (v.has_value()) {
TORCH_CHECK(v.value().stride(-1) == 1, "v last dim must be contiguous.");
TORCH_CHECK(v.value().stride(-2) == head_size, "v second dim must be contiguous.");
}
int64_t kv_cache_range =
(int64_t)num_blocks * num_heads * block_size * head_size * k_cache.element_size();
TORCH_CHECK(kv_cache_range <= UINT32_MAX, "The addressing range of kv_cache cannot exceed 4G.");
const torch_mlu::mlu::MLUGuard device_guard(k.device());
auto queue = torch_mlu::getCurMLUStream();
cnnlDataType_t dtype = getCnnlDataType(k.scalar_type());
int32_t key_stride0 = static_cast<int32_t>(k.stride(0));
int32_t value_stride0 = 1;
if (v.has_value()) {
value_stride0 = static_cast<int32_t>(v.value().stride(0));
}
TMO_KERNEL_CHECK_FATAL(tmo::invokeOfflineQuantToPagedCache(
queue, dtype, getAtTensorPtr(k), getAtTensorPtr(v), getAtTensorPtr(k_cache),
getAtTensorPtr(v_cache), getAtTensorPtr(k_cache_scale), getAtTensorPtr(v_cache_scale),
getAtTensorPtr(slot_mapping), key_stride0, value_stride0, num_tokens, num_heads, num_blocks,
block_size, head_size));
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,423 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_TORCH_API_OP_THEORY_H_
#define CSRC_TORCH_API_OP_THEORY_H_
#include "utils.h"
namespace tmo {
namespace torch_api {
// base class
struct OpTheory {
virtual size_t getTheoryCalc(void) const = 0;
virtual size_t getTheoryIO(void) const = 0;
virtual cnnlDataType_t getCalcDtype(void) const = 0;
virtual const std::string &getOpName(void) const = 0;
};
struct MatmulTheory : public OpTheory {
int m_, k_, n_;
bool has_res_;
cnnlDataType_t calc_type_, data_type_;
std::string name_;
MatmulTheory(int m,
int k,
int n,
bool has_res,
cnnlDataType_t calc_type,
cnnlDataType_t data_type)
: OpTheory(),
m_(m),
k_(k),
n_(n),
has_res_(has_res),
calc_type_(calc_type),
data_type_(data_type),
name_("matmul") {}
// lt: m * n * k, ct: add_bias + act
size_t getTheoryCalc(void) const override {
size_t size_calc = (size_t)m_ * n_ * k_;
return size_calc;
}
size_t getTheoryIO(void) const override {
size_t calc_dw, data_dw;
cnnlGetSizeOfDataType(calc_type_, &calc_dw);
cnnlGetSizeOfDataType(data_type_, &data_dw);
return calc_dw * (m_ * k_ + n_ * k_) + data_dw * m_ * n_ * (1 + has_res_);
}
cnnlDataType_t getCalcDtype(void) const override { return calc_type_; }
const std::string &getOpName(void) const override { return name_; }
};
struct QuantMatmulTheory : public OpTheory {
int m_, k_, n_, quant_bit_size_;
bool has_res_;
cnnlDataType_t data_type_;
std::string quant_algo_, name_;
QuantMatmulTheory(int m,
int k,
int n,
int quant_bit_size,
bool has_res,
cnnlDataType_t data_type,
std::string quant_algo)
: OpTheory(),
m_(m),
k_(k),
n_(n),
quant_bit_size_(quant_bit_size),
has_res_(has_res),
data_type_(data_type),
quant_algo_(quant_algo),
name_("quant_matmul") {}
size_t getTheoryCalc(void) const override {
size_t size_calc = (size_t)m_ * n_ * k_;
return size_calc;
}
size_t getTheoryIO(void) const override {
size_t io_size = 0, data_dw = 0;
cnnlGetSizeOfDataType(data_type_, &data_dw);
if (quant_algo_ == "weight_only") {
io_size = (data_dw * m_ * k_ + m_ * n_ * (1 + has_res_)) + n_ * k_ * (quant_bit_size_ / 8.0);
} else {
io_size = data_dw * m_ * n_ * (1 + has_res_) + (m_ * k_ + n_ * k_) * (quant_bit_size_ / 8.0);
}
return io_size;
}
cnnlDataType_t getCalcDtype(void) const override {
if (quant_algo_ == "weight_only")
return data_type_;
else
return CNNL_DTYPE_INT8;
}
const std::string &getOpName(void) const override { return name_; }
};
struct GroupGemmTheory : public OpTheory {
int total_m_, experts_num_, k_, n_, quant_bit_size_;
bool has_res_;
cnnlDataType_t calc_type_, data_type_;
std::string name_;
GroupGemmTheory(int total_m,
int experts_num, // The max number of processed experts.
int k,
int n,
bool has_res,
cnnlDataType_t calc_type,
cnnlDataType_t data_type)
: OpTheory(),
total_m_(total_m),
experts_num_(experts_num),
k_(k),
n_(n),
has_res_(has_res),
calc_type_(calc_type),
data_type_(data_type),
name_("group_gemm") {} // smooth_quant or float-point
size_t getTheoryCalc(void) const override {
size_t size_calc = (size_t)total_m_ * n_ * k_;
return size_calc;
}
size_t getTheoryIO(void) const override {
size_t calc_dw, data_dw;
cnnlGetSizeOfDataType(calc_type_, &calc_dw);
cnnlGetSizeOfDataType(data_type_, &data_dw);
// assume load max_num weights
size_t size_io = calc_dw * ((size_t)total_m_ * k_ + (size_t)experts_num_ * n_ * k_) +
data_dw * total_m_ * n_ * (1 + has_res_);
return size_io;
}
cnnlDataType_t getCalcDtype(void) const override { return calc_type_; }
const std::string &getOpName(void) const override { return name_; }
};
struct FlashAttnTheory : public OpTheory {
int batch_, total_q_, total_k_, head_num_q_, head_num_kv_, head_size_qk_, head_size_v_;
bool is_causal_;
cnnlDataType_t data_type_;
std::string name_;
FlashAttnTheory(int batch,
int total_q,
int total_k,
int head_num_q,
int head_num_kv,
int head_size_qk,
int head_size_v,
bool is_causal,
cnnlDataType_t data_type)
: OpTheory(),
batch_(batch),
total_q_(total_q),
total_k_(total_k),
head_num_q_(head_num_q),
head_num_kv_(head_num_kv),
head_size_qk_(head_size_qk),
head_size_v_(head_size_v),
is_causal_(is_causal),
data_type_(data_type),
name_("flash_attention") {}
size_t getTheoryCalc(void) const override {
// q * k + qk * v
size_t seq_q = total_q_ / batch_;
size_t seq_kv = total_k_ / batch_;
if (is_causal_) seq_kv = seq_kv - 0.5 * seq_q;
size_t size_lt = (size_t)batch_ * head_num_q_ * seq_q * seq_kv * (head_size_qk_ + head_size_v_);
return size_lt;
}
size_t getTheoryIO(void) const override {
size_t data_dw;
cnnlGetSizeOfDataType(data_type_, &data_dw);
size_t size_io = data_dw * (head_num_q_ * total_q_ + head_num_kv_ * total_k_) *
(head_size_qk_ + head_size_v_);
return size_io;
}
cnnlDataType_t getCalcDtype(void) const override { return data_type_; }
const std::string &getOpName(void) const override { return name_; }
};
struct SingleQueryCachedKvAttnTheory : public OpTheory {
int batch_, seq_q_, head_num_q_, head_num_kv_, num_block_, block_size_, head_size_qk_,
head_size_v_, kv_cache_quant_bit_size_;
cnnlQuantizeLayout_t cache_quant_layout_;
cnnlDataType_t data_type_;
std::string name_;
SingleQueryCachedKvAttnTheory(int batch,
int seq_q,
int head_num_q,
int head_num_kv,
int num_block,
int block_size,
int head_size_qk,
int head_size_v,
int kv_cache_quant_bit_size,
cnnlQuantizeLayout_t cache_quant_layout,
cnnlDataType_t data_type)
: OpTheory(),
batch_(batch),
seq_q_(seq_q),
head_num_q_(head_num_q),
head_num_kv_(head_num_kv),
num_block_(num_block),
block_size_(block_size),
head_size_qk_(head_size_qk),
head_size_v_(head_size_v),
kv_cache_quant_bit_size_(kv_cache_quant_bit_size),
cache_quant_layout_(cache_quant_layout),
data_type_(data_type),
name_("single_query_cached_kv_attn") {}
size_t getTheoryCalc(void) const override {
// q * k + qk * v
size_t size_lt =
(size_t)seq_q_ * head_num_q_ * num_block_ * block_size_ * (head_size_qk_ + head_size_v_);
return size_lt;
}
size_t getTheoryIO(void) const override {
size_t data_dw = 0, dw_float = 0, size_scale = 0;
cnnlGetSizeOfDataType(data_type_, &data_dw);
cnnlGetSizeOfDataType(CNNL_DTYPE_FLOAT, &dw_float);
size_t cache_dw = kv_cache_quant_bit_size_;
if (kv_cache_quant_bit_size_ == -1) {
cache_dw = data_dw * 8;
} else {
if (cache_quant_layout_ == CNNL_QUANTIZE_PER_CHANNEL) {
size_scale = dw_float * head_num_kv_ * (head_size_qk_ + head_size_v_);
} else {
size_scale = dw_float * num_block_ * head_num_kv_ * block_size_ * 2;
}
}
size_t size_io =
size_scale + data_dw * batch_ * seq_q_ * head_num_q_ * (head_size_qk_ + head_size_v_) +
(cache_dw / 8) * head_num_kv_ * num_block_ * block_size_ * (head_size_qk_ + head_size_v_);
return size_io;
}
cnnlDataType_t getCalcDtype(void) const override { return data_type_; }
const std::string &getOpName(void) const override { return name_; }
};
struct FusedNormTheory : public OpTheory {
int num_token_, hidden_size_;
bool has_res_, has_bias_, quant_out_, dynamic_quant_, store_output_before_norm_;
cnnlDataType_t data_type_;
std::string norm_mode_, name_;
FusedNormTheory(int num_token,
int hidden_size,
bool has_res,
bool has_bias,
bool quant_out,
bool dynamic_quant,
bool store_output_before_norm,
cnnlDataType_t data_type,
std::string norm_mode)
: OpTheory(),
num_token_(num_token),
hidden_size_(hidden_size),
has_res_(has_res),
has_bias_(has_bias),
quant_out_(quant_out),
dynamic_quant_(dynamic_quant),
store_output_before_norm_(store_output_before_norm),
data_type_(data_type),
norm_mode_(norm_mode),
name_("fused_norm") {}
size_t getTheoryCalc(void) const override {
size_t size_ct = 0;
if (norm_mode_ == "layernorm") {
size_ct = 7; // avg + sub + square + sum + cycle_mul + mul gamma + add beta,
} else if (norm_mode_ == "rmsnorm") {
size_ct = 4; // square + avg + cycle_mul + mul gamma
}
size_ct += (has_res_ + has_bias_);
if (dynamic_quant_)
size_ct += 4;
else if (quant_out_)
size_ct += 1;
return size_ct * num_token_ * hidden_size_;
}
size_t getTheoryIO(void) const override {
size_t data_dw, dw_int8, dw_float;
cnnlGetSizeOfDataType(CNNL_DTYPE_FLOAT, &dw_float);
cnnlGetSizeOfDataType(data_type_, &data_dw);
cnnlGetSizeOfDataType(CNNL_DTYPE_INT8, &dw_int8);
size_t size_io = data_dw * num_token_ * hidden_size_;
if (has_res_) size_io += data_dw * num_token_ * hidden_size_;
if (has_bias_) size_io += data_dw * hidden_size_;
size_io += hidden_size_ * data_dw * (1 + (norm_mode_ == "layernorm")); // gamma&beta
if (quant_out_) {
size_io += dw_int8 * num_token_ * hidden_size_; // output
size_io += dw_float * (hidden_size_ + dynamic_quant_ ? num_token_ : 0); // scale
} else {
size_io += data_dw * num_token_ * hidden_size_;
}
if (store_output_before_norm_) size_io += data_dw * num_token_ * hidden_size_;
return size_io;
}
cnnlDataType_t getCalcDtype(void) const override { return data_type_; }
const std::string &getOpName(void) const override { return name_; }
};
struct SmoothQuantTheory : public OpTheory {
int m_, ci_;
cnnlDataType_t data_type_;
bool dynamic_quant_;
std::string name_;
SmoothQuantTheory(int m, int ci, cnnlDataType_t data_type, bool dynamic_quant)
: OpTheory(),
m_(m),
ci_(ci),
data_type_(data_type),
dynamic_quant_(dynamic_quant),
name_("smooth_quant_online") {}
size_t getTheoryCalc(void) const override {
if (!dynamic_quant_) {
return 3 * (size_t)m_ * ci_; // convert + mul scale + convert
}
size_t calc_scale = m_; // 128 / maxvalue
size_t size_calc = (2 /*max + min*/ + 2 /*mul scale*/ + 2 /*convert*/) * (size_t)m_ * ci_;
return calc_scale + size_calc;
}
size_t getTheoryIO(void) const override {
size_t data_dw, dw_int8, dw_float;
cnnlGetSizeOfDataType(CNNL_DTYPE_FLOAT, &dw_float);
cnnlGetSizeOfDataType(CNNL_DTYPE_INT8, &dw_int8);
cnnlGetSizeOfDataType(data_type_, &data_dw);
size_t scale_size = dw_float * (ci_ + dynamic_quant_ ? m_ : 0);
return scale_size + (data_dw + dw_int8) * (m_ * ci_);
}
cnnlDataType_t getCalcDtype(void) const override { return data_type_; }
const std::string &getOpName(void) const override { return name_; }
};
struct GroupAddBiasActiveTheory : public OpTheory {
int expert_size_, num_expand_token_, ci_, inner_size_;
bool has_gate_, is_addbias_;
cnnlDataType_t data_type_;
std::string act_mode_, name_;
GroupAddBiasActiveTheory(int expert_size,
int num_expand_token,
int inner_size,
bool has_gate,
bool is_addbias,
cnnlDataType_t data_type,
std::string act_mode)
: OpTheory(),
expert_size_(expert_size),
num_expand_token_(num_expand_token),
inner_size_(inner_size),
has_gate_(has_gate),
is_addbias_(is_addbias),
data_type_(data_type),
act_mode_(act_mode),
name_("group_add_bias_active") {}
size_t getTheoryCalc(void) const override {
size_t size_vec = 0, act_size = 0;
if (is_addbias_) size_vec += (size_t)num_expand_token_ * inner_size_ * (1 + has_gate_);
if (act_mode_ == "silu") { // silu = sigmoid + mul
act_size += 2 * (size_t)num_expand_token_ * inner_size_;
} else { // gelu = 1 vector_compution
act_size += (size_t)num_expand_token_ * inner_size_;
}
if (has_gate_) size_vec += (size_t)num_expand_token_ * inner_size_;
return size_vec + act_size;
}
size_t getTheoryIO(void) const override {
size_t data_dw = 0;
cnnlGetSizeOfDataType(data_type_, &data_dw);
size_t bias_io = is_addbias_ ? expert_size_ * ci_ * (1 + has_gate_) : 0;
return data_dw * (bias_io + num_expand_token_ * inner_size_ * (2 + has_gate_));
}
cnnlDataType_t getCalcDtype(void) const override { return data_type_; }
const std::string &getOpName(void) const override { return name_; }
};
struct MoeCombineResultTheory : public OpTheory {
int num_token_, topk_, ci_, expert_size_;
bool has_bias_, has_res_;
cnnlDataType_t data_type_;
std::string name_;
MoeCombineResultTheory(int num_token,
int topk,
int ci,
int expert_size,
bool has_bias,
bool has_res,
cnnlDataType_t data_type)
: OpTheory(),
num_token_(num_token),
topk_(topk),
ci_(ci),
expert_size_(expert_size),
has_bias_(has_bias),
has_res_(has_res),
data_type_(data_type),
name_("moe_combine_result") {}
size_t getTheoryCalc(void) const override {
size_t size_calc =
(size_t)((has_bias_ + 3 /*cvt+fma+cvt*/) * topk_ + has_res_) * num_token_ * ci_;
return size_calc;
}
size_t getTheoryIO(void) const override {
size_t data_dw = 0;
cnnlGetSizeOfDataType(data_type_, &data_dw);
size_t bias_io = has_bias_ ? expert_size_ * ci_ : 0;
size_t res_io = has_res_ ? num_token_ * ci_ : 0;
return data_dw * (bias_io + res_io + num_token_ * topk_ * ci_ /*input*/ +
num_token_ * ci_ /*output*/ + num_token_ * topk_ /*reduce_weight*/);
}
cnnlDataType_t getCalcDtype(void) const override { return data_type_; }
const std::string &getOpName(void) const override { return name_; }
};
} // namespace torch_api
} // namespace tmo
#endif // CSRC_TORCH_API_OP_THEORY_H_

View File

@@ -0,0 +1,25 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/preload.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void preload(const torch::Tensor &weight, const int64_t size) {
const torch_mlu::mlu::MLUGuard device_guard(weight.device());
auto queue = torch_mlu::getCurMLUStream();
invokePreload(queue, weight.data_ptr(), weight.element_size() * weight.numel(), size);
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,115 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
at::Tensor quant_matmul(const at::Tensor &a_tensor,
const c10::optional<at::Tensor> &a_scale,
const c10::optional<at::Tensor> &a_zero,
const at::Tensor &b_tensor,
const c10::optional<at::Tensor> &b_scale,
const c10::optional<at::Tensor> &b_zero,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &c_tensor,
const c10::optional<at::Tensor> &c_scale,
const c10::optional<at::Tensor> &c_zero,
const c10::optional<at::Tensor> &gemm_output_scale,
const c10::optional<at::Tensor> &gemm_output_zero,
const c10::optional<std::string> &data_type,
const c10::optional<at::Tensor> &d,
const std::string &quant_algo,
const std::string &a_quant_layout,
const std::string &b_quant_layout,
int64_t quant_bit_size,
const std::string &act_mode,
bool use_hp_active,
double act_coef,
double alpha,
double beta,
bool trans_a,
bool trans_b) {
// check device and dtype
// only support weight only or smooth quant
TORCH_CHECK(quant_algo == "weight_only" || quant_algo == "smooth_quant", "illegal quant mode");
TORCH_CHECK(quant_bit_size == 4 || quant_bit_size == 8, "illegal quant bit size");
TORCH_CHECK(trans_a == false && trans_b == true, "illegal trans mode");
// create output tensor
auto tensor_options = a_tensor.options();
std::vector<int64_t> output_shape = {a_tensor.sizes()[0], b_tensor.sizes()[0]};
at::Tensor output_tensor;
if (d.has_value()) {
output_tensor = d.value();
} else if (data_type.has_value()) {
auto dtype = data_type.value();
TORCH_CHECK(dtype == "float" || dtype == "half" || dtype == "bfloat16",
"data type must be 'float', 'half' or 'bfloat16'")
auto torch_dtype = str2TorchDtype(dtype);
output_tensor = at::empty(output_shape, tensor_options.dtype(torch_dtype));
} else {
output_tensor = at::empty(output_shape, tensor_options);
}
// create tensor desc
auto descs = createTensorDescs(
{a_tensor, a_scale.value_or(torch::Tensor()), a_zero.value_or(torch::Tensor()), b_tensor,
b_scale.value_or(torch::Tensor()), b_zero.value_or(torch::Tensor()),
bias.value_or(torch::Tensor()), c_tensor.value_or(torch::Tensor()),
c_scale.value_or(torch::Tensor()), c_zero.value_or(torch::Tensor()), output_tensor,
gemm_output_scale.value_or(torch::Tensor()), gemm_output_zero.value_or(torch::Tensor())});
// get && set Op desc
auto compute_dtype = getCnnlDataType(output_tensor.scalar_type());
auto op_desc = tmo::op_desc::QuantMatmulDesc(quant_algo, a_quant_layout, b_quant_layout,
quant_bit_size, compute_dtype, act_mode,
use_hp_active, act_coef, trans_a, trans_b);
auto handle = torch_mlu::getCurrentHandle();
// get workspace size
size_t workspace_size = 0;
CNNL_CHECK_FATAL(cnnlGetLLMQuantMatmulWorkspaceSize(handle, op_desc, descs[0].get(),
descs[3].get(), descs[7].get(),
descs[10].get(), &workspace_size));
auto workspace =
at::empty({static_cast<int64_t>(workspace_size)}, a_tensor.options().dtype(at::kByte));
// run forward
const cnnlTensorDescriptor_t a_descs[] = {descs[0].get(), descs[1].get(), descs[2].get()};
const cnnlTensorDescriptor_t b_descs[] = {descs[3].get(), descs[4].get(), descs[5].get()};
const cnnlTensorDescriptor_t c_descs[] = {descs[7].get(), descs[8].get(), descs[9].get()};
const cnnlTensorDescriptor_t d_descs[] = {descs[10].get(), nullptr, nullptr};
const void *a_tensors[] = {getAtTensorPtr(a_tensor), getAtTensorPtr(a_scale),
getAtTensorPtr(a_zero)};
const void *b_tensors[] = {getAtTensorPtr(b_tensor), getAtTensorPtr(b_scale),
getAtTensorPtr(b_zero)};
const void *c_tensors[] = {getAtTensorPtr(c_tensor), getAtTensorPtr(c_scale),
getAtTensorPtr(c_zero)};
void *d_tensors[] = {getAtTensorPtr(output_tensor), nullptr, nullptr};
float alpha_arr[1] = {(float)alpha};
float beta_arr[1] = {(float)beta};
bool has_res = c_tensor.has_value();
QuantMatmulTheory obj(a_tensor.size(0), a_tensor.size(1), b_tensor.size(0), quant_bit_size,
has_res, compute_dtype, quant_algo);
cnpxPush(obj);
CNNL_CHECK_FATAL(cnnlLLMQuantMatmul(
handle, op_desc, nullptr, alpha_arr, a_descs, a_tensors, b_descs, b_tensors, nullptr,
beta_arr, c_descs, c_tensors, descs[6].get(), getAtTensorPtr(bias), descs[11].get(),
getAtTensorPtr(gemm_output_scale), descs[12].get(), getAtTensorPtr(gemm_output_zero),
getAtTensorPtr(workspace), workspace_size, d_descs, d_tensors));
cnpxPop();
return output_tensor;
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,208 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/quant_to_linear_cache.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void quant_to_linear_cache(
const at::Tensor &key, // [total_seqlen, num_heads, head_size]
const c10::optional<at::Tensor> &value, // or [bs, max_context_len, num_heads, head_size]
at::Tensor &key_cache, // [max_bs, num_heads, cache_memory_len, head_size]
const c10::optional<at::Tensor>
&value_cache, // [max_bs, num_heads, cache_memory_len, head_size]
at::Tensor &key_cache_scale, // [max_bs, num_heads, cache_memory_len]
const c10::optional<at::Tensor> &value_cache_scale,
const at::Tensor &context_lengths,
const int64_t max_context_len,
bool packed,
const c10::optional<at::Tensor> &context_seq_offset,
const c10::optional<at::Tensor> &cache_bs_id,
const c10::optional<at::Tensor> &cache_seqlen_offset,
const int64_t quant_bit) {
// check device and dtype
TORCH_CHECK(key.scalar_type() == torch::kFloat32 || key.scalar_type() == torch::kFloat16 ||
key.scalar_type() == torch::kBFloat16,
"key type need be torch::kFloat32, torch::kFloat16 or torch::kBFloat16");
TORCH_CHECK(key_cache.scalar_type() == torch::kInt8, "key_cache type need be torch::kInt8");
TORCH_CHECK(key_cache_scale.scalar_type() == torch::kFloat32,
"key_cache_scale type need be torch::kFloat32");
if (value.has_value() || value_cache.has_value() || value_cache_scale.has_value()) {
TORCH_CHECK(value.has_value() && value_cache.has_value() && value_cache_scale.has_value(),
"value_cache, value_cache_scale, value must all have value.")
}
if (value_cache.has_value()) {
TORCH_CHECK(value_cache.value().scalar_type() == torch::kInt8,
"value_cache type need be torch::kInt8");
TORCH_CHECK(value_cache_scale.value().scalar_type() == torch::kFloat32,
"value_cache_scale type need be torch::kFloat32");
}
if (value.has_value()) {
TORCH_CHECK(value.value().scalar_type() == key.scalar_type(),
"value type need be same with key");
}
TORCH_CHECK(context_lengths.scalar_type() == torch::kInt32,
"context_lengths type need be torch::kInt32");
const int batch_size = packed ? context_lengths.size(0) - 1 : context_lengths.size(0);
const int max_bs = key_cache.size(0);
const int num_heads = key_cache.size(1);
const int cache_mem_len = key_cache.size(2);
const int head_size = key.size(-1);
// check quantize param
TORCH_CHECK(quant_bit == 8 || quant_bit == 4, "quant_bit must be 8 or 4");
if (key_cache_scale.dim() == 4) {
TORCH_CHECK(key_cache_scale.size(-1) > 0 && head_size % key_cache_scale.size(-1) == 0,
"key_cache_scale.size(-1) must be legal");
}
int group_size = key_cache_scale.dim() == 3 ? head_size : head_size / key_cache_scale.size(-1);
const int total_seqlen = packed ? key.size(0) : batch_size * key.size(-2);
const size_t context_bs_stride = packed ? 0 : key.stride(0);
const size_t context_head_stride = packed ? key.stride(1) : key.stride(2);
const size_t context_seq_stride = packed ? key.stride(0) : key.stride(1);
const size_t cache_bs_stride = key_cache.stride(0);
const size_t cache_head_stride = key_cache.stride(1);
const size_t key_cache_seq_stride = key_cache.stride(2);
size_t value_cache_seq_stride = head_size;
const size_t cache_scale_bs_stride = key_cache_scale.stride(0);
const size_t cache_scale_head_stride = key_cache_scale.stride(1);
if (key_cache_scale.dim() == 3) {
CHECK_SHAPE(key_cache_scale, max_bs, num_heads, cache_mem_len);
} else {
CHECK_SHAPE(key_cache_scale, max_bs, num_heads, cache_mem_len, head_size / group_size);
}
if (quant_bit == 8) {
TORCH_CHECK(key_cache.size(-1) == head_size, "last dim of key_cache should be head_size");
} else {
TORCH_CHECK(key_cache.size(-1) == head_size / 2, "last dim of key_cache should be head_size");
}
if (value_cache.has_value()) {
if (quant_bit == 8) {
CHECK_SHAPE(value_cache.value(), max_bs, num_heads, cache_mem_len, head_size);
} else if (quant_bit == 4) {
CHECK_SHAPE(value_cache.value(), max_bs, num_heads, cache_mem_len / 2, head_size);
}
if (key_cache_scale.dim() == 3) {
CHECK_SHAPE(value_cache_scale.value(), max_bs, num_heads, cache_mem_len);
} else {
CHECK_SHAPE(value_cache_scale.value(), max_bs, num_heads, cache_mem_len,
head_size / group_size);
}
}
const int64_t device_id = key.get_device();
TORCH_CHECK(key_cache.get_device() == device_id,
"key_cache tensor device index is not same, original index: ", device_id,
"now index is: ", key_cache.get_device());
TORCH_CHECK(key_cache_scale.get_device() == device_id,
"key_cache_scale tensor device index is not same, original index: ", device_id,
"now index is: ", key_cache_scale.get_device());
bool has_context_seq_offset = context_seq_offset.has_value();
const void *context_seq_offset_ptr = nullptr;
if (has_context_seq_offset) {
CHECK_SHAPE(context_seq_offset.value(), batch_size);
TORCH_CHECK(context_seq_offset.value().scalar_type() == torch::kInt32,
"context_seq_offset type need be torch::kInt32");
TORCH_CHECK(context_seq_offset.value().get_device() == device_id,
"context_seq_offset tensor device index is not same, original index: ", device_id,
"now index is: ", context_seq_offset.value().get_device());
context_seq_offset_ptr = context_seq_offset.value().data_ptr();
}
bool has_cache_bs_id = cache_bs_id.has_value();
const void *cache_bs_id_ptr = nullptr;
if (has_cache_bs_id) {
CHECK_SHAPE(cache_bs_id.value(), batch_size);
TORCH_CHECK(cache_bs_id.value().scalar_type() == torch::kInt32,
"cache_bs_id type need be torch::kInt32");
TORCH_CHECK(cache_bs_id.value().get_device() == device_id,
"cache_bs_id tensor device index is not same, original index: ", device_id,
"now index is: ", cache_bs_id.value().get_device());
cache_bs_id_ptr = cache_bs_id.value().data_ptr();
}
bool has_cache_seqlen_offset = cache_seqlen_offset.has_value();
const void *cache_seqlen_offset_ptr = nullptr;
if (has_cache_seqlen_offset) {
CHECK_SHAPE(cache_seqlen_offset.value(), batch_size);
TORCH_CHECK(cache_seqlen_offset.value().scalar_type() == torch::kInt32,
"cache_seqlen_offset type need be torch::kInt32");
TORCH_CHECK(cache_seqlen_offset.value().get_device() == device_id,
"cache_seqlen_offset tensor device index is not same, original index: ", device_id,
"now index is: ", cache_seqlen_offset.value().get_device());
cache_seqlen_offset_ptr = cache_seqlen_offset.value().data_ptr();
}
// check contiguous
TORCH_CHECK(key_cache.is_contiguous(), "key_cache tensor must be contiguous.");
TORCH_CHECK(key_cache_scale.is_contiguous(), "key_cache_scale tensor must be contiguous.");
if (value_cache.has_value()) {
TORCH_CHECK(value_cache.value().is_contiguous(), "value_cache tensor must be contiguous.");
TORCH_CHECK(value_cache.value().get_device() == device_id,
"value_cache tensor device index is not same, original index: ", device_id,
"now index is: ", value_cache.value().get_device());
TORCH_CHECK(value_cache_scale.value().is_contiguous(),
"value_cache_scale tensor must be contiguous.");
TORCH_CHECK(value_cache_scale.value().get_device() == device_id,
"value_cache_scale tensor device index is not same, original index: ", device_id,
"now index is: ", value_cache_scale.value().get_device());
}
// check shape
if (packed) {
CHECK_SHAPE(key, total_seqlen, num_heads, head_size);
if (value.has_value()) {
CHECK_SHAPE(value.value(), total_seqlen, num_heads, head_size);
}
CHECK_SHAPE(context_lengths, batch_size + 1);
} else {
CHECK_SHAPE(key, batch_size, key.size(1), num_heads, head_size);
if (value.has_value()) {
CHECK_SHAPE(value.value(), batch_size, key.size(1), num_heads, head_size);
}
CHECK_SHAPE(context_lengths, batch_size);
}
if (value.has_value()) {
for (int i = 0; i < key.dim(); i++) {
TORCH_CHECK(value.value().stride(i) == key.stride(i),
"key and value must have same stride along axi ", i);
}
}
TORCH_CHECK(key_cache.size(0) >= batch_size,
"first dim of key_cache must be great than or equal to batch_size");
TORCH_CHECK(key_cache.dim() == 4, "dim of key_cache must be 4.");
const torch_mlu::mlu::MLUGuard device_guard(key.device());
auto data_dtype = getCnnlDataType(key.scalar_type());
auto queue = torch_mlu::getCurMLUStream();
// run forward
invokeQuantToLinearCache(
queue, getAtTensorPtr(key_cache), getAtTensorPtr(value_cache),
getAtTensorPtr(key_cache_scale), getAtTensorPtr(value_cache_scale), cache_bs_id_ptr,
cache_seqlen_offset_ptr, getAtTensorPtr(key), getAtTensorPtr(value), context_seq_offset_ptr,
getAtTensorPtr(context_lengths), data_dtype, batch_size, num_heads, head_size,
(int)max_context_len, cache_mem_len, context_bs_stride, context_head_stride,
context_seq_stride, cache_bs_stride, cache_head_stride, key_cache_seq_stride,
value_cache_seq_stride, cache_scale_bs_stride, cache_scale_head_stride, packed,
(int)quant_bit, group_size);
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,81 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/quant_to_paged_cache.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void quant_to_paged_cache(
const torch::Tensor &k, // [num_tokens, num_heads, head_dim]
const c10::optional<torch::Tensor> &v, // [num_tokens, num_heads, head_dim]
torch::Tensor &k_cache, // [num_blocks, num_heads, block_size, head_dim]
const c10::optional<torch::Tensor> &v_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor &k_cache_scale, // [num_blocks, num_heads, block_size]
const c10::optional<torch::Tensor> &v_cache_scale, // [num_blocks, num_heads, block_size]
const torch::Tensor &slot_mapping) {
// 1. check device and tensor type
TORCH_CHECK(slot_mapping.dtype() == torch::kInt32, "slot_mapping type need be torch::kInt32");
TORCH_CHECK(slot_mapping.get_device() == k.get_device(),
"Tensor device index is not same, original index: ", k.get_device(),
"now index is: ", slot_mapping.get_device());
// 2. check shape
const int num_tokens = k.size(0);
const int num_heads = k.size(1);
const int head_dim = k.size(2);
const int block_size = k_cache.size(2);
const int head_size = k_cache.size(3);
const int num_blocks = k_cache.size(0);
CHECK_SHAPE(k, num_tokens, num_heads, head_dim);
CHECK_SHAPE(k_cache, num_blocks, num_heads, block_size, head_dim);
CHECK_SHAPE(k_cache_scale, num_blocks, num_heads, block_size);
TORCH_CHECK(v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value(),
"v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value().")
if (v.has_value()) {
CHECK_SHAPE(v.value(), num_tokens, num_heads, head_dim);
CHECK_SHAPE(v_cache.value(), num_blocks, num_heads, block_size, head_dim);
CHECK_SHAPE(v_cache_scale.value(), num_blocks, num_heads, block_size);
}
CHECK_SHAPE(slot_mapping, num_tokens);
// 3. check strides
TORCH_CHECK(slot_mapping.is_contiguous(), "slot_mapping need be contiguous.");
TORCH_CHECK(k.stride(-1) == 1, "k last dim must be contiguous.");
TORCH_CHECK(k.stride(-2) == head_dim, "k second dim must be contiguous.");
if (v.has_value()) {
TORCH_CHECK(v.value().stride(-1) == 1, "v last dim must be contiguous.");
TORCH_CHECK(v.value().stride(-2) == head_dim, "v second dim must be contiguous.");
}
int64_t kv_cache_range =
(int64_t)num_blocks * num_heads * block_size * head_dim * k_cache.element_size();
TORCH_CHECK(kv_cache_range <= UINT32_MAX, "The addressing range of kv_cache cannot exceed 4G.");
const torch_mlu::mlu::MLUGuard device_guard(k.device());
auto queue = torch_mlu::getCurMLUStream();
cnnlDataType_t dtype = getCnnlDataType(k.scalar_type());
int32_t key_stride0 = static_cast<int32_t>(k.stride(0));
int32_t value_stride0 = 1;
if (v.has_value()) {
value_stride0 = static_cast<int32_t>(v.value().stride(0));
}
TMO_KERNEL_CHECK_FATAL(tmo::invokeQuantToPagedCache(
queue, dtype, getAtTensorPtr(k), getAtTensorPtr(v), getAtTensorPtr(k_cache),
getAtTensorPtr(v_cache), getAtTensorPtr(k_cache_scale), getAtTensorPtr(v_cache_scale),
getAtTensorPtr(slot_mapping), key_stride0, value_stride0, num_tokens, num_heads, num_blocks,
block_size, head_size));
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,136 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/reshape_linear_cache.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void reshape_linear_cache(const at::Tensor &key,
const c10::optional<at::Tensor> &value,
at::Tensor &key_cache,
const c10::optional<at::Tensor> &value_cache,
const at::Tensor &context_lengths,
const int64_t max_context_len,
bool packed,
const c10::optional<at::Tensor> &context_seq_offset,
const c10::optional<at::Tensor> &cache_bs_id,
const c10::optional<at::Tensor> &cache_seqlen_offset) {
// check device and dtype
checkTensorSameAttr<TensorAttr::ALL>(key, key_cache, value, value_cache);
const int batch_size = packed ? context_lengths.size(0) - 1 : context_lengths.size(0);
const int max_bs = key_cache.size(0);
const int num_heads = key_cache.size(1);
const int cache_mem_len = key_cache.size(2);
const int head_size = key_cache.size(3);
const int total_seqlen = packed ? key.size(0) : batch_size * key.size(1);
const int context_bs_stride = packed ? 0 : key.stride(0);
const int context_head_stride = packed ? key.stride(1) : key.stride(2);
const int context_seq_stride = packed ? key.stride(0) : key.stride(1);
const int cache_bs_stride = key_cache.stride(0);
const int cache_head_stride = key_cache.stride(1);
const int cache_seq_stride = key_cache.stride(2);
TORCH_CHECK(context_lengths.scalar_type() == torch::kInt32,
"context_lengths type need be torch::kInt32");
if (value_cache.has_value()) {
CHECK_SHAPE(value_cache.value(), max_bs, num_heads, cache_mem_len, head_size);
}
const int64_t device_id = key.get_device();
bool has_context_seq_offset = context_seq_offset.has_value();
const void *context_seq_offset_ptr = nullptr;
if (has_context_seq_offset) {
CHECK_SHAPE(context_seq_offset.value(), batch_size);
TORCH_CHECK(context_seq_offset.value().scalar_type() == torch::kInt32,
"context_seq_offset type need be torch::kInt32");
TORCH_CHECK(context_seq_offset.value().get_device() == device_id,
"context_seq_offset tensor device index is not same, original index: ", device_id,
"now index is: ", context_seq_offset.value().get_device());
context_seq_offset_ptr = context_seq_offset.value().data_ptr();
}
bool has_cache_bs_id = cache_bs_id.has_value();
const void *cache_bs_id_ptr = nullptr;
if (has_cache_bs_id) {
CHECK_SHAPE(cache_bs_id.value(), batch_size);
TORCH_CHECK(cache_bs_id.value().scalar_type() == torch::kInt32,
"cache_bs_id type need be torch::kInt32");
TORCH_CHECK(cache_bs_id.value().get_device() == device_id,
"cache_bs_id tensor device index is not same, original index: ", device_id,
"now index is: ", cache_bs_id.value().get_device());
cache_bs_id_ptr = cache_bs_id.value().data_ptr();
}
bool has_cache_seqlen_offset = cache_seqlen_offset.has_value();
const void *cache_seqlen_offset_ptr = nullptr;
if (has_cache_seqlen_offset) {
CHECK_SHAPE(cache_seqlen_offset.value(), batch_size);
TORCH_CHECK(cache_seqlen_offset.value().scalar_type() == torch::kInt32,
"cache_seqlen_offset type need be torch::kInt32");
TORCH_CHECK(cache_seqlen_offset.value().get_device() == device_id,
"cache_seqlen_offset tensor device index is not same, original index: ", device_id,
"now index is: ", cache_seqlen_offset.value().get_device());
cache_seqlen_offset_ptr = cache_seqlen_offset.value().data_ptr();
}
// check contiguous
TORCH_CHECK(key_cache.is_contiguous(), "key_cache tensor must be contiguous.");
if (value_cache.has_value()) {
TORCH_CHECK(value_cache.value().is_contiguous(), "value_cache tensor must be contiguous.");
}
// check shape
if (packed) {
CHECK_SHAPE(key, total_seqlen, num_heads, head_size);
if (value.has_value()) {
CHECK_SHAPE(value.value(), total_seqlen, num_heads, head_size);
}
CHECK_SHAPE(context_lengths, batch_size + 1);
} else {
CHECK_SHAPE(key, batch_size, key.size(1), num_heads, head_size);
if (value.has_value()) {
CHECK_SHAPE(value.value(), batch_size, key.size(1), num_heads, head_size);
}
CHECK_SHAPE(context_lengths, batch_size);
}
if (value.has_value()) {
for (int i = 0; i < key.dim(); i++) {
TORCH_CHECK(value.value().stride(i) == key.stride(i),
"key and value must have same stride along axis ", i);
}
}
TORCH_CHECK(key_cache.size(0) >= batch_size,
"first dim of key_cache must be great than or equal to batch_size");
TORCH_CHECK(key_cache.dim() == 4, "dim of key_cache must be 4.");
const torch_mlu::mlu::MLUGuard device_guard(key.device());
auto data_dtype = getCnnlDataType(key.scalar_type());
auto queue = torch_mlu::getCurMLUStream();
// run forward
TMO_KERNEL_CHECK_FATAL(invokeReshapeLinearCache(
queue, getAtTensorPtr(key_cache), getAtTensorPtr(value_cache), cache_bs_id_ptr,
cache_seqlen_offset_ptr, getAtTensorPtr(key), getAtTensorPtr(value), context_seq_offset_ptr,
getAtTensorPtr(context_lengths), data_dtype, batch_size, num_heads, head_size,
(int)max_context_len, cache_mem_len, context_bs_stride, context_head_stride,
context_seq_stride, cache_bs_stride, cache_head_stride, cache_seq_stride, packed));
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,80 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/reshape_paged_cache.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void reshape_paged_cache(
const torch::Tensor &k, // [num_tokens, num_heads, head_dim]
const c10::optional<torch::Tensor> &v, // [num_tokens, num_heads, head_dim]
torch::Tensor &k_cache, // [num_blocks, num_heads, block_size, head_dim]
const c10::optional<torch::Tensor> &v_cache, // [num_blocks, num_heads, block_size, head_dim]
const torch::Tensor &slot_mapping) {
// 1. check device and tensor type
checkTensorSameAttr<TensorAttr::ALL>(k, v, k_cache, v_cache);
TORCH_CHECK(slot_mapping.dtype() == torch::kInt32 || slot_mapping.dtype() == torch::kLong,
"slot_mapping type need be torch::kInt32 or torch::Long");
TORCH_CHECK(slot_mapping.get_device() == k.get_device(),
"Tensor device index is not same, original index: ", k.get_device(),
"now index is: ", slot_mapping.get_device());
// 2. check shape
const int num_tokens = k.size(0);
const int num_heads = k.size(1);
const int head_dim = k.size(2);
const int block_size = k_cache.size(2);
const int head_size = k_cache.size(3);
const int num_blocks = k_cache.size(0);
CHECK_SHAPE(k, num_tokens, num_heads, head_dim);
CHECK_SHAPE(k_cache, num_blocks, num_heads, block_size, head_dim);
TORCH_CHECK(v.has_value() == v_cache.has_value(), "v.has_value() == v_cache.has_value().")
if (v.has_value()) {
CHECK_SHAPE(v.value(), num_tokens, num_heads, head_dim);
CHECK_SHAPE(v_cache.value(), num_blocks, num_heads, block_size, head_dim);
}
CHECK_SHAPE(slot_mapping, num_tokens);
// 3. check strides
TORCH_CHECK(k_cache.is_contiguous(), "k_cache need be contiguous.");
TORCH_CHECK(slot_mapping.is_contiguous(), "slot_mapping need be contiguous.");
TORCH_CHECK(k.stride(-1) == 1, "k last dim must be contiguous.");
TORCH_CHECK(k.stride(-2) == head_dim, "k second dim must be contiguous.");
if (v.has_value()) {
TORCH_CHECK(v_cache.value().is_contiguous(), "v_cache need be contiguous.");
TORCH_CHECK(v.value().stride(-1) == 1, "v last dim must be contiguous.");
TORCH_CHECK(v.value().stride(-2) == head_dim, "v second dim must be contiguous.");
}
// check large tensor
int64_t kv_cache_range =
(int64_t)num_blocks * num_heads * block_size * head_dim * k_cache.element_size();
TORCH_CHECK(kv_cache_range <= UINT32_MAX, "The addressing range of kv_cache cannot exceed 4G.");
const torch_mlu::mlu::MLUGuard device_guard(k.device());
auto queue = torch_mlu::getCurMLUStream();
cnnlDataType_t dtype = getCnnlDataType(k.scalar_type());
int32_t key_stride0 = static_cast<int32_t>(k.stride(0));
int32_t value_stride0 = 1;
if (v.has_value()) {
value_stride0 = static_cast<int32_t>(v.value().stride(0));
}
TMO_KERNEL_CHECK_FATAL(tmo::invokeReshapePagedCache(
queue, dtype, getAtTensorPtr(k), getAtTensorPtr(v), getAtTensorPtr(k_cache),
getAtTensorPtr(v_cache), getAtTensorPtr(slot_mapping), key_stride0, value_stride0, num_tokens,
num_heads, num_blocks, block_size, head_size));
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,165 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "torch_ops_api.h"
#include "utils.h"
namespace tmo {
namespace torch_api {
// q_ori, out Tensor shape is [batch, seq_len, head_num, head_size]
// if (pagedattn) k_cache, v_cache shape is [num_blocks, k_head_num, block_size, head_size]
// else [max_bs, head_num, max_seq_len, head_size]
// if(pagedattn) k_cache_quant_scale shape is [num_blocks, k_head_num, block_size] else [max_bs,
// head_num, max_seq_len]
// if (pagedattn) block_tables shape is [bs, max_nblock_per_seq] else [bs, 1]
// context_lens: true seqkv_len of each bacth
void single_query_cached_kv_attn(
const torch::Tensor &q_ori,
const torch::Tensor &k_cache,
const torch::Tensor &v_cache,
const torch::Tensor &output,
const torch::Tensor &block_tables,
const torch::Tensor &context_lens, // [batch]
const c10::optional<torch::Tensor> &output_lse,
const c10::optional<torch::Tensor> &k_cache_quant_scale,
const c10::optional<torch::Tensor> &v_cache_quant_scale,
const c10::optional<torch::Tensor> &alibi_slopes, // [bs, head_num] or [head_num]
int64_t max_context_len,
int64_t windows_size_left,
int64_t windows_size_right,
double softmax_scale,
bool return_lse,
int64_t kv_cache_quant_bit_size) {
// Check tensor type and tensor device.
cnnlQuantizeLayout_t cache_quant_layout = CNNL_QUANTIZE_NONE;
bool is_kv_quant = k_cache_quant_scale.has_value();
bool has_alibi = alibi_slopes.has_value();
// Check Tensor Shape
int batch = q_ori.size(0);
int seq_q = q_ori.size(1);
int head_num = q_ori.size(2);
int qk_head_size = q_ori.size(3);
int k_head_num = k_cache.size(1);
int v_head_size = v_cache.size(-1);
int max_num_block_per_seq = block_tables.size(1);
int num_blocks = k_cache.size(0);
int block_size = k_cache.size(2);
TORCH_CHECK(windows_size_right < 0, "only support windows_size_right < 0 currently.");
TORCH_CHECK(head_num % k_head_num == 0, "num_heads need be mutiple of num_kv_heads.");
TORCH_CHECK(
kv_cache_quant_bit_size == 4 || kv_cache_quant_bit_size == 8 || kv_cache_quant_bit_size == -1,
"illegal quant bit size, only support 4, 8 or -1.");
CHECK_SHAPE(block_tables, batch, max_num_block_per_seq);
if (kv_cache_quant_bit_size == 4) {
int v_cache_len = PAD_UP_DIV(block_size, 2);
CHECK_SHAPE(k_cache, num_blocks, k_head_num, block_size, qk_head_size / 2);
CHECK_SHAPE(v_cache, num_blocks, k_head_num, v_cache_len, v_head_size);
} else {
CHECK_SHAPE(k_cache, num_blocks, k_head_num, block_size, qk_head_size);
CHECK_SHAPE(v_cache, num_blocks, k_head_num, block_size, v_head_size);
}
TORCH_CHECK(q_ori.stride(-1) == 1 && q_ori.stride(-2) == qk_head_size,
"q last two dim need be contiguous.");
TORCH_CHECK(k_cache.is_contiguous() && v_cache.is_contiguous(),
"k_cache and v_cache need be contiguous.");
TORCH_CHECK(isMlu(q_ori), "q_ori need be mlu tensor.");
checkTensorSameAttr<TensorAttr::DEVICE>(q_ori, k_cache, v_cache, output, block_tables,
context_lens, k_cache_quant_scale, v_cache_quant_scale,
output_lse, alibi_slopes);
if (is_kv_quant) {
TORCH_CHECK(k_cache_quant_scale.value().dim() == 2 || k_cache_quant_scale.value().dim() == 3 ||
k_cache_quant_scale.value().dim() == 4,
"k_cache_quant_scale must be 2d or 3d or 4d.");
TORCH_CHECK(k_cache_quant_scale.value().dim() == v_cache_quant_scale.value().dim(),
"the dim of k_cache_quant_scale and v_cache_quant_scale must be euqal.");
if (k_cache_quant_scale.value().dim() == 2) {
CHECK_SHAPE(k_cache_quant_scale.value(), k_head_num, qk_head_size);
CHECK_SHAPE(v_cache_quant_scale.value(), k_head_num, v_head_size);
cache_quant_layout = CNNL_QUANTIZE_PER_CHANNEL;
} else if (k_cache_quant_scale.value().dim() == 3) {
CHECK_SHAPE(k_cache_quant_scale.value(), num_blocks, k_head_num, block_size);
CHECK_SHAPE(v_cache_quant_scale.value(), num_blocks, k_head_num, block_size);
cache_quant_layout = CNNL_QUANTIZE_PER_TOKEN;
} else {
CHECK_SHAPE(k_cache_quant_scale.value(), num_blocks, k_head_num, block_size, 1);
CHECK_SHAPE(v_cache_quant_scale.value(), num_blocks, k_head_num, block_size, 1);
cache_quant_layout = CNNL_QUANTIZE_PER_TOKEN;
}
}
TORCH_CHECK(
block_tables.scalar_type() == torch::kInt32 || block_tables.scalar_type() == torch::kLong,
"block_tables type need be torch::kInt32 or torch::kLong.");
// Check context_lens
TORCH_CHECK(context_lens.dtype() == torch::kInt32, "context_lens type need be torch::kInt32.");
TORCH_CHECK(context_lens.is_contiguous(), "context_lens need be contiguous.");
if (has_alibi) {
CHECK_SHAPE(alibi_slopes.value(), batch, head_num);
}
if (return_lse) {
TORCH_CHECK(seq_q == 1, "return lse only support seq_q = 1 currently.");
CHECK_SHAPE(output_lse.value(), batch, head_num, seq_q);
}
if (windows_size_left >= 0) {
windows_size_left = windows_size_left + 1; // would be removed next version
}
// Convert torch tensor to tensor descs
auto descs = createTensorDescs(
{q_ori, k_cache, v_cache, k_cache_quant_scale.value_or(at::Tensor()),
v_cache_quant_scale.value_or(at::Tensor()), context_lens, block_tables,
alibi_slopes.value_or(at::Tensor()), output, output_lse.value_or(at::Tensor())});
if (kv_cache_quant_bit_size == 4) {
cnnlDataType_t data_type = CNNL_DTYPE_INT4X2;
// k_cache
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(descs[1].get(), CNNL_LAYOUT_ARRAY, data_type,
k_cache.sizes().size(), k_cache.sizes().data()));
// v_cache
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(descs[2].get(), CNNL_LAYOUT_ARRAY, data_type,
v_cache.sizes().size(), v_cache.sizes().data()));
}
// Get current handle.
const torch_mlu::mlu::MLUGuard device_guard(q_ori.device());
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
// Get workspace size and malloc workspace.
size_t workspace_size = 0;
cnnlSingleQueryCachedKVAttnDescriptor_t att_desc;
CNNL_CHECK_FATAL(cnnlGetSingleQueryCachedKVAttnWorkspaceSize_v2(
handle, att_desc, descs[0].get(), descs[1].get(), descs[2].get(), (int)max_context_len,
&workspace_size));
auto workspace =
at::empty({static_cast<int64_t>(workspace_size)}, q_ori.options().dtype(at::kByte));
cnnlDataType_t data_dtype = getCnnlDataType(q_ori.scalar_type());
SingleQueryCachedKvAttnTheory obj(batch, seq_q, head_num, k_head_num, num_blocks, block_size,
qk_head_size, v_head_size, kv_cache_quant_bit_size,
cache_quant_layout, data_dtype);
cnpxPush(obj);
// call cnnl extra op.
CNNL_CHECK_FATAL(cnnlSingleQueryCachedKVAttn_v3(
handle, att_desc, descs[0].get(), getAtTensorPtr(q_ori), descs[1].get(),
getAtTensorPtr(k_cache), descs[2].get(), getAtTensorPtr(v_cache), descs[3].get(),
getAtTensorPtr(k_cache_quant_scale), descs[4].get(), getAtTensorPtr(v_cache_quant_scale),
descs[5].get(), getAtTensorPtr(context_lens), descs[6].get(), getAtTensorPtr(block_tables),
descs[7].get(), getAtTensorPtr(alibi_slopes), cache_quant_layout, (int)max_context_len,
windows_size_left, windows_size_right, softmax_scale, return_lse, getAtTensorPtr(workspace),
workspace_size, descs[8].get(), getAtTensorPtr(output), descs[9].get(),
getAtTensorPtr(output_lse)));
cnpxPop();
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,112 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "torch_ops_api.h"
#include "utils.h"
namespace tmo {
namespace torch_api {
void smooth_quant(const at::Tensor &input,
const at::Tensor &input_scale,
const at::Tensor &output,
const at::Tensor &output_scale,
const c10::optional<at::Tensor> &input_zero,
const c10::optional<at::Tensor> &token_count,
const c10::optional<at::Tensor> &gather_index,
const c10::optional<at::Tensor> &gather_index_start_position,
const std::string &quant_mode,
const bool &dynamic_quant) {
bool is_pertoken = quant_mode == "per_token";
CHECK_TENSOR_CONTIGUOUS(input_scale)
if (output_scale.dim() > 0) {
CHECK_TENSOR_CONTIGUOUS(output_scale)
}
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(input_zero)
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(token_count)
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(gather_index)
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(gather_index_start_position)
TORCH_CHECK(isMlu(input), "input must be mlu tensor.");
checkTensorSameAttr<TensorAttr::DEVICE>(input, input_scale, input_zero, token_count, gather_index,
gather_index_start_position);
TORCH_CHECK(quant_mode == "per_token" || quant_mode == "per_tensor",
"quant_mode must be 'per_token' or per_tensor.")
if (dynamic_quant) {
TORCH_CHECK(is_pertoken, "only support per_token if dynamic_quant == true")
TORCH_CHECK(output_scale.numel() > 0, "invalid output_scale if dynamic_quant = true")
} else {
TORCH_CHECK(output_scale.numel() <= 0, "not support output_scale if dynamic_quant = false")
}
if (gather_index.has_value() || token_count.has_value()) {
TORCH_CHECK(input.dim() == 2, "input.dim() == 2 if has gather_index or token_count")
}
TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2")
TORCH_CHECK(output.dim() >= 2, "output.dim() >= 2")
if (!gather_index.has_value()) {
TORCH_CHECK(output.sizes() == input.sizes(), "input and output must have the same shape")
}
if (output_scale.numel() > 0) {
if (!gather_index.has_value()) {
TORCH_CHECK(std::vector<int64_t>(output_scale.sizes().begin(), output_scale.sizes().end()) ==
std::vector<int64_t>(input.sizes().begin(), input.sizes().end() - 1),
"output_scale_shape must be equal to input_shape[0:-1]")
}
}
if (gather_index_start_position.has_value()) {
TORCH_CHECK(gather_index.has_value(),
"gather_index must exist if gather_index_start_position has value")
TORCH_CHECK(gather_index.value().dim() == 1, "gather_index.dim() == 1")
}
const torch_mlu::mlu::MLUGuard device_guard(input.device());
// flatten input
at::Tensor input_flat = input.flatten(0, input.dim() - 2);
TORCH_CHECK(input_flat.data_ptr() == input.data_ptr(), "check the input strides")
// flatten output
at::Tensor output_flat = output.flatten(0, output.dim() - 2);
TORCH_CHECK(output_flat.data_ptr() == output.data_ptr(), "check the output stride")
// flatten output_scale
at::Tensor output_scale_flat;
if (dynamic_quant) {
output_scale_flat = output_scale.flatten(0, -1);
}
// Convert torch tensor to tensor descriptor
auto descs = createTensorDescs(
{input_flat, input_scale, input_zero.value_or(at::Tensor()),
token_count.value_or(at::Tensor()), gather_index.value_or(at::Tensor()),
gather_index_start_position.value_or(at::Tensor()), output_flat, output_scale_flat});
auto input_shape = input.sizes();
int64_t in_channel = input_shape.back();
int64_t m =
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, std::multiplies<int>());
SmoothQuantTheory obj(m, in_channel, getCnnlDataType(input.scalar_type()), dynamic_quant);
cnpxPush(obj);
CNNL_CHECK_FATAL(cnnlSmoothQuantOnline_v4(
torch_mlu::getCurrentHandle(), nullptr /*smooth_quant_online_desc*/, descs[0].get(),
getAtTensorPtr(input_flat), /*input*/
descs[1].get(), getAtTensorPtr(input_scale), /*input_scale*/
descs[2].get(), getAtTensorPtr(input_zero), /*input_zero*/
descs[3].get(), getAtTensorPtr(token_count), /*token_count*/
descs[4].get(), getAtTensorPtr(gather_index), /*gather_index*/
descs[5].get(), getAtTensorPtr(gather_index_start_position), /*gather_index_start_position*/
descs[6].get(), getAtTensorPtr(output_flat), /*output*/
descs[7].get(), getAtTensorPtr(output_scale_flat), /*output_scale*/
nullptr, nullptr, /*output_zero*/
nullptr, 0 /*worksapce*/));
cnpxPop();
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,65 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/swap_blocks.mluh"
#include "torch_ops_api.h"
#include "utils.h"
namespace tmo {
namespace torch_api {
void swap_blocks(torch::Tensor &dst,
const torch::Tensor &src,
const c10::Dict<int64_t, int64_t> &block_mapping_dict) {
std::map<int64_t, int64_t> block_mapping;
for (const auto &item : block_mapping_dict) {
block_mapping[item.key()] = item.value();
}
// check shape
TORCH_CHECK(src[0].numel() == dst[0].numel(), "the block_size of src and dst are not the same.")
// check dtype
TORCH_CHECK(src.dtype() == dst.dtype(), "the data type of src and dst are not the same.")
TORCH_CHECK(src.dtype() == torch::kInt8 || src.dtype() == torch::kUInt8 ||
src.dtype() == torch::kInt16 || src.dtype() == torch::kInt32 ||
src.dtype() == torch::kLong || src.dtype() == torch::kFloat16 ||
src.dtype() == torch::kFloat32 || src.dtype() == torch::kBFloat16,
"data type only supports torch::kInt8, torch::kUInt8, torch::kInt16, torch::kInt32, "
"torch::kLong, torch::kFloat16, torch::kFloat32 and torch::kBFloat16");
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
torch::Device mlu_device = src_device;
cnrtMemTransDir_t memcpy_type;
bool src_is_mlu = isMlu(src);
bool dst_is_mlu = isMlu(dst);
if (src_is_mlu && dst_is_mlu) {
TORCH_CHECK(src_device.index() == dst_device.index(), "src and dst must be on the same MLU");
memcpy_type = cnrtMemcpyDevToDev;
} else if (src_is_mlu && dst_device.is_cpu()) {
memcpy_type = cnrtMemcpyDevToHost;
} else if (src_device.is_cpu() && dst_is_mlu) {
memcpy_type = cnrtMemcpyHostToDev;
mlu_device = dst_device;
} else {
TORCH_CHECK(false, "Invalid device combination");
}
const torch_mlu::mlu::MLUGuard device_guard(mlu_device);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
TMO_KERNEL_CHECK_FATAL(invokeSwapBlocksKernel(handle, dst.data_ptr(), src.data_ptr(),
block_size_in_bytes, memcpy_type, block_mapping));
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,472 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_TORCH_API_TORCH_OPS_API_H_
#define CSRC_TORCH_API_TORCH_OPS_API_H_
#include <cstdint>
#include <map>
#include <optional>
#include <vector>
#include "op_theory.h"
#include "ops/kernel_api.h"
#include "torch/extension.h"
#include "utils.h"
namespace tmo {
namespace torch_api {
void fused_layernorm(const at::Tensor &input,
const at::Tensor &out,
const c10::optional<at::Tensor> &residual,
const c10::optional<at::Tensor> &gamma,
const c10::optional<at::Tensor> &beta,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &quant_scale,
const c10::optional<at::Tensor> &residual_out,
const c10::optional<at::Tensor> &smooth_quant_scale,
const std::string &norm_mode,
double eps,
bool store_output_before_norm,
bool dynamic_quant);
std::vector<at::Tensor> attention_project(const at::Tensor &input,
const at::Tensor &q_weight,
const c10::optional<at::Tensor> &q_bias,
const c10::optional<at::Tensor> &k_weight,
const c10::optional<at::Tensor> &k_bias,
const c10::optional<at::Tensor> &v_weight,
const c10::optional<at::Tensor> &v_bias,
const c10::optional<at::Tensor> &norm_weight,
const c10::optional<at::Tensor> &norm_bias,
const c10::optional<at::Tensor> &residual,
const std::string &out_layout,
int64_t head_size,
double eps,
double alpha,
double beta,
bool norm_out);
at::Tensor ffn(const at::Tensor &input,
const at::Tensor &up_fc_weight,
const c10::optional<at::Tensor> &up_fc_bias,
const at::Tensor &down_proj_weight,
const c10::optional<at::Tensor> &down_proj_bias,
const c10::optional<at::Tensor> &gate_up_proj_weight,
const c10::optional<at::Tensor> &gate_up_proj_bias,
const c10::optional<at::Tensor> &layernorm_weight,
const c10::optional<at::Tensor> &layernorm_bias,
const std::string &act_mode,
const std::string &residual_is,
double eps,
double alpha,
double beta);
void flash_attention(const at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
const at::Tensor &out,
const c10::optional<at::Tensor> &output_lse,
const c10::optional<at::Tensor> &cu_seq_lens_q,
const c10::optional<at::Tensor> &cu_seq_lens_kv,
const c10::optional<at::Tensor> &alibi_slope,
const c10::optional<at::Tensor> &attn_bias,
const c10::optional<at::Tensor> &k_cache_quant_scale,
const c10::optional<at::Tensor> &v_cache_quant_scale,
const c10::optional<at::Tensor> &block_tables,
const int64_t max_seq_len_q,
const int64_t max_seq_len_kv,
const double softmax_scale,
const bool is_causal,
const int64_t window_size_left,
const int64_t window_size_right,
const std::string &compute_dtype,
bool return_lse);
void single_query_cached_kv_attn(
const torch::Tensor &q_ori,
const torch::Tensor &k_cache,
const torch::Tensor &v_cache,
const torch::Tensor &output,
const torch::Tensor &block_tables,
const torch::Tensor &context_lens, // [batch]
const c10::optional<torch::Tensor> &output_lse,
const c10::optional<torch::Tensor> &k_cache_quant_scale,
const c10::optional<torch::Tensor> &v_cache_quant_scale,
const c10::optional<torch::Tensor> &alibi_slopes, // [bs, head_num] or [head_num]
int64_t max_context_len,
int64_t windows_size_left,
int64_t windows_size_right,
double softmax_scale,
bool return_lse,
int64_t kv_cache_quant_bit_size);
void reshape_linear_cache(const at::Tensor &key,
const c10::optional<at::Tensor> &value,
at::Tensor &key_cache,
const c10::optional<at::Tensor> &value_cache,
const at::Tensor &context_lengths,
const int64_t max_context_len,
bool packed,
const c10::optional<at::Tensor> &context_seq_offset,
const c10::optional<at::Tensor> &cache_bs_id,
const c10::optional<at::Tensor> &cache_seqlen_offset);
void reshape_paged_cache(const torch::Tensor &k,
const c10::optional<torch::Tensor> &v,
torch::Tensor &k_cache,
const c10::optional<torch::Tensor> &v_cache,
const torch::Tensor &slot_mapping);
void quant_to_paged_cache(const torch::Tensor &k,
const c10::optional<torch::Tensor> &v,
torch::Tensor &k_cache,
const c10::optional<torch::Tensor> &v_cache,
torch::Tensor &k_cache_scale,
const c10::optional<torch::Tensor> &v_cache_scale,
const torch::Tensor &slot_mapping);
void offline_quant_to_paged_cache(const torch::Tensor &k,
const c10::optional<torch::Tensor> &v,
const torch::Tensor &k_cache_scale,
const c10::optional<torch::Tensor> &v_cache_scale,
const torch::Tensor &slot_mapping,
torch::Tensor &k_cache,
const c10::optional<torch::Tensor> &v_cache);
void quant_to_linear_cache(const at::Tensor &key,
const c10::optional<at::Tensor> &value,
at::Tensor &key_cache,
const c10::optional<at::Tensor> &value_cache,
at::Tensor &key_cache_scale,
const c10::optional<at::Tensor> &value_cache_scale,
const at::Tensor &context_lengths,
const int64_t max_context_len,
bool packed,
const c10::optional<at::Tensor> &context_seq_offset,
const c10::optional<at::Tensor> &cache_bs_id,
const c10::optional<at::Tensor> &cache_seqlen_offset,
const int64_t quant_bit);
void offline_quant_to_linear_cache(const at::Tensor &key,
const c10::optional<at::Tensor> &value,
at::Tensor &key_cache,
const c10::optional<at::Tensor> &value_cache,
const at::Tensor &key_cache_scale,
const c10::optional<at::Tensor> &value_cache_scale,
const at::Tensor &context_lengths,
const int64_t max_context_len,
const int64_t quant_mode,
const bool packed,
const c10::optional<at::Tensor> &context_seq_offset,
const c10::optional<at::Tensor> &cache_bs_id,
const c10::optional<at::Tensor> &cache_seqlen_offset);
void apply_rotary(const torch::Tensor &input,
const torch::Tensor &sin_cache,
const torch::Tensor &cos_cache,
const c10::optional<torch::Tensor> &position_ids,
const c10::optional<torch::Tensor> &cu_seqlens,
bool interleaved,
bool discrete,
bool dynamic_ntk,
int64_t max_seqlen);
void swap_blocks(torch::Tensor &dst,
const torch::Tensor &src,
const c10::Dict<int64_t, int64_t> &block_mapping);
void copy_blocks(const std::vector<torch::Tensor> &k_caches,
const std::vector<torch::Tensor> &v_caches,
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping);
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> copy_blocks_out_of_place(
const std::vector<torch::Tensor> &k_caches,
const std::vector<torch::Tensor> &v_caches,
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping);
at::Tensor fused_moe(const at::Tensor &hidden_states,
const at::Tensor &gating_output,
const at::Tensor &w1,
const at::Tensor &w2,
const c10::optional<at::Tensor> &bias1,
const c10::optional<at::Tensor> &bias2,
const c10::optional<at::Tensor> &residual,
const c10::optional<at::Tensor> &input_smooth,
const c10::optional<at::Tensor> &act_smooth,
const c10::optional<at::Tensor> &w1_scale,
const c10::optional<at::Tensor> &w2_scale,
const c10::optional<at::List<int64_t>> &w1_quant_flag,
const c10::optional<at::List<int64_t>> &w2_quant_flag,
const int64_t topk,
const bool renormalize,
const bool gated,
const std::string &act_mode,
const int64_t start_expert_id,
const int64_t block_n = 0,
const int64_t cncl_comm = 0);
// d = (a * b + bias) * alpha + c * beta
// d = (((a * a_scale * b * b_scale) * gemm_output_scale + bias) * alpha + c * beta) * d_scale
at::Tensor quant_matmul(
const at::Tensor &a_tensor, // input
const c10::optional<at::Tensor> &a_scale, // input scale, smooth quant
const c10::optional<at::Tensor> &a_zero, // input zero
const at::Tensor &b_tensor, // weight
const c10::optional<at::Tensor> &b_scale, // weight scale, smooth quant
const c10::optional<at::Tensor> &b_zero, // weight zero
const c10::optional<at::Tensor> &bias, // bias
const c10::optional<at::Tensor> &c_tensor, // residual
const c10::optional<at::Tensor> &c_scale, // residual scale
const c10::optional<at::Tensor> &c_zero, // residual zero
const c10::optional<at::Tensor> &gemm_output_scale, // for dequant, weight only
const c10::optional<at::Tensor> &gemm_output_zero, // for dequant, preserve
const c10::optional<std::string> &data_type, // output data type
const c10::optional<at::Tensor> &d, // output
const std::string &quant_algo, // quant type
const std::string &a_quant_layout, // input quant mode
const std::string &b_quant_layout, // weight quant mode
int64_t quant_bit_size = 8, // weight quant bit size
const std::string &act_mode = "none", // act mode
bool use_hp_active = false, // for active, whether uses high precision mode
double act_coef = 1.f, // active_coef
double alpha = 1.f, // for quant matmul
double beta = 1.f, // for residual
bool trans_a = false, // whether transpose input
bool trans_b = true); // whether transpose weight
at::Tensor quant_matmul_allreduce(const int64_t cncl_comm,
const at::Tensor &a_tensor,
const c10::optional<at::Tensor> &a_scale,
const c10::optional<at::Tensor> &a_zero,
const at::Tensor &b_tensor,
const c10::optional<at::Tensor> &b_scale,
const c10::optional<at::Tensor> &b_zero,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &c_tensor,
const c10::optional<at::Tensor> &c_scale,
const c10::optional<at::Tensor> &c_zero,
const c10::optional<at::Tensor> &gemm_output_scale,
const c10::optional<at::Tensor> &gemm_output_zero,
const c10::optional<std::string> &data_type,
const c10::optional<at::Tensor> &d,
const std::string &quant_algo,
const std::string &a_quant_layout,
const std::string &b_quant_layout,
int64_t quant_bit_size,
double alpha,
double beta,
bool trans_a,
bool trans_b,
const int64_t block_m);
void active(const torch::Tensor &input,
const torch::Tensor &output,
const c10::optional<torch::Tensor> &bias,
const c10::optional<torch::Tensor> &cusum_token_count,
const std::string &act_mode,
bool is_gated,
int64_t start_expert_id,
int64_t expert_size,
double active_coef = 1.0);
void smooth_quant(const at::Tensor &input,
const at::Tensor &input_scale,
const at::Tensor &output,
const at::Tensor &output_scale,
const c10::optional<at::Tensor> &input_zero,
const c10::optional<at::Tensor> &token_count,
const c10::optional<at::Tensor> &gather_index,
const c10::optional<at::Tensor> &gather_index_start_position,
const std::string &quant_mode,
const bool &dynamic_quant);
at::Tensor matmul(const at::Tensor &a,
const at::Tensor &b,
const c10::optional<at::Tensor> &d,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &c,
const c10::optional<std::string> &dtype,
const std::string &act_mode,
double alpha,
double beta,
bool fast_act,
bool approximate,
double a_scale,
double b_scale,
bool trans_a,
bool trans_b);
at::Tensor matmul_allreduce(const int64_t cncl_comm,
const at::Tensor &a,
const at::Tensor &b,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &c,
const c10::optional<at::Tensor> &d,
const double alpha,
const double beta,
const int64_t block_m);
at::Tensor group_gemm(const at::Tensor &a_tensor,
const at::Tensor &b_tensor,
const at::Tensor &m_list,
const c10::optional<at::Tensor> &gather_idx,
const c10::optional<at::Tensor> &c_tensor,
const c10::optional<at::Tensor> &alpha,
const c10::optional<at::Tensor> &beta,
const c10::optional<at::Tensor> &a_scale,
const c10::optional<at::Tensor> &b_scale,
const c10::optional<at::Tensor> &bias,
const c10::optional<std::string> &data_type,
const c10::optional<at::List<int64_t>> &quant_flag,
const c10::optional<at::Tensor> &b_offset,
const int64_t max_m /*max m in m_list*/);
void preload(const torch::Tensor &weight, const int64_t size);
at::Tensor group_gemm_combine_result_allreduce(const int64_t cncl_comm,
const at::Tensor &a_tensor,
const at::Tensor &b_tensor,
const at::Tensor &m_list,
const at::Tensor &combine_idx,
const at::Tensor &combine_weight,
const c10::optional<at::Tensor> &c_tensor,
const c10::optional<at::Tensor> &alpha,
const c10::optional<at::Tensor> &beta,
const c10::optional<at::Tensor> &a_scale,
const c10::optional<at::Tensor> &b_scale,
const c10::optional<std::string> &data_type,
const int64_t num_token,
const int64_t topk,
const int64_t block_n);
at::Tensor flash_attn_sq_mm_allreduce(const int64_t cncl_comm,
const at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
const c10::optional<at::Tensor> &cu_seq_lens_q,
const c10::optional<at::Tensor> &cu_seq_lens_kv,
const c10::optional<at::Tensor> &alibi_slope,
const c10::optional<at::Tensor> &attn_bias,
const at::Tensor &smooth,
const at::Tensor &weight,
const at::Tensor &weight_scale,
const c10::optional<at::Tensor> &bias,
const int64_t max_seq_len_q,
const int64_t max_seq_len_kv,
const double softmax_scale,
const bool is_causal,
const int64_t window_size_left,
const int64_t window_size_right,
const std::string &compute_dtype,
const int64_t block_seq);
std::vector<at::Tensor> moe_softmax_topk(const at::Tensor &input,
int64_t topk,
int64_t num_expert_group,
int64_t topk_group,
bool normalize,
const c10::optional<at::Tensor> &mask,
const std::string &normed_by);
at::Tensor moe_expand_input(const torch::Tensor &input,
const torch::Tensor &gather_idx,
const c10::optional<torch::Tensor> &cusum_token_count,
int64_t start_expert_id,
int64_t expert_size);
std::vector<at::Tensor> moe_gen_idx(const torch::Tensor &expert_id, int64_t expert_num);
at::Tensor moe_combine_result(const at::Tensor &input,
const at::Tensor &reduce_weight,
const at::Tensor &gather_ids,
const c10::optional<at::Tensor> &residual,
const c10::optional<at::Tensor> &cusum_token_count,
int64_t start_expert_id,
int64_t expert_size,
const c10::optional<at::Tensor> &bias);
void fused_rope(at::Tensor &qkv,
at::Tensor &key_cache_hp,
at::Tensor &value_cache_hp,
const c10::optional<at::Tensor> &key_cache_lp,
const c10::optional<at::Tensor> &value_cache_lp,
const at::Tensor &sin_table,
const at::Tensor &cos_table,
const at::Tensor &position_ids,
const at::Tensor &gamma,
const at::Tensor &beta,
const c10::optional<at::Tensor> &key_scale_hp,
const c10::optional<at::Tensor> &value_scale_hp,
const c10::optional<at::Tensor> &key_scale_lp,
const c10::optional<at::Tensor> &value_scale_lp,
const c10::optional<at::Tensor> &cache_bs_id_hp,
const c10::optional<at::Tensor> &cache_seq_offsets_hp,
const c10::optional<at::Tensor> &cache_bs_id_lp,
const c10::optional<at::Tensor> &cache_seq_offsets_lp,
const c10::optional<at::Tensor> &slot_mapping_hp,
const c10::optional<at::Tensor> &slot_mapping_lp,
const double eps);
void batch_matmul(const at::Tensor &a,
const at::Tensor &b,
const at::Tensor &c,
double alpha,
double beta,
double a_scale,
double b_scale,
bool trans_a,
bool trans_b);
at::Tensor moe_cast_gating(const at::Tensor &input, const at::Tensor &weight);
void update_out_and_lse(at::Tensor &out,
at::Tensor &lse,
const at::Tensor &block_out,
const at::Tensor &block_lse,
const c10::optional<at::Tensor> &seq_offsets,
const c10::optional<at::Tensor> &cu_seqs,
const c10::optional<at::Tensor> &block_cu_seqs);
void cnpxPush(const OpTheory &op);
void cnpxPop(void);
void dequant_from_linear_cache(at::Tensor &key,
const c10::optional<at::Tensor> &value,
const at::Tensor &key_cache,
const c10::optional<at::Tensor> &value_cache,
const at::Tensor &key_quant_scale,
const c10::optional<at::Tensor> &value_quant_scale,
const at::Tensor &context_lengths,
const int64_t max_context_len,
const c10::optional<at::Tensor> &context_seq_offset,
const c10::optional<at::Tensor> &cache_bs_id,
const c10::optional<at::Tensor> &cache_seq_offset,
const int64_t quant_mode,
const int64_t quant_bit);
void dequant_from_paged_cache(at::Tensor &key,
const c10::optional<at::Tensor> &value,
const at::Tensor &key_cache,
const c10::optional<at::Tensor> &value_cache,
const at::Tensor &key_cache_quant_scale,
const c10::optional<at::Tensor> &value_cache_quant_scale,
const at::Tensor &context_lengths,
int64_t max_context_len,
const c10::optional<at::Tensor> &context_seq_offset,
const at::Tensor &block_tables,
int64_t quant_mode,
int64_t quant_bit);
} // namespace torch_api
} // namespace tmo
#endif // CSRC_TORCH_API_TORCH_OPS_API_H_

View File

@@ -0,0 +1,263 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <torch/library.h>
#include "glue_ops.h"
#include "torch_ops_api.h"
TORCH_LIBRARY_FRAGMENT(torch_mlu_ops, m) {
// torch 2.1.0 does not support impl_abstract_pystub, enable it when torch 2.3.0 is released
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3
m.impl_abstract_pystub("torch_mlu_ops.abstract");
#endif
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::attention_project(Tensor input, Tensor q_weight, Tensor? q_bias, Tensor? "
"k_weight, "
"Tensor? k_bias, Tensor? v_weight, Tensor? v_bias, Tensor? norm_weight, Tensor? norm_bias, "
"Tensor? residual, str out_layout, int head_size, float eps, float alpha, float beta, bool "
"norm_out) -> (Tensor[])"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::ffn(Tensor input, Tensor up_fc_weight, Tensor? up_fc_bias, Tensor "
"down_proj_weight, "
"Tensor? down_proj_bias, Tensor? gate_up_proj_weight, Tensor? gate_up_proj_bias, Tensor? "
"layernorm_weight, Tensor? layernorm_bias, str act_mode, str residual_is, float eps, float "
"alpha, float beta) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::flash_attention(Tensor q, Tensor k, Tensor v, Tensor(a!) out, Tensor(b!)? "
"out_lse, Tensor? cu_seq_lens_q, Tensor? cu_seq_lens_kv, "
"Tensor? alibi_slope, Tensor? attn_bias, Tensor? k_quant_scale, Tensor? v_quant_scale, "
"Tensor? block_tables, int max_seq_len_q, int "
"max_seq_len_kv, float softmax_scale, bool is_causal, int window_size_left, int "
"window_size_right, str compute_dtype, bool return_lse) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::single_query_cached_kv_attn(Tensor q_ori, Tensor k_cache, Tensor "
"v_cache, Tensor(a!) output, Tensor block_tables, Tensor context_lens, "
"Tensor(b!)? out_lse, Tensor? k_cache_quant_scale, Tensor? v_cache_quant_scale, "
"Tensor? alibi_slopes, int max_contxt_len, int windows_size_left, int "
"windows_size_right, float softmax_scale, bool return_lse, int kv_cache_quant_bit_size) -> "
"()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::apply_rotary(Tensor(a!) input, Tensor sin_cache, Tensor cos_cache, "
"Tensor? position_ids, Tensor? cu_seqlens, bool interleaved, bool "
"discrete, bool dynamic_ntk, int max_seqlen) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::reshape_linear_cache(Tensor key, Tensor? value, Tensor(a!) key_cache, "
"Tensor(b!)? "
"value_cache, Tensor context_lengths, int max_context_len, bool packed, Tensor? "
"context_seq_offset, Tensor? cache_bs_id, Tensor? cache_seqlen_offset) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::reshape_paged_cache(Tensor k, Tensor? v, Tensor(a!) k_cache, "
"Tensor(b!)? v_cache, Tensor slot_mapping) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::quant_to_paged_cache(Tensor k, Tensor? v, Tensor(a!) k_cache, Tensor(b!)? "
"v_cache, "
"Tensor(c!) k_cache_scale, Tensor(d!)? v_cache_scale, Tensor slot_mapping) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::offline_quant_to_paged_cache(Tensor k, Tensor? v, Tensor k_cache_scale,"
"Tensor? v_cache_scale, Tensor slot_mapping, Tensor(a!) k_cache, Tensor(b!)? v_cache)-> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::quant_to_linear_cache(Tensor key, Tensor? value, Tensor(a!) key_cache, "
"Tensor(b!)? "
"value_cache, Tensor(c!) key_cache_scale, Tensor(d!)? value_cache_scale, Tensor "
"context_lengths, "
"int max_context_len, bool packed, Tensor? context_seq_offset, Tensor? cache_bs_id, Tensor? "
"cache_seqlen_offset, int quant_bit=8) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::offline_quant_to_linear_cache(Tensor key, Tensor? value, Tensor(a!) "
"key_cache,"
" Tensor(b!)? value_cache, Tensor(c!) key_cache_scale, Tensor(d!)? value_cache_scale, "
"Tensor context_lengths, int max_context_len, int quant_mode, bool packed, "
"Tensor? context_seq_offset, Tensor? cache_bs_id, Tensor? "
"cache_seqlen_offset) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::quant_matmul(Tensor a_tensor, Tensor? a_scale, Tensor? a_zero, Tensor "
"b_tensor, "
"Tensor? b_scale, Tensor? b_zero, Tensor? bias, Tensor? c_tensor, Tensor? c_scale, Tensor? "
"c_zero, Tensor? gemm_output_scale, Tensor? gemm_output_zero, str? data_type, Tensor? d, "
"str "
"quant_algo, str a_quant_layout, str b_quant_layout, int quant_bit_size=8, str "
"act_mode='none', bool use_hp_active=False, float act_coef=1.0, float alpha=1.0, float "
"beta=1.0, bool trans_a=False, bool trans_b=True) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::quant_matmul_allreduce(int cncl_comm, Tensor a_tensor, Tensor? a_scale, "
"Tensor? a_zero, Tensor b_tensor, Tensor? b_scale, Tensor? b_zero, "
"Tensor? bias, Tensor? c_tensor, Tensor? c_scale, Tensor? "
"c_zero, Tensor? gemm_output_scale, Tensor? gemm_output_zero, str? data_type, Tensor? d, "
"str quant_algo, str a_quant_layout, str b_quant_layout, int quant_bit_size=8, "
"float alpha=1.0, float beta=1.0, bool trans_a=False, bool trans_b=True, int block_m=0) -> "
"Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::active(Tensor input, Tensor(a!) output, Tensor? bias, "
"Tensor? cusum_token_count, str act_mode, bool is_gated, int start_expert_id, "
"int expert_size, float active_coef=1.0) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::smooth_quant(Tensor input, Tensor input_scale, Tensor(a!) output, "
"Tensor(b!) output_scale, "
"Tensor? input_zero, Tensor? token_count, Tensor? gather_index, "
"Tensor? gather_index_start_position, str quant_mode, bool dynamic_quant) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::fused_layernorm(Tensor input, Tensor(a!) out, "
"Tensor? residual, Tensor? gamma, Tensor? beta, Tensor? bias, Tensor? "
"quant_scale, Tensor(b!)? residual_out, Tensor? smooth_quant_scale, str norm_mode, "
"float eps, bool store_output_before_norm, bool dynamic_quant) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::fused_moe(Tensor hidden_states, Tensor gating_output, Tensor "
"w1, Tensor w2, Tensor? bias1, Tensor? bias2, Tensor? residual, Tensor? "
"input_smooth, Tensor? act_smooth, Tensor? w1_scale, Tensor? w2_scale, "
"int[]? w1_quant_flag, int[]? w2_quant_flag, "
"int topk, bool renormalize, bool gated, str act_mode, int start_expert_id, "
"int block_n, int cncl_comm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::matmul(Tensor a, Tensor b, Tensor? d, Tensor? bias, Tensor? c, str? dtype,"
" str act_mode, float alpha, float beta, bool fast_act, bool approximate, float a_scale, "
"float "
"b_scale, bool trans_a, bool trans_b) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::batch_matmul(Tensor a, Tensor b, Tensor(a!) c,"
"float alpha, float beta, float a_scale, float b_scale, bool trans_a, bool trans_b) -> ()"));
m.def(
TORCH_SELECTIVE_SCHEMA("torch_mlu_ops::matmul_allreduce(int cncl_comm, Tensor a, Tensor b, "
"Tensor? bias, Tensor? c, Tensor? d, "
"float alpha, float beta, int block_m) -> Tensor"));
m.def(
TORCH_SELECTIVE_SCHEMA("torch_mlu_ops::swap_blocks(Tensor(a!) dst, Tensor src, Dict(int, "
"int) block_mapping) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::copy_blocks(Tensor(a!)[] k_caches, Tensor(b!)[] v_caches, "
"Dict(int, int[]) block_mapping) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::copy_blocks_out_of_place(Tensor[] k_caches, Tensor[] v_caches, "
"Dict(int, int[]) block_mapping) -> (Tensor[], Tensor[])"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::group_gemm(Tensor a, Tensor b, Tensor m_list, Tensor? idx, Tensor? c, "
"Tensor? alpha, Tensor? beta, Tensor? a_scale, Tensor? b_scale, Tensor? bias,"
"str? data_type, int[]? quant_flag, Tensor? b_offset, int max_m) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("torch_mlu_ops::preload(Tensor(a!) weight, int size) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::group_gemm_combine_result_allreduce(int cncl_comm, Tensor a, Tensor b, "
"Tensor m_list, Tensor combine_idx, Tensor combine_weight, Tensor? c, "
"Tensor? alpha, Tensor? beta, Tensor? a_scale, Tensor? b_scale, "
"str? data_type, int num_token, int topk, int block_n) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::flash_attn_sq_mm_allreduce(int cncl_comm, Tensor q, Tensor k, "
"Tensor v, Tensor? cu_seq_lens_q, Tensor? cu_seq_lens_k, Tensor? alibi_slope, "
"Tensor? attn_bias, Tensor smooth, Tensor weight, Tensor weight_scale, "
"Tensor? bias, int max_seq_len_q, int max_seq_len_kv, float softmax_scale, bool is_causal, "
"int window_size_left, int window_size_right, "
"str compute_dtype, int block_seq) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::moe_softmax_topk(Tensor input, int topk, int num_expert_group, "
"int topk_group, bool normalize, Tensor? mask, str normed_by) -> Tensor[]"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::moe_expand_input(Tensor input, Tensor gather_idx, Tensor ? "
"cusum_token_count, int start_expert_id, int expert_size) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::moe_gen_idx(Tensor expert_id, int expert_num) -> Tensor[]"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::moe_combine_result(Tensor input, Tensor reduce_weight,"
"Tensor gather_ids, Tensor? residual,"
"Tensor? cusum_token_count, int start_expert_id, int expert_size, Tensor? bias) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::fused_rope(Tensor(a!) input, Tensor(b!) k_cache_hp, Tensor(c!) v_cache_hp, "
"Tensor(d!)? k_cache_lp, Tensor(e!)? v_cache_lp, Tensor sin_table, Tensor cos_table, "
"Tensor position_ids, Tensor gamma, Tensor beta, Tensor? k_scale_hp, Tensor? v_scale_hp, "
"Tensor(f!)? k_scale_lp, Tensor(g!)? v_scale_lp, Tensor? cache_bs_id_hp, "
"Tensor? cache_seq_offsets_hp, Tensor? cache_bs_id_lp, Tensor? cache_seq_offsets_lp, "
"Tensor? slot_mapping_hp, Tensor? slot_mapping_lp, float eps) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::moe_cast_gating(Tensor input, Tensor weight) -> Tensor"));
m.def(
TORCH_SELECTIVE_SCHEMA("torch_mlu_ops::update_out_and_lse(Tensor(a!) out, Tensor(b!) lse, "
"Tensor block_out, Tensor block_lse,"
"Tensor? seq_offsets, Tensor? cu_seqs, Tensor? block_cu_seqs) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::dequant_from_linear_cache(Tensor(a!) key, Tensor(b!)? value, "
"Tensor key_cache, Tensor? value_cache, Tensor key_cache_quant_scale, "
"Tensor? value_cache_quant_scale, Tensor context_lengths, int max_context_len, "
"Tensor? context_seq_offset, Tensor? cache_bs_id, Tensor? cache_seq_offset, "
"int quant_mode=0, int quant_bit=8) -> ()"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torch_mlu_ops::dequant_from_paged_cache(Tensor(a!) key, Tensor(b!)? value, "
"Tensor key_cache, Tensor? value_cache, Tensor key_cache_quant_scale, "
"Tensor? value_cache_quant_scale, Tensor context_lengths, int max_context_len, "
"Tensor? context_seq_offset, Tensor block_tables, int quant_mode=0, int quant_bit=8) -> ()"));
}
TORCH_LIBRARY_IMPL(torch_mlu_ops, PrivateUse1, m) {
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::attention_project"),
TORCH_FN(tmo::torch_api::attention_project));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::ffn"), TORCH_FN(tmo::torch_api::ffn));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::flash_attention"),
TORCH_FN(tmo::torch_api::flash_attention));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::single_query_cached_kv_attn"),
TORCH_FN(tmo::torch_api::single_query_cached_kv_attn));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::apply_rotary"),
TORCH_FN(tmo::torch_api::apply_rotary));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::reshape_linear_cache"),
TORCH_FN(tmo::torch_api::reshape_linear_cache));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::reshape_paged_cache"),
TORCH_FN(tmo::torch_api::reshape_paged_cache));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::quant_to_paged_cache"),
TORCH_FN(tmo::torch_api::quant_to_paged_cache));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::offline_quant_to_paged_cache"),
TORCH_FN(tmo::torch_api::offline_quant_to_paged_cache));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::quant_to_linear_cache"),
TORCH_FN(tmo::torch_api::quant_to_linear_cache));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::offline_quant_to_linear_cache"),
TORCH_FN(tmo::torch_api::offline_quant_to_linear_cache));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::quant_matmul"),
TORCH_FN(tmo::torch_api::quant_matmul));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::quant_matmul_allreduce"),
TORCH_FN(tmo::torch_api::quant_matmul_allreduce));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::active"), TORCH_FN(tmo::torch_api::active));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::smooth_quant"),
TORCH_FN(tmo::torch_api::smooth_quant));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::fused_layernorm"),
TORCH_FN(tmo::torch_api::fused_layernorm));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::fused_moe"), TORCH_FN(tmo::torch_api::fused_moe));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::matmul"), TORCH_FN(tmo::torch_api::matmul));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::batch_matmul"),
TORCH_FN(tmo::torch_api::batch_matmul));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::matmul_allreduce"),
TORCH_FN(tmo::torch_api::matmul_allreduce));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::swap_blocks"), TORCH_FN(tmo::torch_api::swap_blocks));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::copy_blocks"), TORCH_FN(tmo::torch_api::copy_blocks));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::copy_blocks_out_of_place"),
TORCH_FN(tmo::torch_api::copy_blocks_out_of_place));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::group_gemm"), TORCH_FN(tmo::torch_api::group_gemm));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::preload"), TORCH_FN(tmo::torch_api::preload));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::group_gemm_combine_result_allreduce"),
TORCH_FN(tmo::torch_api::group_gemm_combine_result_allreduce));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::flash_attn_sq_mm_allreduce"),
TORCH_FN(tmo::torch_api::flash_attn_sq_mm_allreduce));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_softmax_topk"),
TORCH_FN(tmo::torch_api::moe_softmax_topk));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_expand_input"),
TORCH_FN(tmo::torch_api::moe_expand_input));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_gen_idx"), TORCH_FN(tmo::torch_api::moe_gen_idx));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_combine_result"),
TORCH_FN(tmo::torch_api::moe_combine_result));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_cast_gating"),
TORCH_FN(tmo::torch_api::moe_cast_gating));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::fused_rope"), TORCH_FN(tmo::torch_api::fused_rope));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::update_out_and_lse"),
TORCH_FN(tmo::torch_api::update_out_and_lse));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::dequant_from_linear_cache"),
TORCH_FN(tmo::torch_api::dequant_from_linear_cache));
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::dequant_from_paged_cache"),
TORCH_FN(tmo::torch_api::dequant_from_paged_cache));
}
// Function(copy_blocks__functionalization_glue) is no longer needed in the torch 2.5
TORCH_LIBRARY_IMPL(torch_mlu_ops, Functionalize, m) {
m.impl("torch_mlu_ops::copy_blocks", copy_blocks__functionalization_glue);
}

View File

@@ -0,0 +1,115 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/update_out_and_lse.mluh"
#include "torch_ops_api.h"
#include "utils.h"
namespace tmo {
namespace torch_api {
void update_out_and_lse(at::Tensor &out,
at::Tensor &lse,
const at::Tensor &block_out,
const at::Tensor &block_lse,
const c10::optional<at::Tensor> &seq_offsets,
const c10::optional<at::Tensor> &cu_seqs,
const c10::optional<at::Tensor> &block_cu_seqs) {
const torch_mlu::mlu::MLUGuard device_guard(out.device());
auto queue = torch_mlu::getCurMLUStream();
checkTensorSameAttr<TensorAttr::DEVICE>(out, lse, block_out, block_lse, seq_offsets, cu_seqs,
block_cu_seqs);
checkTensorSameAttr<TensorAttr::DTYPE>(seq_offsets, cu_seqs, block_cu_seqs);
checkTensorSameAttr<TensorAttr::DTYPE>(out, block_out);
checkTensorSameAttr<TensorAttr::DTYPE>(lse, block_lse);
TORCH_CHECK(out.dim() == 3 || out.dim() == 4, "the dim of out must be 3 or 4");
TORCH_CHECK(block_out.dim() == 3 || block_out.dim() == 4, "the dim of block_out must be 3 or 4");
TORCH_CHECK(lse.dim() == 3, "the dim of lse must be 3");
TORCH_CHECK(block_lse.dim() == 3, "the dim of block_lse must be 3");
bool packed = out.dim() == 3;
int32_t batch{0}, head_num{0}, head_size{0}, max_seq_len{0}, block_seq_len{0};
int64_t bs_stride{0}, seq_stride{0}, head_stride{0}, block_bs_stride{0}, block_seq_stride{0},
block_head_stride{0};
auto data_type = getCnnlDataType(out.scalar_type());
batch = lse.size(0);
head_num = lse.size(1);
max_seq_len = lse.size(2);
block_seq_len = block_lse.size(2);
head_size = out.size(-1);
TORCH_CHECK(block_seq_len <= max_seq_len,
"the max block_seq_len of block_lse can not be greater than the max_seq_len of lse. "
"block_seq_len: ",
block_seq_len, " max_seq_len: ", max_seq_len);
TORCH_CHECK(data_type == CNNL_DTYPE_FLOAT || data_type == CNNL_DTYPE_BFLOAT16 ||
data_type == CNNL_DTYPE_HALF,
"the data_type of out only support float, bfloat16 and half");
if (packed) {
seq_stride = out.stride(0);
head_stride = out.stride(1);
block_seq_stride = block_out.stride(0);
block_head_stride = block_out.stride(1);
if (max_seq_len == block_seq_len && block_seq_len == 1) {
bs_stride = seq_stride;
block_bs_stride = block_seq_stride;
}
CHECK_SHAPE(out, out.size(0), head_num, head_size);
CHECK_SHAPE(block_out, block_out.size(0), head_num, head_size);
} else {
bs_stride = out.stride(0);
seq_stride = out.stride(1);
head_stride = out.stride(2);
block_bs_stride = block_out.stride(0);
block_seq_stride = block_out.stride(1);
block_head_stride = block_out.stride(2);
CHECK_SHAPE(out, batch, max_seq_len, head_num, head_size);
CHECK_SHAPE(block_out, batch, block_seq_len, head_num, head_size);
}
TORCH_CHECK(head_size <= 1024, "the head_size can not be greater than 1024.");
if (seq_offsets.has_value()) {
TORCH_CHECK(seq_offsets.value().is_contiguous(), "seq_offsets must be contiguous");
TORCH_CHECK(seq_offsets.value().dtype() == torch::kInt32,
"the dtype of seq_offsets must be int32");
TORCH_CHECK(seq_offsets.value().dim() == 1, "the dim of seq_offsets must be 1");
CHECK_SHAPE(seq_offsets.value(), batch);
}
if (cu_seqs.has_value()) {
TORCH_CHECK(cu_seqs.value().is_contiguous(), "cu_seqs must be contiguous");
TORCH_CHECK(cu_seqs.value().dim() == 1, "the dim of cu_seqs must be 1");
CHECK_SHAPE(cu_seqs.value(), batch + 1);
}
if (block_cu_seqs.has_value()) {
TORCH_CHECK(block_cu_seqs.value().is_contiguous(), "block_cu_seqs must be contiguous");
TORCH_CHECK(block_cu_seqs.value().dim() == 1, "the dim of block_cu_seqs must be 1");
CHECK_SHAPE(block_cu_seqs.value(), batch + 1);
}
invokeUpdateOutAndLse(queue, getAtTensorPtr(out), (float *)getAtTensorPtr(lse),
getAtTensorPtr(block_out), (float *)getAtTensorPtr(block_lse),
(int32_t *)getAtTensorPtr(seq_offsets), (int32_t *)getAtTensorPtr(cu_seqs),
(int32_t *)getAtTensorPtr(block_cu_seqs), batch, head_num, head_size,
max_seq_len, block_seq_len, bs_stride, seq_stride, head_stride,
block_bs_stride, block_seq_stride, block_head_stride, packed, data_type);
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,75 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "utils.h"
#include "common/utils.h"
namespace tmo {
namespace torch_api {
#define CNNL_TYPE_AND_SCALAR_TYPE_WITHOUT_64BIT(_) \
_(CNNL_DTYPE_FLOAT, at::kFloat) \
_(CNNL_DTYPE_BFLOAT16, at::kBFloat16) \
_(CNNL_DTYPE_HALF, at::kHalf) \
_(CNNL_DTYPE_INT32, at::kInt) \
_(CNNL_DTYPE_INT8, at::kChar) \
_(CNNL_DTYPE_UINT8, at::kByte) \
_(CNNL_DTYPE_BOOL, at::kBool) \
_(CNNL_DTYPE_INT16, at::kShort) \
_(CNNL_DTYPE_COMPLEX_HALF, at::kComplexHalf) \
_(CNNL_DTYPE_COMPLEX_FLOAT, at::kComplexFloat)
cnnlDataType_t getCnnlDataType(const at::ScalarType &data_type) {
switch (data_type) {
#define DEFINE_CASE(cnnl_dtype, scalar_type) \
case scalar_type: \
return cnnl_dtype;
CNNL_TYPE_AND_SCALAR_TYPE_WITHOUT_64BIT(DEFINE_CASE)
#undef DEFINE_CASE
case at::kLong:
return CNNL_DTYPE_INT32;
case at::kDouble:
return CNNL_DTYPE_FLOAT;
case at::kComplexDouble:
return CNNL_DTYPE_COMPLEX_FLOAT;
default:
std::string msg("getCnnlDataType() not supported for ");
throw std::runtime_error(msg + c10::toString(data_type));
}
}
std::vector<TensorDesc> createTensorDescs(const std::initializer_list<at::Tensor> &tensors) {
std::vector<TensorDesc> descs;
for (size_t i = 0; i < tensors.size(); ++i) {
descs.emplace_back(TensorDesc{nullptr, cnnlDestroyTensorDescriptor});
auto tensor = tensors.begin()[i];
if (!tensor.defined()) {
continue;
}
cnnlTensorDescriptor_t desc;
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&desc));
descs[i].reset(desc);
cnnlDataType_t data_type = getCnnlDataType(tensor.scalar_type());
if (tensor.strides().size() == 0) {
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(descs[i].get(), CNNL_LAYOUT_ARRAY, data_type,
tensor.sizes().size(), tensor.sizes().data()));
} else {
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(descs[i].get(), CNNL_LAYOUT_ARRAY, data_type,
tensor.sizes().size(), tensor.sizes().data(),
tensor.strides().data()));
}
}
return descs;
}
} // namespace torch_api
} // namespace tmo

View File

@@ -0,0 +1,164 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_TORCH_API_UTILS_H_
#define CSRC_TORCH_API_UTILS_H_
#include <cstdint>
#include <map>
#include <optional>
#include <string>
#include "ATen/ScalarType.h"
#include "ATen/Tensor.h"
#include "aten/cnnl/cnnlHandle.h"
#include "c10/util/Exception.h"
#include "cnnl.h"
#include "framework/core/MLUStream.h"
#include "framework/core/caching_allocator.h"
#include "framework/core/device.h"
#include "framework/core/mlu_guard.h"
#include "torch/torch.h"
#include "torch/version.h"
namespace tmo {
namespace torch_api {
using TensorDesc = std::unique_ptr<std::remove_pointer_t<cnnlTensorDescriptor_t>,
decltype(&cnnlDestroyTensorDescriptor)>;
std::vector<TensorDesc> createTensorDescs(const std::initializer_list<at::Tensor> &tensors);
cnnlDataType_t getCnnlDataType(const at::ScalarType &data_type);
template <typename T>
bool isMlu(const T &tensor) {
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
return tensor.device().is_privateuseone();
#else
return tensor.is_mlu();
#endif
}
enum class TensorAttr { DEVICE, DTYPE, ALL };
struct attr_t {
int64_t device_id;
at::ScalarType dtype;
};
inline void checkDevice(int64_t &device_id, const at::Tensor &tensor) {
auto tensor_device_id = tensor.get_device();
if (device_id == -1) {
device_id = tensor_device_id;
return;
}
TORCH_CHECK(tensor_device_id == device_id,
"Tensor device id is not same, original device_id: ", device_id,
"now device_id is: ", tensor_device_id);
}
inline void checkDtype(at::ScalarType &dtype, const at::Tensor &tensor) {
auto tensor_dtype = tensor.scalar_type();
if (dtype == at::ScalarType::Undefined) {
dtype = tensor_dtype;
return;
}
TORCH_CHECK(tensor_dtype == dtype, "Tensor dtype is not same. original dtype: ", dtype,
"now dtype is: ", tensor_dtype);
}
template <TensorAttr attr>
inline void checkTensorAttr(attr_t &attr_states, const at::Tensor &tensor) {
if (attr == TensorAttr::DEVICE) {
checkDevice(attr_states.device_id, tensor);
} else if (attr == TensorAttr::DTYPE) {
checkDtype(attr_states.dtype, tensor);
} else if (attr == TensorAttr::ALL) {
checkDevice(attr_states.device_id, tensor);
checkDtype(attr_states.dtype, tensor);
}
}
template <TensorAttr attr,
typename T,
typename = typename std::enable_if<
std::is_same<typename std::decay<T>::type, at::Tensor>::value>::type>
void checkTensorSameWithSpecificAttr(attr_t &attr_states, const c10::optional<T> &tensor) {
if (!tensor.has_value() || !tensor->defined()) return;
auto temp_tensor = tensor.value();
TORCH_CHECK(isMlu(temp_tensor), "Only support mlu tensor.");
checkTensorAttr<attr>(attr_states, temp_tensor);
}
template <TensorAttr attr,
typename T,
typename = typename std::enable_if<
std::is_same<typename std::decay<T>::type, at::Tensor>::value>::type>
void checkTensorSameWithSpecificAttr(attr_t &attr_states, const T &tensor) {
if (!tensor.defined()) return;
TORCH_CHECK(isMlu(tensor), "Only support mlu tensor.");
checkTensorAttr<attr>(attr_states, tensor);
}
template <TensorAttr attr, typename T, typename... Args>
void checkTensorSameWithSpecificAttr(attr_t &attr_states, const T &tensor, Args &&...args) {
checkTensorSameWithSpecificAttr<attr>(attr_states, tensor);
checkTensorSameWithSpecificAttr<attr>(attr_states, std::forward<Args>(args)...);
}
template <TensorAttr attr, typename... Args>
void checkTensorSameAttr(Args &&...args) {
attr_t attr_states = {-1, at::ScalarType::Undefined};
checkTensorSameWithSpecificAttr<attr>(attr_states, std::forward<Args>(args)...);
}
inline at::ScalarType str2TorchDtype(const std::string &type) {
static std::map<std::string, at::ScalarType> dtype_map = {
{"float", torch::kFloat32}, {"half", torch::kHalf}, {"bfloat16", torch::kBFloat16},
{"int32", torch::kInt32}, {"int8", torch::kInt8},
};
return dtype_map.at(type);
}
inline std::string &torchDtype2Str(const at::ScalarType type) {
static std::map<at::ScalarType, std::string> torch_dtype_map = {
{torch::kFloat32, "float"}, {torch::kHalf, "half"}, {torch::kBFloat16, "bfloat16"},
{torch::kInt32, "int32"}, {torch::kInt8, "int8"},
};
return torch_dtype_map.at(type);
}
inline cnnlDataType_t str2CnnlDtype(const std::string &type) {
static std::map<std::string, cnnlDataType_t> cnnl_dtype_map = {
{"float", CNNL_DTYPE_FLOAT}, {"half", CNNL_DTYPE_HALF}, {"bfloat16", CNNL_DTYPE_BFLOAT16},
{"int32", CNNL_DTYPE_INT32}, {"int8", CNNL_DTYPE_INT8},
};
return cnnl_dtype_map.at(type);
}
#define CHECK_SHAPE(x, ...) \
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
#x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_TENSOR_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.")
#define CHECK_OPTIONAL_TENSOR_CONTIGUOUS(x) \
if (x.has_value()) TORCH_CHECK(x.value().is_contiguous(), #x " must be contiguous.")
inline void *getAtTensorPtr(const c10::optional<at::Tensor> &tensor) {
return tensor.has_value() ? tensor.value().data_ptr() : nullptr;
}
inline void *getAtTensorPtr(const at::Tensor &tensor) {
return tensor.defined() ? tensor.data_ptr() : nullptr;
}
} // namespace torch_api
} // namespace tmo
#endif // CSRC_TORCH_API_UTILS_H_