forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
58
torch_mlu_ops-v1.3.2/csrc/torch_api/active.cpp
Normal file
58
torch_mlu_ops-v1.3.2/csrc/torch_api/active.cpp
Normal file
@@ -0,0 +1,58 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/moe/add_bias_activation.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
void active(const torch::Tensor &input,
|
||||
const torch::Tensor &output,
|
||||
const c10::optional<torch::Tensor> &bias,
|
||||
const c10::optional<torch::Tensor> &cusum_token_count,
|
||||
const std::string &act_mode,
|
||||
bool is_gated,
|
||||
int64_t start_expert_id,
|
||||
int64_t expert_size,
|
||||
double active_coef) {
|
||||
TORCH_CHECK(
|
||||
act_mode == "silu" || act_mode == "gelu" || act_mode == "quick_gelu" || act_mode == "swish",
|
||||
"act_mode must be 'silu', 'gelu', 'quick_gelu' or 'swish'.")
|
||||
cnnlActivationMode_t act_type = act_mode == "gelu" ? CNNL_ACTIVATION_GELU : CNNL_ACTIVATION_SWISH;
|
||||
if (act_mode == "quick_gelu") {
|
||||
active_coef = 1.702f;
|
||||
} else if (act_mode == "silu") {
|
||||
active_coef = 1.0f;
|
||||
}
|
||||
TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2")
|
||||
auto input_shape = input.sizes();
|
||||
int64_t in_channel = input_shape.back();
|
||||
TORCH_CHECK(in_channel > 0, "in_channel > 0")
|
||||
if (is_gated) {
|
||||
TORCH_CHECK(in_channel % 2 == 0, "in_channel % 2 == 0 if is_gated is true")
|
||||
}
|
||||
int64_t total_tokens = input.numel() / in_channel;
|
||||
int64_t inner_size = is_gated ? in_channel / 2 : in_channel;
|
||||
int64_t num_expert = cusum_token_count.has_value() ? (cusum_token_count.value().size(0) - 1) : 0;
|
||||
const torch_mlu::mlu::MLUGuard device_guard(input.device());
|
||||
int64_t output_stride = output.stride(-2);
|
||||
auto data_dtype = getCnnlDataType(input.scalar_type());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
tmo::invokeGroupAddBiasActivationKernel(
|
||||
queue, getAtTensorPtr(output), getAtTensorPtr(input), getAtTensorPtr(bias),
|
||||
(int *)getAtTensorPtr(cusum_token_count), num_expert, total_tokens, inner_size, output_stride,
|
||||
data_dtype, is_gated, act_type, start_expert_id, expert_size, active_coef);
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
145
torch_mlu_ops-v1.3.2/csrc/torch_api/apply_rotary.cpp
Normal file
145
torch_mlu_ops-v1.3.2/csrc/torch_api/apply_rotary.cpp
Normal file
@@ -0,0 +1,145 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/rotary_embedding.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
void apply_rotary(const torch::Tensor &input,
|
||||
const torch::Tensor &sin_cache,
|
||||
const torch::Tensor &cos_cache,
|
||||
const c10::optional<torch::Tensor> &position_ids,
|
||||
const c10::optional<torch::Tensor> &cu_seqlens,
|
||||
bool interleaved,
|
||||
bool discrete,
|
||||
bool dynamic_ntk,
|
||||
int64_t max_seqlen) {
|
||||
auto output = input;
|
||||
// 1.check device and tensor type
|
||||
checkTensorSameAttr<TensorAttr::ALL>(input, sin_cache, cos_cache);
|
||||
|
||||
const bool has_position_ids = position_ids.has_value();
|
||||
const bool has_cu_seqlens = cu_seqlens.has_value();
|
||||
const int origin_device_id = input.get_device();
|
||||
const void *position_ids_ptr = nullptr;
|
||||
const void *cu_seqlens_ptr = nullptr;
|
||||
|
||||
if (has_position_ids) {
|
||||
TORCH_CHECK(position_ids.value().dtype() == torch::kInt32,
|
||||
"position_ids type need be torch::kInt32");
|
||||
TORCH_CHECK(position_ids.value().get_device() == origin_device_id,
|
||||
"Tensor device index is not the same, original index: ", origin_device_id,
|
||||
"now index is: ", position_ids.value().get_device());
|
||||
position_ids_ptr = position_ids.value().data_ptr();
|
||||
}
|
||||
|
||||
if (has_cu_seqlens) {
|
||||
TORCH_CHECK(cu_seqlens.value().dtype() == torch::kInt32,
|
||||
"cu_seqlens type need be torch::kInt32");
|
||||
TORCH_CHECK(cu_seqlens.value().get_device() == origin_device_id,
|
||||
"Tensor device index is not the same, original index: ", origin_device_id,
|
||||
"now index is: ", cu_seqlens.value().get_device());
|
||||
cu_seqlens_ptr = cu_seqlens.value().data_ptr();
|
||||
}
|
||||
|
||||
// 2. check shape
|
||||
int total_seqlen = 0;
|
||||
int batch_size = 0;
|
||||
int head_size = input.size(-1);
|
||||
if (input.dim() == 3) { // pack mode
|
||||
TORCH_CHECK(has_cu_seqlens,
|
||||
"input has 3 dims: (total_seq_len, head_num, head_size),"
|
||||
" which means pack mode, cu_seqlens should not be None");
|
||||
total_seqlen = input.size(0);
|
||||
batch_size = cu_seqlens.value().size(0) - 1;
|
||||
} else if (input.dim() == 4) {
|
||||
TORCH_CHECK(!has_cu_seqlens,
|
||||
"input has 4 dims: (batch_size, seq_len, head_num, head_size),"
|
||||
" which means pad mode, cu_seqlens should be None");
|
||||
TORCH_CHECK(max_seqlen == input.size(1),
|
||||
"input has 4 dims: (batch_size, seq_len, head_num, head_size),"
|
||||
" which means pad mode, max_seqlen must be equtals to input.size(1)");
|
||||
batch_size = input.size(0);
|
||||
total_seqlen = batch_size * input.size(1);
|
||||
} else {
|
||||
TORCH_CHECK(false, "input only support 3 or 4 dims");
|
||||
}
|
||||
TORCH_CHECK(head_size <= 256, "only support input head_size <= 256");
|
||||
|
||||
const int rope_seqlen = dynamic_ntk ? sin_cache.size(1) : sin_cache.size(0);
|
||||
const int rope_dim = dynamic_ntk ? sin_cache.size(2) : sin_cache.size(1);
|
||||
|
||||
if (has_position_ids) {
|
||||
if (discrete) {
|
||||
CHECK_SHAPE(position_ids.value(), total_seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(position_ids.value(), batch_size);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(!discrete, "discrete must be false if position ids is null.")
|
||||
}
|
||||
|
||||
if (!(has_position_ids && discrete)) {
|
||||
TORCH_CHECK(max_seqlen <= rope_seqlen, "max_seqlen must less than or equal to rope_seqlen.")
|
||||
}
|
||||
|
||||
if (dynamic_ntk) {
|
||||
CHECK_SHAPE(sin_cache, batch_size, rope_seqlen, rope_dim);
|
||||
CHECK_SHAPE(cos_cache, batch_size, rope_seqlen, rope_dim);
|
||||
} else {
|
||||
CHECK_SHAPE(sin_cache, rope_seqlen, rope_dim);
|
||||
CHECK_SHAPE(cos_cache, rope_seqlen, rope_dim);
|
||||
}
|
||||
|
||||
// 3. check strides
|
||||
TORCH_CHECK(input.stride(-1) == 1, "input last dim must be contiguous");
|
||||
|
||||
if (dynamic_ntk) {
|
||||
TORCH_CHECK(sin_cache.stride(1) == cos_cache.stride(1),
|
||||
"sin_cache second stride must be equal to cos_cache second stride");
|
||||
} else {
|
||||
TORCH_CHECK(sin_cache.stride(0) == cos_cache.stride(0),
|
||||
"sin_cache first stride must be equal to cos_cache second stride");
|
||||
}
|
||||
|
||||
if (has_position_ids) {
|
||||
TORCH_CHECK(position_ids.value().is_contiguous(), "position_ids must be contiguous");
|
||||
}
|
||||
|
||||
if (has_cu_seqlens) {
|
||||
TORCH_CHECK(cu_seqlens.value().is_contiguous(), "cu_seqlens must be contiguous");
|
||||
}
|
||||
|
||||
// prepare inputs
|
||||
auto dims = input.dim();
|
||||
const int64_t num_heads = input.size(dims - 2);
|
||||
const int64_t head_dim = input.size(dims - 1);
|
||||
const int64_t input_seq_stride = input.strides()[dims - 3];
|
||||
const int64_t input_head_stride = input.strides()[dims - 2];
|
||||
const int64_t output_seq_stride = output.strides()[dims - 3];
|
||||
const int64_t output_head_stride = output.strides()[dims - 2];
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(input.device());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
auto data_type = getCnnlDataType(input.scalar_type());
|
||||
|
||||
invokeRotaryEmbedding(queue, output.data_ptr(), input.data_ptr(), sin_cache.data_ptr(),
|
||||
cos_cache.data_ptr(), (int *)position_ids_ptr, (int *)cu_seqlens_ptr,
|
||||
batch_size, max_seqlen, num_heads, head_dim, rope_seqlen, rope_dim,
|
||||
dynamic_ntk ? sin_cache.strides()[1] : sin_cache.strides()[0],
|
||||
input_seq_stride, input_head_stride, output_seq_stride, output_head_stride,
|
||||
interleaved, discrete, dynamic_ntk, data_type);
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
151
torch_mlu_ops-v1.3.2/csrc/torch_api/attn_proj.cpp
Normal file
151
torch_mlu_ops-v1.3.2/csrc/torch_api/attn_proj.cpp
Normal file
@@ -0,0 +1,151 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
std::vector<at::Tensor> attention_project(const at::Tensor &input,
|
||||
const at::Tensor &q_weight,
|
||||
const c10::optional<at::Tensor> &q_bias,
|
||||
const c10::optional<at::Tensor> &k_weight,
|
||||
const c10::optional<at::Tensor> &k_bias,
|
||||
const c10::optional<at::Tensor> &v_weight,
|
||||
const c10::optional<at::Tensor> &v_bias,
|
||||
const c10::optional<at::Tensor> &norm_weight,
|
||||
const c10::optional<at::Tensor> &norm_bias,
|
||||
const c10::optional<at::Tensor> &residual,
|
||||
const std::string &out_layout,
|
||||
int64_t head_size,
|
||||
double eps,
|
||||
double alpha,
|
||||
double beta,
|
||||
bool norm_out) {
|
||||
// check device and dtype
|
||||
checkTensorSameAttr<TensorAttr::ALL>(input, q_weight, q_bias, k_weight, k_bias, v_weight, v_bias,
|
||||
norm_weight, norm_bias, residual);
|
||||
|
||||
// check contiguous
|
||||
CHECK_TENSOR_CONTIGUOUS(input)
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(residual)
|
||||
|
||||
// check size
|
||||
const int64_t nDim = input.dim();
|
||||
at::Tensor input_view = input;
|
||||
if (nDim == 2) {
|
||||
input_view = input.unsqueeze(0);
|
||||
}
|
||||
|
||||
const bool has_k = k_weight.has_value();
|
||||
const bool has_v = v_weight.has_value();
|
||||
const int64_t n = input_view.size(0);
|
||||
const int64_t t = input_view.size(1);
|
||||
const int64_t hidden_size_q = q_weight.size(0);
|
||||
const int64_t hidden_size_k = has_k ? k_weight.value().size(0) : 0;
|
||||
const int64_t hidden_size_v = has_v ? v_weight.value().size(0) : 0;
|
||||
// check params
|
||||
bool has_bias = q_bias.has_value();
|
||||
bool has_ln = norm_weight.has_value();
|
||||
bool has_residual = residual.has_value();
|
||||
TORCH_CHECK(!(has_ln && has_residual), "cannot support layernorm and residual at the same time.")
|
||||
TORCH_CHECK(out_layout == "nthc" || (out_layout == "nhtc" && nDim == 3),
|
||||
"input must be 3-D if out_layout is 'nhtc'")
|
||||
bool trans_out = (out_layout == "nhtc");
|
||||
const int64_t head_num_q = hidden_size_q / head_size;
|
||||
const int64_t head_num_k = hidden_size_k / head_size;
|
||||
const int64_t head_num_v = hidden_size_v / head_size;
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(input_view.device());
|
||||
auto q_shape = trans_out ? std::vector<int64_t>({n, head_num_q, t, head_size})
|
||||
: std::vector<int64_t>({n, t, hidden_size_q});
|
||||
auto k_shape = trans_out ? std::vector<int64_t>({n, head_num_k, t, head_size})
|
||||
: std::vector<int64_t>({n, t, hidden_size_k});
|
||||
auto v_shape = trans_out ? std::vector<int64_t>({n, head_num_v, t, head_size})
|
||||
: std::vector<int64_t>({n, t, hidden_size_v});
|
||||
auto out_q = at::empty(q_shape, input_view.options());
|
||||
auto out_k = has_k ? at::empty(k_shape, input_view.options()) : at::Tensor();
|
||||
auto out_v = has_v ? at::empty(q_shape, input_view.options()) : at::Tensor();
|
||||
auto out_ln = norm_out ? at::empty(input.sizes(), input_view.options()) : at::Tensor();
|
||||
|
||||
// create tensor descs
|
||||
auto descs =
|
||||
createTensorDescs({input_view, q_weight, q_bias.value_or(at::Tensor()),
|
||||
k_weight.value_or(at::Tensor()), k_bias.value_or(at::Tensor()),
|
||||
v_weight.value_or(at::Tensor()), v_bias.value_or(at::Tensor()),
|
||||
norm_weight.value_or(at::Tensor()), norm_bias.value_or(at::Tensor()),
|
||||
residual.value_or(at::Tensor()), out_q, out_k, out_v, out_ln});
|
||||
|
||||
// create and set attn_proj_desc
|
||||
cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode =
|
||||
has_ln ? CNNL_TRANSFORMER_PRE_LAYERNORM_NO_RESIDUAL
|
||||
: (has_residual ? CNNL_TRANSFORMER_NO_LAYERNORM_WITH_RESIDUAL
|
||||
: CNNL_TRANSFORMER_NO_LAYERNORM_NO_RESIDUAL);
|
||||
auto compute_dtype = getCnnlDataType(input_view.scalar_type());
|
||||
auto attn_proj_desc = tmo::op_desc::AttnProjDesc();
|
||||
attn_proj_desc.setDesc(layernorm_residual_mode, /*layernorm_residual_mode*/
|
||||
compute_dtype, true, /*has_q*/
|
||||
has_k, /*has_k*/
|
||||
has_v, /*has_v*/
|
||||
has_bias, false, 0, /*no packed*/
|
||||
trans_out, /*trans_out*/
|
||||
norm_out, /*store_layernorm out*/
|
||||
alpha, beta, /*alpha && beta */
|
||||
eps /*eps*/);
|
||||
|
||||
auto handle = torch_mlu::getCurrentHandle();
|
||||
// get workspace size
|
||||
size_t workspace_size = 0;
|
||||
CNNL_CHECK_FATAL(cnnlGetTransformerAttnProjWorkspaceSize(handle, attn_proj_desc, nullptr,
|
||||
descs[0].get(), descs[1].get(),
|
||||
descs[10].get(), &workspace_size));
|
||||
auto workspace =
|
||||
at::empty({static_cast<int64_t>(workspace_size)}, input.options().dtype(at::kByte));
|
||||
|
||||
// run forward
|
||||
CNNL_CHECK_FATAL(cnnlTransformerAttnProj(
|
||||
handle, attn_proj_desc, nullptr, /*quant_desc*/
|
||||
descs[0].get(), getAtTensorPtr(input_view), /* input */
|
||||
descs[9].get(), getAtTensorPtr(residual), /* residual */
|
||||
descs[1].get(), getAtTensorPtr(q_weight), /* q weight */
|
||||
descs[3].get(), getAtTensorPtr(k_weight), /* k weight */
|
||||
descs[5].get(), getAtTensorPtr(v_weight), /* v weight */
|
||||
descs[2].get(), getAtTensorPtr(q_bias), /* q bias */
|
||||
descs[4].get(), getAtTensorPtr(k_bias), /* k bias */
|
||||
descs[6].get(), getAtTensorPtr(v_bias), /* v bias */
|
||||
nullptr, nullptr, /*no valid token*/
|
||||
descs[7].get(), getAtTensorPtr(norm_weight), /* layernorm weight */
|
||||
descs[8].get(), getAtTensorPtr(norm_bias), /* layernorm bias */
|
||||
getAtTensorPtr(workspace), workspace_size, /* workspace */
|
||||
descs[10].get(), getAtTensorPtr(out_q), /* q out */
|
||||
descs[11].get(), getAtTensorPtr(out_k), /* k out */
|
||||
descs[12].get(), getAtTensorPtr(out_v), /* v out */
|
||||
descs[13].get(), getAtTensorPtr(out_ln) /* layernorm out */
|
||||
));
|
||||
|
||||
// return
|
||||
if (nDim == 2) {
|
||||
out_q.squeeze_(0);
|
||||
if (has_k) out_k.squeeze_(0);
|
||||
if (has_v) out_v.squeeze_(0);
|
||||
if (norm_out) out_ln.squeeze_(0);
|
||||
}
|
||||
std::vector<at::Tensor> output_list;
|
||||
output_list.emplace_back(out_q);
|
||||
if (has_k) output_list.emplace_back(out_k);
|
||||
if (has_v) output_list.emplace_back(out_v);
|
||||
if (norm_out) output_list.emplace_back(out_ln);
|
||||
return output_list;
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
91
torch_mlu_ops-v1.3.2/csrc/torch_api/batch_matmul.cpp
Normal file
91
torch_mlu_ops-v1.3.2/csrc/torch_api/batch_matmul.cpp
Normal file
@@ -0,0 +1,91 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "torch_ops_api.h"
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
void batch_matmul(const at::Tensor &a,
|
||||
const at::Tensor &b,
|
||||
const at::Tensor &c,
|
||||
double alpha,
|
||||
double beta,
|
||||
double a_scale,
|
||||
double b_scale,
|
||||
bool trans_a,
|
||||
bool trans_b) {
|
||||
bool use_beta = beta == 0 ? false : true;
|
||||
int batch = a.size(0);
|
||||
int m = trans_a ? a.size(-1) : a.size(-2);
|
||||
int n = trans_b ? b.size(-2) : b.size(-1);
|
||||
checkTensorSameAttr<TensorAttr::DEVICE>(a, b, c);
|
||||
// check contiguous
|
||||
TORCH_CHECK(a.is_contiguous(), "a must be contiguous.")
|
||||
TORCH_CHECK(b.is_contiguous(), "b must be contiguous.")
|
||||
TORCH_CHECK(c.is_contiguous(), "c must be contiguous.")
|
||||
CHECK_SHAPE(c, batch, m, n);
|
||||
|
||||
// get cnnl data type and init output
|
||||
auto a_dtype = getCnnlDataType(a.scalar_type());
|
||||
auto b_dtype = getCnnlDataType(b.scalar_type());
|
||||
TORCH_CHECK(a_dtype == b_dtype, "a, b must be same dtype.");
|
||||
auto c_dtype = getCnnlDataType(c.scalar_type());
|
||||
if (a_dtype == CNNL_DTYPE_BFLOAT16) c_dtype = CNNL_DTYPE_FLOAT;
|
||||
// create tensor desc
|
||||
auto descs = createTensorDescs({a, b, c});
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[0].get(), a_dtype));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[1].get(), b_dtype));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[2].get(), c_dtype));
|
||||
|
||||
if (a_dtype == CNNL_DTYPE_INT8) {
|
||||
float int_max = 127.0;
|
||||
int quant_bit = 8;
|
||||
float max_a = int_max / a_scale;
|
||||
float max_b = int_max / b_scale;
|
||||
int pos_a = std::floor(std::log2(max_a) - (quant_bit - 2));
|
||||
int pos_b = std::floor(std::log2(max_b) - (quant_bit - 2));
|
||||
float new_a_scale = std::pow(2.0f, pos_a) * a_scale;
|
||||
float new_b_scale = std::pow(2.0f, pos_b) * b_scale;
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[0].get(), pos_a, new_a_scale));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[1].get(), pos_b, new_b_scale));
|
||||
TORCH_CHECK(c_dtype != CNNL_DTYPE_BFLOAT16,
|
||||
"output dtype cannot be bfloat16 when a/b is fixed-point")
|
||||
}
|
||||
|
||||
// get && set Op desc
|
||||
cnnlMatMulHeuristicResult_t heuristic_result;
|
||||
cnnlMatMulAlgo_t algo;
|
||||
auto bmm_ex_desc =
|
||||
tmo::op_desc::BatchMatMulDesc(heuristic_result, algo, use_beta, trans_a, trans_b);
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(a.device());
|
||||
auto handle = torch_mlu::getCurrentHandle();
|
||||
|
||||
size_t workspace_size = 0;
|
||||
int requested_algo_count = 1;
|
||||
int returned_algo_count = 0;
|
||||
CNNL_CHECK_FATAL(cnnlGetBatchMatMulExAlgoHeuristic(
|
||||
handle, bmm_ex_desc, descs[0].get(), descs[1].get(), descs[2].get(), nullptr,
|
||||
requested_algo_count, &heuristic_result, &returned_algo_count));
|
||||
CNNL_CHECK_FATAL(cnnlGetBatchMatMulExHeuristicResult(heuristic_result, algo, &workspace_size));
|
||||
auto workspace = at::empty({static_cast<int64_t>(workspace_size)}, a.options().dtype(at::kByte));
|
||||
|
||||
// run forward
|
||||
float alpha_f = alpha;
|
||||
float beta_f = beta;
|
||||
CNNL_CHECK_FATAL(cnnlBatchMatMulEx(handle, bmm_ex_desc, algo, &alpha_f, descs[0].get(),
|
||||
getAtTensorPtr(a), descs[1].get(), getAtTensorPtr(b), &beta_f,
|
||||
descs[2].get(), getAtTensorPtr(c), getAtTensorPtr(workspace),
|
||||
workspace_size));
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
45
torch_mlu_ops-v1.3.2/csrc/torch_api/cnpx.cpp
Normal file
45
torch_mlu_ops-v1.3.2/csrc/torch_api/cnpx.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "cnpx.h"
|
||||
#include <sstream>
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
static cnpxDomainHandle_t domain = cnpxDomainCreate("CNPERF_KERNEL_TMO");
|
||||
static bool cnperf_kernel_analysis = getenv("CNPERF_KERNEL_ANALYSIS");
|
||||
|
||||
void cnpxPush(const OpTheory &op) {
|
||||
if (cnperf_kernel_analysis) {
|
||||
size_t calc = op.getTheoryCalc();
|
||||
cnnlDataType_t calc_dtype = op.getCalcDtype();
|
||||
size_t io = op.getTheoryIO();
|
||||
auto op_name = op.getOpName();
|
||||
std::ostringstream jsonStream;
|
||||
jsonStream << "{\"name\":\"" << op_name << "\", \"theo_calc\":" << calc
|
||||
<< ", \"theo_bytes\":" << io << ", \"calc_type\":" << calc_dtype << "}";
|
||||
std::string json = jsonStream.str();
|
||||
// std::cout << json.c_str() << std::endl;
|
||||
cnpxDomainRangePush(domain, json.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void cnpxPop() {
|
||||
if (cnperf_kernel_analysis) {
|
||||
cnpxDomainRangePop(domain);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,39 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "comm_overlap.h"
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
#define CNCL_TYPE_AND_SCALAR_TYPE(_) \
|
||||
_(cnclFloat, at::kFloat) \
|
||||
_(cnclBfloat16, at::kBFloat16) \
|
||||
_(cnclHalf, at::kHalf) \
|
||||
_(cnclInt32, at::kInt) \
|
||||
_(cnclInt64, at::kLong) \
|
||||
_(cnclInt8, at::kChar) \
|
||||
_(cnclUint8, at::kByte) \
|
||||
_(cnclInt16, at::kShort)
|
||||
|
||||
cnclDataType_t getCnclDataType(const at::ScalarType &data_type) {
|
||||
switch (data_type) {
|
||||
#define DEFINE_CASE(cncl_dtype, scalar_type) \
|
||||
case scalar_type: \
|
||||
return cncl_dtype;
|
||||
|
||||
CNCL_TYPE_AND_SCALAR_TYPE(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
default:
|
||||
std::string msg("getCnclDataType() not supported for ");
|
||||
throw std::runtime_error(msg + c10::toString(data_type));
|
||||
}
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,69 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_TORCH_API_COMPUTE_ALLREDUCE_COMM_OVERLAP_H_
|
||||
#define CSRC_TORCH_API_COMPUTE_ALLREDUCE_COMM_OVERLAP_H_
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
#include "cncl.h"
|
||||
#include "framework/core/MLUEvent.h"
|
||||
#include "framework/core/MLUStream.h"
|
||||
#include "torch/extension.h"
|
||||
#include "torch_api/utils.h"
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
cnclDataType_t getCnclDataType(const at::ScalarType &data_type);
|
||||
|
||||
template <typename MmOp>
|
||||
struct ParallelAllReduce {
|
||||
static std::vector<torch_mlu::MLUEvent> event_;
|
||||
static std::optional<torch_mlu::MLUStream> stream_;
|
||||
cnclComm_t cncl_comm_;
|
||||
std::vector<at::Tensor> d_list_;
|
||||
|
||||
ParallelAllReduce(int64_t cncl_comm) { cncl_comm_ = reinterpret_cast<cnclComm_t>(cncl_comm); }
|
||||
|
||||
at::Tensor operator()(MmOp &mm) {
|
||||
auto compute_stream_ = torch_mlu::getCurrentMLUStream();
|
||||
if (!stream_.has_value()) {
|
||||
stream_.emplace(torch_mlu::getStreamFromPool());
|
||||
}
|
||||
auto comm_stream_ = stream_.value();
|
||||
uint64_t loop = mm.getLoopNum();
|
||||
if (event_.size() < loop + 1) {
|
||||
event_.resize(loop + 1);
|
||||
}
|
||||
for (uint64_t i = 0; i < loop; i++) {
|
||||
d_list_.push_back(mm.forward(i));
|
||||
event_[i].place(compute_stream_);
|
||||
event_[i].wait(comm_stream_);
|
||||
// reduce_sum d_send
|
||||
CNCL_CHECK(cnclAllReduce(getAtTensorPtr(d_list_[i]), getAtTensorPtr(d_list_[i]),
|
||||
d_list_[i].numel(), getCnclDataType(d_list_[i].scalar_type()),
|
||||
cnclSum, cncl_comm_, comm_stream_.stream()));
|
||||
}
|
||||
event_[loop].place(comm_stream_);
|
||||
event_[loop].wait(compute_stream_);
|
||||
return mm.getOutput();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename MmOp>
|
||||
std::vector<torch_mlu::MLUEvent> ParallelAllReduce<MmOp>::event_;
|
||||
|
||||
template <typename MmOp>
|
||||
std::optional<torch_mlu::MLUStream> ParallelAllReduce<MmOp>::stream_;
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
#endif // CSRC_TORCH_API_COMPUTE_ALLREDUCE_COMM_OVERLAP_H_
|
||||
@@ -0,0 +1,180 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "comm_overlap.h"
|
||||
#include "torch_api/torch_ops_api.h"
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
using namespace torch::indexing;
|
||||
|
||||
struct FlashAttnSmoothQuantMatmul {
|
||||
at::Tensor q_;
|
||||
at::Tensor k_;
|
||||
at::Tensor v_;
|
||||
const c10::optional<at::Tensor> &cu_seq_lens_q_;
|
||||
const c10::optional<at::Tensor> &cu_seq_lens_kv_;
|
||||
const at::Tensor &smooth_;
|
||||
const at::Tensor &weight_;
|
||||
const at::Tensor &weight_scale_;
|
||||
const c10::optional<at::Tensor> &bias_;
|
||||
const std::string &compute_dtype_;
|
||||
std::string input_dtype_str_;
|
||||
int64_t max_seq_len_q_;
|
||||
int64_t max_seq_len_kv_;
|
||||
double softmax_scale_;
|
||||
bool is_causal_;
|
||||
int64_t block_seq_;
|
||||
std::vector<int> cumsum_seq_;
|
||||
at::Tensor output_;
|
||||
|
||||
FlashAttnSmoothQuantMatmul(const at::Tensor &q,
|
||||
const at::Tensor &k,
|
||||
const at::Tensor &v,
|
||||
const c10::optional<at::Tensor> &cu_seq_lens_q,
|
||||
const c10::optional<at::Tensor> &cu_seq_lens_kv,
|
||||
const at::Tensor &smooth,
|
||||
const at::Tensor &weight,
|
||||
const at::Tensor &weight_scale,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const int64_t max_seq_len_q,
|
||||
const int64_t max_seq_len_kv,
|
||||
const double softmax_scale,
|
||||
const bool is_causal,
|
||||
const std::string &compute_dtype,
|
||||
const int64_t block_seq)
|
||||
: q_(q),
|
||||
k_(k),
|
||||
v_(v),
|
||||
cu_seq_lens_q_(cu_seq_lens_q),
|
||||
cu_seq_lens_kv_(cu_seq_lens_kv),
|
||||
smooth_(smooth),
|
||||
weight_(weight),
|
||||
weight_scale_(weight_scale),
|
||||
bias_(bias),
|
||||
compute_dtype_(compute_dtype),
|
||||
max_seq_len_q_(max_seq_len_q),
|
||||
max_seq_len_kv_(max_seq_len_kv),
|
||||
softmax_scale_(softmax_scale),
|
||||
is_causal_(is_causal) {
|
||||
input_dtype_str_ = q.scalar_type() == at::kBFloat16 ? "bfloat16"
|
||||
: q.scalar_type() == at::kHalf ? "half"
|
||||
: "float";
|
||||
bool is_pack = cu_seq_lens_q.has_value();
|
||||
if (is_pack) {
|
||||
TORCH_CHECK(cu_seq_lens_q.value().size(0) == 2 && cu_seq_lens_kv.value().size(0) == 2,
|
||||
"only support 1 batch.")
|
||||
TORCH_CHECK(q.dim() == 3 && k.dim() == 3 && v.dim() == 3, "q,k,v must be 3-d in pack mode.")
|
||||
// 1 batch pack to pad
|
||||
q_ = q_.unsqueeze(0);
|
||||
k_ = k_.unsqueeze(0);
|
||||
v_ = v_.unsqueeze(0);
|
||||
} else {
|
||||
TORCH_CHECK(q.size(0) == 1, "only support 1 batch.")
|
||||
TORCH_CHECK(q.dim() == 4, "q must be 4-d in pad mode.")
|
||||
}
|
||||
int total_seq_q = q_.size(0) * q_.size(1);
|
||||
output_ = at::empty({total_seq_q, weight.size(0)}, q.options());
|
||||
if (total_seq_q >= 4096) {
|
||||
block_seq_ = 4;
|
||||
split_4(total_seq_q);
|
||||
} else {
|
||||
block_seq_ = 1;
|
||||
split_1(total_seq_q);
|
||||
}
|
||||
}
|
||||
|
||||
void split_1(int seq) {
|
||||
cumsum_seq_.resize(2);
|
||||
cumsum_seq_[0] = 0;
|
||||
cumsum_seq_[1] = seq;
|
||||
}
|
||||
|
||||
void split_4(int seq) {
|
||||
auto pad_up = [](int x, int y) -> int { return (x + y - 1) / y * y; };
|
||||
int seq_4 = pad_up(seq / 8, 256);
|
||||
int seq_3 = pad_up(seq / 8, 256);
|
||||
int seq_2 = pad_up(seq / 4, 256);
|
||||
int seq_1 = seq - seq_2 - seq_3 - seq_4;
|
||||
cumsum_seq_.resize(5);
|
||||
cumsum_seq_[0] = 0;
|
||||
cumsum_seq_[1] = cumsum_seq_[0] + seq_1;
|
||||
cumsum_seq_[2] = cumsum_seq_[1] + seq_2;
|
||||
cumsum_seq_[3] = cumsum_seq_[2] + seq_3;
|
||||
cumsum_seq_[4] = cumsum_seq_[3] + seq_4;
|
||||
}
|
||||
|
||||
auto split_tensor(const at::Tensor &a, int64_t dim, int64_t start, int64_t end) {
|
||||
return a.narrow(dim, start, end - start);
|
||||
}
|
||||
|
||||
at::Tensor forward(const int64_t block_id) {
|
||||
// flash attn
|
||||
int64_t end_kv = is_causal_ ? cumsum_seq_[block_id + 1] : cumsum_seq_[block_seq_];
|
||||
auto q_i = split_tensor(q_, 1, cumsum_seq_[block_id], cumsum_seq_[block_id + 1]);
|
||||
auto k_i = split_tensor(k_, 1, 0, end_kv);
|
||||
auto v_i = split_tensor(v_, 1, 0, end_kv);
|
||||
auto attn_out = at::empty_like(q_i);
|
||||
flash_attention(q_i, k_i, v_i, attn_out, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt,
|
||||
c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt, q_i.size(1),
|
||||
k_i.size(1), softmax_scale_, is_causal_, -1, -1, compute_dtype_, false);
|
||||
// smooth quant
|
||||
auto smooth_quant_input =
|
||||
attn_out.flatten(-2, -1).flatten(0, 1); // (b, s, hn, hs) -> (b*s, hn*hs)
|
||||
auto quant_out =
|
||||
at::empty(smooth_quant_input.sizes(),
|
||||
torch::TensorOptions().dtype(torch::kInt8).device(smooth_quant_input.device()));
|
||||
auto quant_out_scale = at::empty({smooth_quant_input.size(0)}, smooth_.options());
|
||||
smooth_quant(smooth_quant_input, smooth_, quant_out, quant_out_scale, c10::nullopt,
|
||||
c10::nullopt, c10::nullopt, c10::nullopt, "per_token", true);
|
||||
// quant matmul
|
||||
auto d_i = quant_matmul(
|
||||
quant_out, quant_out_scale, c10::nullopt, weight_, weight_scale_, c10::nullopt, bias_,
|
||||
c10::nullopt, c10::nullopt, c10::nullopt, weight_scale_, c10::nullopt, input_dtype_str_,
|
||||
split_tensor(output_, 0, cumsum_seq_[block_id], cumsum_seq_[block_id + 1]), "smooth_quant",
|
||||
"quantize_per_token", "quantize_per_channel", 8, "none", false, 1.0, 1.0, 1.0, false, true);
|
||||
return d_i;
|
||||
}
|
||||
|
||||
int64_t getLoopNum() const { return block_seq_; }
|
||||
at::Tensor getOutput() const { return output_; }
|
||||
};
|
||||
|
||||
at::Tensor flash_attn_sq_mm_allreduce(const int64_t cncl_comm,
|
||||
const at::Tensor &q,
|
||||
const at::Tensor &k,
|
||||
const at::Tensor &v,
|
||||
const c10::optional<at::Tensor> &cu_seq_lens_q,
|
||||
const c10::optional<at::Tensor> &cu_seq_lens_kv,
|
||||
const c10::optional<at::Tensor> &alibi_slope,
|
||||
const c10::optional<at::Tensor> &attn_bias,
|
||||
const at::Tensor &smooth,
|
||||
const at::Tensor &weight,
|
||||
const at::Tensor &weight_scale,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const int64_t max_seq_len_q,
|
||||
const int64_t max_seq_len_kv,
|
||||
const double softmax_scale,
|
||||
const bool is_causal,
|
||||
const int64_t window_size_left,
|
||||
const int64_t window_size_right,
|
||||
const std::string &compute_dtype,
|
||||
const int64_t block_seq) {
|
||||
FlashAttnSmoothQuantMatmul mm(q, k, v, cu_seq_lens_q, cu_seq_lens_kv, smooth, weight,
|
||||
weight_scale, bias, max_seq_len_q, max_seq_len_kv, softmax_scale,
|
||||
is_causal, compute_dtype, block_seq);
|
||||
ParallelAllReduce<FlashAttnSmoothQuantMatmul> parallel_rs(cncl_comm);
|
||||
return parallel_rs(mm);
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,192 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "comm_overlap.h"
|
||||
#include "kernels/moe/combine_result.mluh"
|
||||
#include "torch_api/torch_ops_api.h"
|
||||
#include "torch_api/utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
using namespace torch::indexing;
|
||||
|
||||
struct GroupGemmCombineResultSplitN {
|
||||
const at::Tensor &a_;
|
||||
const at::Tensor &b_;
|
||||
const at::Tensor &m_list_;
|
||||
const at::Tensor &combine_idx_;
|
||||
const at::Tensor &combine_weight_;
|
||||
const c10::optional<at::Tensor> &c_;
|
||||
const c10::optional<at::Tensor> &alpha_;
|
||||
const c10::optional<at::Tensor> &beta_;
|
||||
const c10::optional<at::Tensor> &a_scale_;
|
||||
const c10::optional<at::Tensor> &b_scale_;
|
||||
const c10::optional<std::string> &data_type_;
|
||||
int64_t num_token_;
|
||||
int64_t topk_;
|
||||
int64_t block_n_;
|
||||
int64_t expert_num_;
|
||||
int64_t n_;
|
||||
bool has_c_;
|
||||
bool has_alpha_;
|
||||
bool has_beta_;
|
||||
bool has_a_scale_;
|
||||
bool has_b_scale_;
|
||||
bool has_data_type_;
|
||||
int64_t n_per_blk_;
|
||||
cnrtQueue_t queue_;
|
||||
at::Tensor output_buff_;
|
||||
at::Tensor b_scale_trans_;
|
||||
c10::optional<at::Tensor> b_offset_ = c10::nullopt;
|
||||
cnnlDataType_t cnnl_dtype_;
|
||||
|
||||
GroupGemmCombineResultSplitN(const at::Tensor &a_tensor,
|
||||
const at::Tensor &b_tensor,
|
||||
const at::Tensor &m_list,
|
||||
const at::Tensor &combine_idx,
|
||||
const at::Tensor &combine_weight,
|
||||
const c10::optional<at::Tensor> &c_tensor,
|
||||
const c10::optional<at::Tensor> &alpha,
|
||||
const c10::optional<at::Tensor> &beta,
|
||||
const c10::optional<at::Tensor> &a_scale,
|
||||
const c10::optional<at::Tensor> &b_scale,
|
||||
const c10::optional<std::string> &data_type,
|
||||
const int64_t num_token,
|
||||
const int64_t topk,
|
||||
const int64_t block_n)
|
||||
: a_(a_tensor),
|
||||
b_(b_tensor),
|
||||
m_list_(m_list),
|
||||
combine_idx_(combine_idx),
|
||||
combine_weight_(combine_weight),
|
||||
c_(c_tensor),
|
||||
alpha_(alpha),
|
||||
beta_(beta),
|
||||
a_scale_(a_scale),
|
||||
b_scale_(b_scale),
|
||||
data_type_(data_type),
|
||||
num_token_(num_token),
|
||||
topk_(topk),
|
||||
block_n_(block_n) {
|
||||
has_c_ = c_.has_value();
|
||||
has_alpha_ = alpha_.has_value();
|
||||
has_beta_ = beta_.has_value();
|
||||
has_a_scale_ = a_scale_.has_value();
|
||||
has_b_scale_ = b_scale_.has_value();
|
||||
has_data_type_ = data_type_.has_value();
|
||||
auto b_shape = b_.sizes();
|
||||
expert_num_ = b_.dim() == 3 ? b_shape[0] : b_shape[0] * b_shape[2];
|
||||
n_ = b_shape[1];
|
||||
auto k = b_.size(-1);
|
||||
auto m = a_.size(0);
|
||||
set_block(m);
|
||||
|
||||
queue_ = torch_mlu::getCurMLUStream();
|
||||
auto a_options = a_.options();
|
||||
if (has_data_type_) {
|
||||
auto dtype = data_type_.value();
|
||||
TORCH_CHECK(dtype == "float" || dtype == "half" || dtype == "bfloat16",
|
||||
"data type must be 'float', 'half' or 'bfloat16'");
|
||||
auto torch_dtype = str2TorchDtype(dtype);
|
||||
cnnl_dtype_ = str2CnnlDtype(dtype);
|
||||
output_buff_ = at::empty({block_n_, num_token_, n_per_blk_}, a_options.dtype(torch_dtype));
|
||||
} else {
|
||||
output_buff_ = at::empty({block_n_, num_token_, n_per_blk_}, a_options);
|
||||
cnnl_dtype_ = getCnnlDataType(a_.scalar_type());
|
||||
}
|
||||
if (has_b_scale_) {
|
||||
b_scale_trans_ = b_scale_.value().view({expert_num_, block_n_, -1});
|
||||
b_scale_trans_ = b_scale_trans_.transpose(0, 1).contiguous();
|
||||
}
|
||||
|
||||
auto b_offset = at::empty({expert_num_}, a_options.dtype(at::kLong).device(at::kCPU));
|
||||
auto element_size = b_.element_size();
|
||||
if (block_n_ > 1 && b_.dim() == 3) {
|
||||
for (int64_t i = 0; i < expert_num_; i++) {
|
||||
b_offset[i] = n_ * k * i * element_size;
|
||||
}
|
||||
b_offset_ = b_offset;
|
||||
} else if (b_.dim() == 4) {
|
||||
for (int64_t i = 0; i < expert_num_; i++) {
|
||||
b_offset[i] = (i / b_shape[2] * b_.stride(0) + i % b_shape[2] * b_shape[3]) * element_size;
|
||||
}
|
||||
b_offset_ = b_offset;
|
||||
}
|
||||
}
|
||||
|
||||
auto split(const at::Tensor &a, int64_t dim, int64_t start, int64_t block) {
|
||||
return a.narrow(dim, start, block);
|
||||
}
|
||||
|
||||
at::Tensor forward(const int64_t block_id) {
|
||||
auto gg_o =
|
||||
group_gemm(a_, split(b_, 1, block_id * n_per_blk_, n_per_blk_), m_list_, c10::nullopt,
|
||||
(has_c_) ? split(c_.value(), 1, block_id * n_per_blk_, n_per_blk_) : c_, alpha_,
|
||||
beta_, a_scale_, has_b_scale_ ? b_scale_trans_[block_id] : b_scale_,
|
||||
c10::nullopt, data_type_, c10::nullopt, b_offset_, num_token_);
|
||||
|
||||
tmo::invokeMoeCombineResultKernel(
|
||||
queue_, getAtTensorPtr(output_buff_[block_id]), getAtTensorPtr(gg_o), nullptr, nullptr,
|
||||
(float *)getAtTensorPtr(combine_weight_), nullptr, (int *)getAtTensorPtr(combine_idx_),
|
||||
num_token_, topk_, expert_num_, n_per_blk_, 0, expert_num_, cnnl_dtype_);
|
||||
return output_buff_[block_id];
|
||||
}
|
||||
|
||||
void set_block(int64_t m) {
|
||||
if (block_n_ < 1) {
|
||||
if (m >= 4096 && m < 8192 && n_ % 2048 == 0) {
|
||||
n_per_blk_ = 2048;
|
||||
block_n_ = n_ / n_per_blk_;
|
||||
} else if (m >= 8192 && n_ % 1024 == 0) {
|
||||
n_per_blk_ = 1024;
|
||||
block_n_ = n_ / n_per_blk_;
|
||||
} else {
|
||||
n_per_blk_ = n_;
|
||||
block_n_ = 1;
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(n_ % block_n_ == 0, "n must be divisible by block_n");
|
||||
n_per_blk_ = n_ / block_n_;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t getLoopNum() const { return block_n_; }
|
||||
at::Tensor getOutput() const { return output_buff_.transpose(0, 1).reshape({num_token_, n_}); }
|
||||
};
|
||||
|
||||
at::Tensor group_gemm_combine_result_allreduce(int64_t cncl_comm,
|
||||
const at::Tensor &a_tensor,
|
||||
const at::Tensor &b_tensor,
|
||||
const at::Tensor &m_list,
|
||||
const at::Tensor &combine_idx,
|
||||
const at::Tensor &combine_weight,
|
||||
const c10::optional<at::Tensor> &c_tensor,
|
||||
const c10::optional<at::Tensor> &alpha,
|
||||
const c10::optional<at::Tensor> &beta,
|
||||
const c10::optional<at::Tensor> &a_scale,
|
||||
const c10::optional<at::Tensor> &b_scale,
|
||||
const c10::optional<std::string> &data_type,
|
||||
const int64_t num_token,
|
||||
const int64_t topk,
|
||||
const int64_t block_n) {
|
||||
GroupGemmCombineResultSplitN gg(a_tensor, b_tensor, m_list, combine_idx, combine_weight, c_tensor,
|
||||
alpha, beta, a_scale, b_scale, data_type, num_token, topk,
|
||||
block_n);
|
||||
ParallelAllReduce<GroupGemmCombineResultSplitN> parallel_rs(cncl_comm);
|
||||
return parallel_rs(gg);
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,94 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "comm_overlap.h"
|
||||
#include "torch_api/torch_ops_api.h"
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
using namespace torch::indexing;
|
||||
|
||||
struct MatmulSplitM {
|
||||
const at::Tensor &a_;
|
||||
const at::Tensor &b_;
|
||||
const c10::optional<at::Tensor> &bias_;
|
||||
const c10::optional<at::Tensor> &c_;
|
||||
const c10::optional<at::Tensor> &d_;
|
||||
double alpha_;
|
||||
double beta_;
|
||||
int64_t block_m_;
|
||||
bool has_res_;
|
||||
bool has_output_;
|
||||
at::Tensor output_;
|
||||
|
||||
MatmulSplitM(const at::Tensor &a,
|
||||
const at::Tensor &b,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<at::Tensor> &c,
|
||||
const c10::optional<at::Tensor> &d,
|
||||
const double alpha,
|
||||
const double beta,
|
||||
const int64_t block_m)
|
||||
: a_(a), b_(b), bias_(bias), c_(c), d_(d), alpha_(alpha), beta_(beta), block_m_(block_m) {
|
||||
auto m = a_.size(0);
|
||||
if (block_m_ > m) {
|
||||
block_m_ = 1;
|
||||
} else if (block_m_ < 1) {
|
||||
block_m_ = m > 8192 ? 8 : (m < 2048 ? 1 : 4);
|
||||
}
|
||||
|
||||
has_res_ = c_.has_value();
|
||||
has_output_ = d_.has_value();
|
||||
if (has_output_) {
|
||||
output_ = d_.value();
|
||||
} else {
|
||||
output_ = at::empty({a.size(0), b.size(0)}, a.options());
|
||||
}
|
||||
}
|
||||
|
||||
auto split(const at::Tensor &a, const int64_t block_id) {
|
||||
auto m = a.size(0);
|
||||
auto m_per_blk = m / block_m_;
|
||||
auto remain = m % block_m_;
|
||||
auto start = block_id * m_per_blk + std::min(block_id, remain);
|
||||
auto end = (block_id + 1) * m_per_blk + std::min(block_id + 1, remain);
|
||||
return a.narrow(0, start, end - start);
|
||||
}
|
||||
|
||||
at::Tensor forward(const int64_t block_id) {
|
||||
auto Di = split(output_, block_id);
|
||||
matmul(split(a_, block_id), b_, Di, bias_, has_res_ ? split(c_.value(), block_id) : c_, None,
|
||||
"none", alpha_, beta_, true, true, 1.0, 1.0, false, true);
|
||||
return Di;
|
||||
}
|
||||
|
||||
int64_t getLoopNum() const { return block_m_; }
|
||||
at::Tensor getOutput() const { return output_; }
|
||||
at::Tensor getDSplit(const int64_t block_id) { return split(output_, block_id); }
|
||||
};
|
||||
|
||||
at::Tensor matmul_allreduce(const int64_t cncl_comm,
|
||||
const at::Tensor &a,
|
||||
const at::Tensor &b,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<at::Tensor> &c,
|
||||
const c10::optional<at::Tensor> &d,
|
||||
const double alpha,
|
||||
const double beta,
|
||||
const int64_t block_m) {
|
||||
MatmulSplitM mm(a, b, bias, c, d, alpha, beta, block_m);
|
||||
ParallelAllReduce<MatmulSplitM> parallel_rs(cncl_comm);
|
||||
return parallel_rs(mm);
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,198 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
#include "comm_overlap.h"
|
||||
#include "torch_api/torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
using namespace torch::indexing;
|
||||
|
||||
struct QuantMatmulSplitM {
|
||||
const at::Tensor &a_;
|
||||
const c10::optional<at::Tensor> &a_scale_;
|
||||
const c10::optional<at::Tensor> &a_zero_;
|
||||
const at::Tensor &b_;
|
||||
const c10::optional<at::Tensor> &b_scale_;
|
||||
const c10::optional<at::Tensor> &b_zero_;
|
||||
const c10::optional<at::Tensor> &bias_;
|
||||
const c10::optional<at::Tensor> &c_;
|
||||
const c10::optional<at::Tensor> &c_scale_;
|
||||
const c10::optional<at::Tensor> &c_zero_;
|
||||
const c10::optional<at::Tensor> &output_scale_;
|
||||
const c10::optional<at::Tensor> &output_zero_;
|
||||
const c10::optional<std::string> &data_type_;
|
||||
const c10::optional<at::Tensor> &d_;
|
||||
const std::string &quant_algo_;
|
||||
const std::string &a_quant_layout_;
|
||||
const std::string &b_quant_layout_;
|
||||
int64_t quant_bit_size_;
|
||||
double alpha_;
|
||||
double beta_;
|
||||
bool trans_a_;
|
||||
bool trans_b_;
|
||||
int64_t block_m_;
|
||||
bool has_a_scale_;
|
||||
bool has_a_zero_;
|
||||
bool has_b_scale_;
|
||||
bool has_b_zero_;
|
||||
bool has_bias_;
|
||||
bool has_c_;
|
||||
bool has_c_scale_;
|
||||
bool has_c_zero_;
|
||||
bool has_output_scale_;
|
||||
bool has_output_zero_;
|
||||
bool has_output_;
|
||||
bool has_dtype_;
|
||||
at::Tensor output_;
|
||||
|
||||
QuantMatmulSplitM(const at::Tensor &a_tensor,
|
||||
const c10::optional<at::Tensor> &a_scale,
|
||||
const c10::optional<at::Tensor> &a_zero,
|
||||
const at::Tensor &b_tensor,
|
||||
const c10::optional<at::Tensor> &b_scale,
|
||||
const c10::optional<at::Tensor> &b_zero,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<at::Tensor> &c_tensor,
|
||||
const c10::optional<at::Tensor> &c_scale,
|
||||
const c10::optional<at::Tensor> &c_zero,
|
||||
const c10::optional<at::Tensor> &gemm_output_scale,
|
||||
const c10::optional<at::Tensor> &gemm_output_zero,
|
||||
const c10::optional<std::string> &data_type,
|
||||
const c10::optional<at::Tensor> &d,
|
||||
const std::string &quant_algo,
|
||||
const std::string &a_quant_layout,
|
||||
const std::string &b_quant_layout,
|
||||
int64_t quant_bit_size,
|
||||
double alpha,
|
||||
double beta,
|
||||
bool trans_a,
|
||||
bool trans_b,
|
||||
const int64_t block_m)
|
||||
: a_(a_tensor),
|
||||
a_scale_(a_scale),
|
||||
a_zero_(a_zero),
|
||||
b_(b_tensor),
|
||||
b_scale_(b_scale),
|
||||
b_zero_(b_zero),
|
||||
bias_(bias),
|
||||
c_(c_tensor),
|
||||
c_scale_(c_scale),
|
||||
c_zero_(c_zero),
|
||||
output_scale_(gemm_output_scale),
|
||||
output_zero_(gemm_output_zero),
|
||||
data_type_(data_type),
|
||||
d_(d),
|
||||
quant_algo_(quant_algo),
|
||||
a_quant_layout_(a_quant_layout),
|
||||
b_quant_layout_(b_quant_layout),
|
||||
quant_bit_size_(quant_bit_size),
|
||||
alpha_(alpha),
|
||||
beta_(beta),
|
||||
trans_a_(trans_a),
|
||||
trans_b_(trans_b),
|
||||
block_m_(block_m) {
|
||||
auto m = a_.size(0);
|
||||
if (block_m_ > m) {
|
||||
block_m_ = 1;
|
||||
} else if (block_m_ < 1) {
|
||||
block_m_ = m > 8192 ? 8 : (m < 2048 ? 1 : 4);
|
||||
}
|
||||
has_a_scale_ = a_scale_.has_value();
|
||||
has_a_zero_ = a_zero_.has_value();
|
||||
has_b_scale_ = b_scale_.has_value();
|
||||
has_b_zero_ = b_zero_.has_value();
|
||||
has_bias_ = bias_.has_value();
|
||||
has_c_ = c_.has_value();
|
||||
has_c_scale_ = c_scale_.has_value();
|
||||
has_c_zero_ = c_zero_.has_value();
|
||||
has_output_scale_ = output_scale_.has_value();
|
||||
has_output_zero_ = output_zero_.has_value();
|
||||
has_output_ = d_.has_value();
|
||||
has_dtype_ = data_type_.has_value();
|
||||
auto a_options = a_.options();
|
||||
if (has_output_) {
|
||||
output_ = d_.value();
|
||||
} else if (has_dtype_) {
|
||||
auto dtype = data_type_.value();
|
||||
TORCH_CHECK(dtype == "float" || dtype == "half" || dtype == "bfloat16",
|
||||
"data type must be 'float', 'half' or 'bfloat16'");
|
||||
auto torch_dtype = str2TorchDtype(dtype);
|
||||
output_ = at::empty({a_.size(0), b_.size(0)}, a_options.dtype(torch_dtype));
|
||||
} else {
|
||||
output_ = at::empty({a_.size(0), b_.size(0)}, a_options);
|
||||
}
|
||||
}
|
||||
|
||||
auto split(const at::Tensor &a, const int64_t block_id) {
|
||||
auto m = a.size(0);
|
||||
auto m_per_blk = m / block_m_;
|
||||
auto remain = m % block_m_;
|
||||
auto start = block_id * m_per_blk + std::min(block_id, remain);
|
||||
auto end = (block_id + 1) * m_per_blk + std::min(block_id + 1, remain);
|
||||
return a.narrow(0, start, end - start);
|
||||
}
|
||||
|
||||
at::Tensor forward(const int64_t block_id) {
|
||||
auto Di = quant_matmul(
|
||||
split(a_, block_id), has_a_scale_ ? split(a_scale_.value(), block_id) : a_scale_,
|
||||
has_a_zero_ ? split(a_zero_.value(), block_id) : a_zero_, b_, b_scale_, b_zero_, bias_,
|
||||
has_c_ ? split(c_.value(), block_id) : c_,
|
||||
has_c_scale_ ? split(c_scale_.value(), block_id) : c_scale_,
|
||||
has_c_zero_ ? split(c_zero_.value(), block_id) : c_zero_, output_scale_, output_zero_,
|
||||
data_type_, split(output_, block_id), quant_algo_, a_quant_layout_, b_quant_layout_,
|
||||
quant_bit_size_, "none", false, 1.0, alpha_, beta_, trans_a_, trans_b_);
|
||||
return Di;
|
||||
}
|
||||
|
||||
int64_t getLoopNum() const { return block_m_; }
|
||||
at::Tensor getOutput() const { return output_; }
|
||||
};
|
||||
|
||||
at::Tensor quant_matmul_allreduce(const int64_t cncl_comm,
|
||||
const at::Tensor &a_tensor,
|
||||
const c10::optional<at::Tensor> &a_scale,
|
||||
const c10::optional<at::Tensor> &a_zero,
|
||||
const at::Tensor &b_tensor,
|
||||
const c10::optional<at::Tensor> &b_scale,
|
||||
const c10::optional<at::Tensor> &b_zero,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<at::Tensor> &c_tensor,
|
||||
const c10::optional<at::Tensor> &c_scale,
|
||||
const c10::optional<at::Tensor> &c_zero,
|
||||
const c10::optional<at::Tensor> &gemm_output_scale,
|
||||
const c10::optional<at::Tensor> &gemm_output_zero,
|
||||
const c10::optional<std::string> &data_type,
|
||||
const c10::optional<at::Tensor> &d,
|
||||
const std::string &quant_algo,
|
||||
const std::string &a_quant_layout,
|
||||
const std::string &b_quant_layout,
|
||||
int64_t quant_bit_size,
|
||||
double alpha,
|
||||
double beta,
|
||||
bool trans_a,
|
||||
bool trans_b,
|
||||
const int64_t block_m) {
|
||||
TORCH_CHECK(!trans_a && trans_b, "trans_a must be false and trans_b must be true");
|
||||
QuantMatmulSplitM quant_mm(a_tensor, a_scale, a_zero, b_tensor, b_scale, b_zero, bias, c_tensor,
|
||||
c_scale, c_zero, gemm_output_scale, gemm_output_zero, data_type, d,
|
||||
quant_algo, a_quant_layout, b_quant_layout, quant_bit_size, alpha,
|
||||
beta, trans_a, trans_b, block_m);
|
||||
ParallelAllReduce<QuantMatmulSplitM> parallel_rs(cncl_comm);
|
||||
return parallel_rs(quant_mm);
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
89
torch_mlu_ops-v1.3.2/csrc/torch_api/copy_blocks.cpp
Normal file
89
torch_mlu_ops-v1.3.2/csrc/torch_api/copy_blocks.cpp
Normal file
@@ -0,0 +1,89 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/copy_blocks.mluh"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "torch_ops_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
void copy_blocks(const std::vector<torch::Tensor> &k_caches,
|
||||
const std::vector<torch::Tensor> &v_caches,
|
||||
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping_dict) {
|
||||
// Create block mapping array.
|
||||
std::vector<int32_t> block_mapping_vec;
|
||||
for (const auto &item : block_mapping_dict) {
|
||||
int64_t src_block_number = item.key();
|
||||
auto value_vec = item.value().vec();
|
||||
for (int64_t dst_block_number : value_vec) {
|
||||
block_mapping_vec.push_back(int32_t(src_block_number));
|
||||
block_mapping_vec.push_back(int32_t(dst_block_number));
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(!k_caches.empty() && !block_mapping_vec.empty(),
|
||||
"k_caches and block_mapping_vec can not be empty.")
|
||||
int32_t num_layers = k_caches.size();
|
||||
for (auto i = 0; i < num_layers; i++) {
|
||||
TORCH_CHECK(k_caches[i].dim() == 4, "every layer k_cache must be 4d.")
|
||||
}
|
||||
// check same device and tensor type
|
||||
TORCH_CHECK(isMlu(k_caches[0]), "k_caches must on mlu.");
|
||||
TORCH_CHECK(k_caches[0].dtype() == torch::kInt8 || k_caches[0].dtype() == torch::kUInt8 ||
|
||||
k_caches[0].dtype() == torch::kInt16 || k_caches[0].dtype() == torch::kInt32 ||
|
||||
k_caches[0].dtype() == torch::kLong || k_caches[0].dtype() == torch::kFloat16 ||
|
||||
k_caches[0].dtype() == torch::kFloat32 || k_caches[0].dtype() == torch::kBFloat16,
|
||||
"data type only supports torch::kInt8, torch::kUInt8, torch::kInt16, torch::kInt32, "
|
||||
"torch::kLong, torch::kFloat16, torch::kFloat32 and torch::kBFloat16");
|
||||
if (!v_caches.empty()) {
|
||||
TORCH_CHECK(k_caches.size() == v_caches.size(),
|
||||
"k_caches size must equal to "
|
||||
"v_caches size if v_caches is not none.")
|
||||
TORCH_CHECK(isMlu(v_caches[0]), "v_caches must on mlu.");
|
||||
for (auto i = 0; i < num_layers; i++) {
|
||||
TORCH_CHECK(k_caches[i].dtype() == v_caches[i].dtype(),
|
||||
"the data type of k_caches and v_caches are not the same.")
|
||||
TORCH_CHECK(k_caches[i].dim() == v_caches[i].dim(),
|
||||
"every layer k_cache dim must equal to v_cache dim.")
|
||||
// check shape
|
||||
TORCH_CHECK(k_caches[i][0].numel() == v_caches[0][0].numel(),
|
||||
"the block_size of k_caches and v_caches are not the same.")
|
||||
}
|
||||
}
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(k_caches[0].device());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
|
||||
size_t block_size_bytes = k_caches[0][0].numel() * k_caches[0].element_size();
|
||||
std::vector<void *> new_key_caches, new_value_caches;
|
||||
for (auto i = 0; i < num_layers; ++i) {
|
||||
new_key_caches.push_back(k_caches[i].data_ptr());
|
||||
if (!v_caches.empty()) {
|
||||
new_value_caches.push_back(v_caches[i].data_ptr());
|
||||
}
|
||||
}
|
||||
TMO_KERNEL_CHECK_FATAL(invokeCopyBlocksKernel(queue, new_key_caches, new_value_caches,
|
||||
block_mapping_vec, block_size_bytes));
|
||||
}
|
||||
|
||||
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> copy_blocks_out_of_place(
|
||||
const std::vector<at::Tensor> &k_caches,
|
||||
const std::vector<at::Tensor> &v_caches,
|
||||
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping) {
|
||||
copy_blocks(k_caches, v_caches, block_mapping);
|
||||
return std::make_tuple(k_caches, v_caches);
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,176 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/dequant_from_linear_cache.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
#define CHECK_TENSOR_DTYPE(x, expected_type) \
|
||||
TORCH_CHECK(x.scalar_type() == expected_type, "Tensor " #x " type should be ", \
|
||||
torchDtype2Str(expected_type), ".");
|
||||
|
||||
void dequant_from_linear_cache(
|
||||
at::Tensor &key, // [total_seqlen, head_num, head_size]
|
||||
const c10::optional<at::Tensor> &value, // same as above
|
||||
const at::Tensor &key_cache, // [max_batch_size, head_num, cache_mem_len, head_size]
|
||||
const c10::optional<at::Tensor> &value_cache, // same as above
|
||||
const at::Tensor &key_quant_scale, // quant_mode is 0: [head_num, head_size]
|
||||
// quant_mode is 1:
|
||||
// [max_batch_size, head_num, cache_mem_len]
|
||||
const c10::optional<at::Tensor> &value_quant_scale, // same as above
|
||||
const at::Tensor &context_lengths,
|
||||
const int64_t max_context_len,
|
||||
const c10::optional<at::Tensor> &context_seq_offset,
|
||||
const c10::optional<at::Tensor> &cache_bs_id,
|
||||
const c10::optional<at::Tensor> &cache_seq_offset,
|
||||
const int64_t quant_mode, // 0:per_channel, 1:per_head
|
||||
const int64_t quant_bit) { // 4 or 8
|
||||
// check same attr for tensors
|
||||
checkTensorSameAttr<TensorAttr::DEVICE>(key, key_cache, key_quant_scale, value, value_cache,
|
||||
value_quant_scale, context_lengths, context_seq_offset,
|
||||
cache_bs_id, cache_seq_offset);
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(context_lengths, context_seq_offset, cache_bs_id,
|
||||
cache_seq_offset);
|
||||
|
||||
// check quant parameters first
|
||||
TORCH_CHECK(quant_mode >= 0 && quant_mode <= 1, "quantization mode support 0 and 1.");
|
||||
TORCH_CHECK(quant_bit == 4 || quant_bit == 8, "quantization bit width support 4 and 8.");
|
||||
|
||||
/****************************************check key***************************************/
|
||||
// check dtype
|
||||
TORCH_CHECK(key.scalar_type() == torch::kFloat16 || key.scalar_type() == torch::kBFloat16,
|
||||
"Tensor key type should be half or bfloat16.");
|
||||
CHECK_TENSOR_DTYPE(key_cache, torch::kInt8);
|
||||
CHECK_TENSOR_DTYPE(key_quant_scale, torch::kFloat32);
|
||||
|
||||
// check_contiguous
|
||||
TORCH_CHECK(key.stride(-1) == 1, "Tensor key last dim must be contiguous.");
|
||||
CHECK_TENSOR_CONTIGUOUS(key_cache);
|
||||
CHECK_TENSOR_CONTIGUOUS(key_quant_scale);
|
||||
|
||||
// check shape
|
||||
TORCH_CHECK(key.dim() == 3, "The dimensions of tensor key only support 3.");
|
||||
TORCH_CHECK(key_cache.dim() == 4, "The dimensions of tensor key_cache only supports 4.");
|
||||
TORCH_CHECK(context_lengths.dim() == 1,
|
||||
"The dimensions of tensor context_lengths only supports 1.");
|
||||
const int32_t total_seqlen = key.size(0);
|
||||
const int32_t head_num = key.size(1);
|
||||
const int32_t head_size = key.size(2);
|
||||
const int32_t max_batch_size = key_cache.size(0);
|
||||
const int32_t cache_mem_len = key_cache.size(2);
|
||||
const int32_t batch_size = context_lengths.size(0);
|
||||
CHECK_SHAPE(context_lengths, batch_size);
|
||||
TORCH_CHECK(max_batch_size >= batch_size,
|
||||
"max_batch_size should be greater than or equal to batch_size.");
|
||||
TORCH_CHECK(cache_mem_len % 2 == 0, "cache_mem_len should be a multiply of 2.");
|
||||
if (quant_mode == 0) {
|
||||
CHECK_SHAPE(key_quant_scale, head_num, head_size);
|
||||
} else if (quant_mode == 1) {
|
||||
CHECK_SHAPE(key_quant_scale, max_batch_size, head_num, cache_mem_len);
|
||||
}
|
||||
|
||||
if (quant_bit == 4) {
|
||||
TORCH_CHECK(head_size % 2 == 0, "head_size should be a multiply of 2 if quant_bit is 4.");
|
||||
CHECK_SHAPE(key_cache, max_batch_size, head_num, cache_mem_len, head_size >> 1);
|
||||
} else {
|
||||
CHECK_SHAPE(key_cache, max_batch_size, head_num, cache_mem_len, head_size);
|
||||
}
|
||||
|
||||
/***************************************check value***************************************/
|
||||
if (value.has_value() || value_cache.has_value() || value_quant_scale.has_value()) {
|
||||
TORCH_CHECK(value.has_value() && value_cache.has_value() && value_quant_scale.has_value(),
|
||||
"value, value_cache, and value_quant_scale must all exists.")
|
||||
}
|
||||
|
||||
if (value_cache.has_value()) {
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(key, value);
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(key_cache, value_cache);
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(key_quant_scale, value_quant_scale);
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(value_cache);
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(value_quant_scale);
|
||||
TORCH_CHECK(key_quant_scale.dim() == value_quant_scale.value().dim(),
|
||||
"value_cache_quant_scale dim should keep same with key_cache_quant_scale.");
|
||||
CHECK_SHAPE(value.value(), total_seqlen, head_num, head_size);
|
||||
if (quant_mode == 0) {
|
||||
CHECK_SHAPE(value_quant_scale.value(), head_num, head_size);
|
||||
} else {
|
||||
CHECK_SHAPE(value_quant_scale.value(), max_batch_size, head_num, cache_mem_len);
|
||||
}
|
||||
|
||||
if (quant_bit == 4) {
|
||||
CHECK_SHAPE(value_cache.value(), max_batch_size, head_num, cache_mem_len >> 1, head_size);
|
||||
} else {
|
||||
CHECK_SHAPE(value_cache.value(), max_batch_size, head_num, cache_mem_len, head_size);
|
||||
}
|
||||
|
||||
for (int i = 0; i < key.dim(); i++) {
|
||||
TORCH_CHECK(value.value().stride(i) == key.stride(i),
|
||||
"key and value must have same stride along axi ", i, ".");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(context_lengths.scalar_type() == torch::kInt32,
|
||||
"context_lengths type need be torch::kInt32.");
|
||||
/*********************************check optional tensor***********************************/
|
||||
const void *context_seq_offset_ptr = nullptr;
|
||||
if (context_seq_offset.has_value()) {
|
||||
CHECK_SHAPE(context_seq_offset.value(), batch_size);
|
||||
context_seq_offset_ptr = context_seq_offset.value().data_ptr();
|
||||
} else {
|
||||
TORCH_CHECK(batch_size <= 1024,
|
||||
"batch_size greater than 1024 not support when context_seq_offset is None.");
|
||||
}
|
||||
|
||||
const void *cache_bs_id_ptr = nullptr;
|
||||
if (cache_bs_id.has_value()) {
|
||||
CHECK_SHAPE(cache_bs_id.value(), batch_size);
|
||||
cache_bs_id_ptr = cache_bs_id.value().data_ptr();
|
||||
}
|
||||
|
||||
const void *cache_seq_offset_ptr = nullptr;
|
||||
if (cache_seq_offset.has_value()) {
|
||||
CHECK_SHAPE(cache_seq_offset.value(), batch_size);
|
||||
cache_seq_offset_ptr = cache_seq_offset.value().data_ptr();
|
||||
}
|
||||
|
||||
// propare parameters before calling invokeDequantFromLinearCache
|
||||
const int32_t key_group_num = 1;
|
||||
const int32_t value_group_num = 1;
|
||||
const size_t context_head_stride = key.stride(1);
|
||||
const size_t context_seq_stride = key.stride(0);
|
||||
const size_t cache_bs_stride = key_cache.stride(0);
|
||||
const size_t cache_head_stride = key_cache.stride(1);
|
||||
const size_t key_cache_seq_stride = key_cache.stride(2);
|
||||
const size_t value_cache_seq_stride = value_cache.has_value() ? value_cache.value().stride(2) : 0;
|
||||
const size_t cache_scale_bs_stride = key_quant_scale.dim() == 3 ? key_quant_scale.stride(0) : 0;
|
||||
const size_t cache_scale_head_stride =
|
||||
key_quant_scale.dim() == 3 ? key_quant_scale.stride(1) : key_quant_scale.stride(0);
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(key.device());
|
||||
auto data_dtype = getCnnlDataType(key.scalar_type());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
|
||||
// run forward
|
||||
TMO_KERNEL_CHECK_FATAL(invokeDequantFromLinearCache(
|
||||
queue, getAtTensorPtr(key), getAtTensorPtr(value), getAtTensorPtr(key_cache),
|
||||
getAtTensorPtr(value_cache), getAtTensorPtr(key_quant_scale),
|
||||
getAtTensorPtr(value_quant_scale), getAtTensorPtr(context_lengths), context_seq_offset_ptr,
|
||||
cache_bs_id_ptr, cache_seq_offset_ptr, (int)max_context_len, batch_size, head_num,
|
||||
key_group_num, value_group_num, cache_mem_len, head_size, quant_mode, quant_bit,
|
||||
context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride,
|
||||
key_cache_seq_stride, value_cache_seq_stride, cache_scale_bs_stride, cache_scale_head_stride,
|
||||
data_dtype));
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
162
torch_mlu_ops-v1.3.2/csrc/torch_api/dequant_from_paged_cache.cpp
Normal file
162
torch_mlu_ops-v1.3.2/csrc/torch_api/dequant_from_paged_cache.cpp
Normal file
@@ -0,0 +1,162 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/dequant_from_paged_cache.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
#define CHECK_TENSOR_DTYPE(x, expected_type) \
|
||||
TORCH_CHECK(x.scalar_type() == expected_type, "Tensor " #x " type should be ", \
|
||||
torchDtype2Str(expected_type), ".");
|
||||
|
||||
void dequant_from_paged_cache(at::Tensor &key, // [total_seqlen, head_num, head_size]
|
||||
const c10::optional<at::Tensor> &value, // same as above
|
||||
const at::Tensor &key_cache, // [token_num, head_num, block_size]
|
||||
const c10::optional<at::Tensor> &value_cache, // same as above
|
||||
// [token_num, head_num, block_size, head_size] for per-token
|
||||
// [head_num, head_size] for per-channel
|
||||
const at::Tensor &key_cache_quant_scale,
|
||||
const c10::optional<at::Tensor> &value_cache_quant_scale,
|
||||
const at::Tensor &context_lengths, // [batch_size]
|
||||
int64_t max_context_len,
|
||||
const c10::optional<at::Tensor> &context_seq_offset, // [batch_size]
|
||||
const at::Tensor &block_tables, // [batch_size, max_block_num]
|
||||
int64_t quant_mode, // 0 is per-channel, 1 is per-token
|
||||
int64_t quant_bit) { // quantization bit only support 8
|
||||
// check same attr for tensors
|
||||
checkTensorSameAttr<TensorAttr::DEVICE>(key, key_cache, key_cache_quant_scale, value, value_cache,
|
||||
value_cache_quant_scale, context_lengths,
|
||||
context_seq_offset, block_tables);
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(context_lengths, context_seq_offset, block_tables);
|
||||
|
||||
// check quant parameters first
|
||||
TORCH_CHECK(quant_bit == 8, "quantization bit width only supports 8.");
|
||||
TORCH_CHECK(quant_mode >= 0 && quant_mode <= 1, "quantization mode support 0 and 1.");
|
||||
|
||||
/***************************************check key***************************************/
|
||||
// check dtype
|
||||
TORCH_CHECK(key.scalar_type() == torch::kFloat16 || key.scalar_type() == torch::kBFloat16,
|
||||
"Tensor key type should be half or bfloat16.");
|
||||
CHECK_TENSOR_DTYPE(key_cache, torch::kInt8);
|
||||
CHECK_TENSOR_DTYPE(key_cache_quant_scale, torch::kFloat32);
|
||||
CHECK_TENSOR_DTYPE(block_tables, torch::kInt32);
|
||||
|
||||
// check contiguous
|
||||
TORCH_CHECK(key.stride(-1) == 1, "Tensor key last dim must be contiguous.")
|
||||
CHECK_TENSOR_CONTIGUOUS(key_cache);
|
||||
CHECK_TENSOR_CONTIGUOUS(key_cache_quant_scale);
|
||||
CHECK_TENSOR_CONTIGUOUS(block_tables);
|
||||
|
||||
// check shape
|
||||
TORCH_CHECK(key.dim() == 3, "The dimensions of tensor key only supports 3.");
|
||||
TORCH_CHECK(key_cache.dim() == 4, "The dimensions of tensor key_cache only supports 4.");
|
||||
TORCH_CHECK(context_lengths.dim() == 1,
|
||||
"The dimensions of tensor context_lengths only supports 1.");
|
||||
TORCH_CHECK(block_tables.dim() == 2, "The dimensions of tensor block_tables only supports 2.");
|
||||
const int32_t token_num = key.size(0);
|
||||
const int32_t head_num = key.size(1);
|
||||
const int32_t head_size = key.size(2);
|
||||
const int32_t block_num = key_cache.size(0);
|
||||
const int32_t block_size = key_cache.size(2);
|
||||
const int32_t batch_size = context_lengths.size(0);
|
||||
const int32_t max_block_num = block_tables.size(1);
|
||||
CHECK_SHAPE(key_cache, block_num, head_num, block_size, head_size);
|
||||
int64_t kv_cache_range =
|
||||
(int64_t)block_num * head_num * block_size * head_size * key_cache.element_size();
|
||||
// kernel use uint32_t to calculate offsets
|
||||
TORCH_CHECK(kv_cache_range <= UINT32_MAX, "The addressing range of key_cache cannot exceed 4GB.");
|
||||
if (quant_mode == 0) {
|
||||
CHECK_SHAPE(key_cache_quant_scale, head_num, head_size);
|
||||
} else if (quant_mode == 1) {
|
||||
CHECK_SHAPE(key_cache_quant_scale, block_num, head_num, block_size);
|
||||
}
|
||||
|
||||
/***************************************check value***************************************/
|
||||
if (value.has_value() || value_cache.has_value() || value_cache_quant_scale.has_value()) {
|
||||
TORCH_CHECK(value.has_value() && value_cache.has_value() && value_cache_quant_scale.has_value(),
|
||||
"value, value_cache, and value_cache_quant_scale must all exists.")
|
||||
}
|
||||
|
||||
if (value_cache.has_value()) {
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(key, value);
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(key_cache, value_cache);
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(key_cache_quant_scale, value_cache_quant_scale);
|
||||
TORCH_CHECK(value.value().stride(-1) == 1, "Tensor value last dim must be contiguous.")
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(value_cache);
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(value_cache_quant_scale);
|
||||
CHECK_SHAPE(value.value(), token_num, head_num, head_size);
|
||||
TORCH_CHECK(value_cache_quant_scale.value().dim() == key_cache_quant_scale.dim(),
|
||||
"value_cache_quant_scale dim should keep same with key_cache_quant_scale.");
|
||||
if (quant_mode == 0) {
|
||||
CHECK_SHAPE(value_cache_quant_scale.value(), head_num, head_size);
|
||||
} else {
|
||||
CHECK_SHAPE(value_cache_quant_scale.value(), block_num, head_num, block_size);
|
||||
}
|
||||
|
||||
CHECK_SHAPE(value_cache.value(), block_num, head_num, block_size, head_size);
|
||||
|
||||
for (int i = 0; i < key.dim(); i++) {
|
||||
TORCH_CHECK(value.value().stride(i) == key.stride(i),
|
||||
"key and value must have same stride along axi ", i, ".");
|
||||
}
|
||||
}
|
||||
|
||||
/*********************************check index tensor***********************************/
|
||||
CHECK_SHAPE(context_lengths, batch_size);
|
||||
CHECK_SHAPE(block_tables, batch_size, max_block_num);
|
||||
TORCH_CHECK(max_context_len <= block_size * max_block_num,
|
||||
"max_context_len should smaller than or equal to block_size * max_block_num.");
|
||||
|
||||
/*********************************check optional tensor***********************************/
|
||||
const void *context_seq_offset_ptr = nullptr;
|
||||
if (context_seq_offset.has_value()) {
|
||||
CHECK_SHAPE(context_seq_offset.value(), batch_size);
|
||||
context_seq_offset_ptr = context_seq_offset.value().data_ptr();
|
||||
} else {
|
||||
TORCH_CHECK(batch_size <= 1024,
|
||||
"batch_size greater than 1024 not support when context_seq_offset is None.");
|
||||
}
|
||||
|
||||
// propare parameters before calling invokeDequantFromLinearCache
|
||||
const int32_t key_group_num = 1;
|
||||
const int32_t value_group_num = 1;
|
||||
const size_t context_head_stride = key.stride(1);
|
||||
const size_t context_seq_stride = key.stride(0);
|
||||
const size_t cache_bs_stride = key_cache.stride(0);
|
||||
const size_t cache_head_stride = key_cache.stride(1);
|
||||
const size_t key_cache_seq_stride = key_cache.stride(2);
|
||||
const size_t value_cache_seq_stride = value_cache.has_value() ? value_cache.value().stride(2) : 0;
|
||||
const size_t cache_scale_bs_stride =
|
||||
key_cache_quant_scale.dim() == 3 ? key_cache_quant_scale.stride(0) : 0;
|
||||
const size_t cache_scale_head_stride = key_cache_quant_scale.dim() == 3
|
||||
? key_cache_quant_scale.stride(1)
|
||||
: key_cache_quant_scale.stride(0);
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(key.device());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
auto data_dtype = getCnnlDataType(key.scalar_type());
|
||||
|
||||
// run forward
|
||||
TMO_KERNEL_CHECK_FATAL(invokeDequantFromPagedCache(
|
||||
queue, getAtTensorPtr(key), getAtTensorPtr(value), getAtTensorPtr(key_cache),
|
||||
getAtTensorPtr(value_cache), getAtTensorPtr(key_cache_quant_scale),
|
||||
getAtTensorPtr(value_cache_quant_scale), getAtTensorPtr(context_lengths),
|
||||
context_seq_offset_ptr, getAtTensorPtr(block_tables), (int)max_context_len, max_block_num,
|
||||
batch_size, head_num, key_group_num, value_group_num, block_size, head_size, quant_mode,
|
||||
quant_bit, context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride,
|
||||
key_cache_seq_stride, value_cache_seq_stride, cache_scale_bs_stride, cache_scale_head_stride,
|
||||
data_dtype));
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
156
torch_mlu_ops-v1.3.2/csrc/torch_api/ffn.cpp
Normal file
156
torch_mlu_ops-v1.3.2/csrc/torch_api/ffn.cpp
Normal file
@@ -0,0 +1,156 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "common/utils.h"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
/*
|
||||
torch_ffn_api API funtion in pytorch op is:
|
||||
class FeedForward(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, inner_size: int, act_mode: str,
|
||||
bias = False, gated = False):
|
||||
super(FeedForward, self).__init__()
|
||||
self.up_linear = torch.nn.Linear(hidden_size, inner_size, bias)
|
||||
self.gated = gated
|
||||
if self.gated:
|
||||
self.gated_linear = torch.nn.Linear(hidden_size, inner_size, bias)
|
||||
self.down_linear = torch.nn.Linear(inner_size, hidden_size, bias)
|
||||
self.act = act_mode_dict[act_mode]
|
||||
|
||||
def forward(self, x):
|
||||
act_out = self.act(self.up_linear(x).float()).to(x.dtype)
|
||||
return self.down_linear(act_out * self.gated_linear(x)) \
|
||||
if self.gated else self.down_linear(act_out)
|
||||
|
||||
Demo of pytorch python module in TGI:
|
||||
class LlamaBtMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.act = config.hidden_act
|
||||
self.gate_proj_weight = weights.get_multi_weights_col(
|
||||
f"{prefix}.gate_proj", quantize=config.quantize, dim=dim
|
||||
)
|
||||
self.gate_proj_bias = None
|
||||
self.up_proj_weight = weights.get_multi_weights_col(
|
||||
f"{prefix}.up_proj", quantize=config.quantize, dim=dim
|
||||
)
|
||||
self.up_proj_bias = None
|
||||
self.down_proj_weight = weights.get_multi_weights_col(
|
||||
f"{prefix}.down_proj", quantize=config.quantize, dim=dim
|
||||
)
|
||||
self.down_proj_bias = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return tmo.ffn(hidden_states, self.up_proj_weight, self.up_proj_bias,
|
||||
self.down_proj_weight, self.down_proj_bias, self.gate_proj_weight,
|
||||
self.gate_proj_bias, self.act)
|
||||
*/
|
||||
|
||||
// act_mode now only support silu, gelu and relu.
|
||||
// std::string pytorch func
|
||||
// silu --> nn.SiLU
|
||||
// gelu --> nn.functional.gelu
|
||||
// relu --> nn.ReLU
|
||||
// Maybe aligned with transformer act mode later.
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L200
|
||||
|
||||
// Dimensions of ffn .
|
||||
// Dimension of input: [batch_num, seq_len, hidden_size]
|
||||
// Dimension of up_fc_filters: [filter_size, hidden_size]
|
||||
// Dimension of up_fc_bias: [filter_size]
|
||||
// Dimension of down_fc_filters: [hidden_size, filter_size]
|
||||
// Dimension of down_fc_bias: [hidden_size]
|
||||
// Dimension of gated_fc_filters: [filter_size, hidden_size] only for gated fnn
|
||||
// Dimension of gated_fc_bias: [filter_size] only for gated fnn
|
||||
// Dimension of layer norm weight: [seq_len, hidden_size] only for layer norm fused
|
||||
// Dimension of layer norm bias: [seq_len, hidden_size] only for layer norm fused
|
||||
// Dimension of output is same as input.
|
||||
// Dimension of \b output: [batch_num, seq_len, hidden_size]
|
||||
|
||||
// Fused layer norm is not support now. And only support fused
|
||||
// per layer norm later.
|
||||
|
||||
at::Tensor ffn(const at::Tensor &input,
|
||||
const at::Tensor &up_fc_weight,
|
||||
const c10::optional<at::Tensor> &up_fc_bias,
|
||||
const at::Tensor &down_proj_weight,
|
||||
const c10::optional<at::Tensor> &down_proj_bias,
|
||||
const c10::optional<at::Tensor> &gate_up_proj_weight,
|
||||
const c10::optional<at::Tensor> &gate_up_proj_bias,
|
||||
const c10::optional<at::Tensor> &layernorm_weight,
|
||||
const c10::optional<at::Tensor> &layernorm_bias,
|
||||
const std::string &act_mode,
|
||||
const std::string &residual_is,
|
||||
double eps,
|
||||
double alpha,
|
||||
double beta) {
|
||||
// Check tensor type and tensor device.
|
||||
checkTensorSameAttr<TensorAttr::ALL>(input, up_fc_weight, up_fc_bias, down_proj_weight,
|
||||
down_proj_bias, gate_up_proj_weight, gate_up_proj_bias,
|
||||
layernorm_weight, layernorm_bias);
|
||||
// Check dims.
|
||||
const int64_t nDim = input.dim();
|
||||
at::Tensor input_view = input;
|
||||
if (nDim == 2) input_view = input_view.unsqueeze(0);
|
||||
|
||||
// Check contiguous.
|
||||
CHECK_TENSOR_CONTIGUOUS(input_view)
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(input_view.device());
|
||||
at::Tensor output = at::empty_like(input_view);
|
||||
|
||||
// Convert torch tensor to tensor desc
|
||||
auto descs = createTensorDescs(
|
||||
{input_view, up_fc_weight, up_fc_bias.value_or(torch::Tensor()), down_proj_weight,
|
||||
down_proj_bias.value_or(torch::Tensor()), gate_up_proj_weight.value_or(torch::Tensor()),
|
||||
gate_up_proj_bias.value_or(torch::Tensor()), layernorm_weight.value_or(torch::Tensor()),
|
||||
layernorm_bias.value_or(torch::Tensor()), output});
|
||||
|
||||
bool has_ln = layernorm_weight.has_value();
|
||||
TORCH_CHECK(residual_is == "input" || residual_is == "normed_input" || residual_is == "none",
|
||||
"residual_is must be 'input' or 'normed_input' or 'none'.")
|
||||
bool has_residual = residual_is != "none";
|
||||
auto ln_res_mode = tmo::lnres::makeLnresEnum(has_ln, has_residual, residual_is == "input");
|
||||
auto compute_type = getCnnlDataType(input_view.scalar_type());
|
||||
|
||||
tmo::op_desc::FeedForwardDesc ffn_desc(ln_res_mode, act_mode, compute_type, eps, alpha, beta);
|
||||
if (gate_up_proj_weight.has_value() && gate_up_proj_weight->defined()) {
|
||||
CNNL_CHECK_FATAL(cnnlSetTransformerFeedForwardDescriptorGateFiltersBias(
|
||||
ffn_desc, descs[5].get(), getAtTensorPtr(gate_up_proj_weight), descs[6].get(),
|
||||
getAtTensorPtr(gate_up_proj_bias)));
|
||||
}
|
||||
|
||||
// Get current handle.
|
||||
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
||||
// Get workspace size and malloc workspace.
|
||||
size_t workspace_size = 0;
|
||||
CNNL_CHECK_FATAL(cnnlGetTransformerFeedForwardWorkspaceSize_v2(
|
||||
handle, ffn_desc, descs[0].get(), descs[1].get(), nullptr, ffn_desc, &workspace_size));
|
||||
auto workspace =
|
||||
at::empty({static_cast<int64_t>(workspace_size)}, input.options().dtype(at::kByte));
|
||||
|
||||
// forward
|
||||
CNNL_CHECK_FATAL(cnnlTransformerFeedForward(
|
||||
handle, ffn_desc, ffn_desc, nullptr, descs[0].get(), getAtTensorPtr(input_view),
|
||||
descs[1].get(), getAtTensorPtr(up_fc_weight), descs[2].get(), getAtTensorPtr(up_fc_bias),
|
||||
descs[3].get(), getAtTensorPtr(down_proj_weight), descs[4].get(),
|
||||
getAtTensorPtr(down_proj_bias), descs[7].get(), getAtTensorPtr(layernorm_weight),
|
||||
descs[8].get(), getAtTensorPtr(layernorm_bias), getAtTensorPtr(workspace), workspace_size,
|
||||
descs[9].get(), getAtTensorPtr(output)));
|
||||
return nDim == 2 ? output.squeeze_(0) : output;
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
117
torch_mlu_ops-v1.3.2/csrc/torch_api/flash_attention.cpp
Normal file
117
torch_mlu_ops-v1.3.2/csrc/torch_api/flash_attention.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "torch_ops_api.h"
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
void flash_attention(const at::Tensor &q,
|
||||
const at::Tensor &k,
|
||||
const at::Tensor &v,
|
||||
const at::Tensor &output,
|
||||
const c10::optional<at::Tensor> &output_lse,
|
||||
const c10::optional<at::Tensor> &cu_seq_lens_q,
|
||||
const c10::optional<at::Tensor> &cu_seq_lens_kv,
|
||||
const c10::optional<at::Tensor> &alibi_slope,
|
||||
const c10::optional<at::Tensor> &attn_bias,
|
||||
const c10::optional<at::Tensor> &k_cache_quant_scale,
|
||||
const c10::optional<at::Tensor> &v_cache_quant_scale,
|
||||
const c10::optional<at::Tensor> &block_tables,
|
||||
const int64_t max_seq_len_q,
|
||||
const int64_t max_seq_len_kv,
|
||||
const double softmax_scale,
|
||||
const bool is_causal,
|
||||
const int64_t window_size_left,
|
||||
const int64_t window_size_right,
|
||||
const std::string &compute_dtype,
|
||||
bool return_lse) {
|
||||
TORCH_CHECK(compute_dtype == "float" || compute_dtype == "half" || compute_dtype == "bfloat16",
|
||||
"compute_dtype must be 'float', 'half' or 'bfloat16'.");
|
||||
TORCH_CHECK(k_cache_quant_scale.has_value() == false, "k_cache_scale for reserve.");
|
||||
TORCH_CHECK(v_cache_quant_scale.has_value() == false, "v_cache_scale for reserve.");
|
||||
bool has_block_table = block_tables.has_value();
|
||||
bool is_pack = cu_seq_lens_q.has_value();
|
||||
int64_t batch = is_pack ? cu_seq_lens_q.value().size(0) - 1 : q.size(0);
|
||||
int qk_head_size = q.size(-1);
|
||||
int v_head_size = v.size(-1);
|
||||
|
||||
// Check tensor type and tensor device.
|
||||
checkTensorSameAttr<TensorAttr::ALL>(q, k, v, output);
|
||||
// 3d for packed
|
||||
TORCH_CHECK(q.dim() == 3 || q.dim() == 4, "query must be 3d or 4d.");
|
||||
if (has_block_table) {
|
||||
TORCH_CHECK(block_tables.value().dim() == 2, "block_tables must be 2d.");
|
||||
TORCH_CHECK(k.dim() == 4, "with block table, key_cache must be 4d.");
|
||||
TORCH_CHECK(v.dim() == 4, "with block table, value_cache must be 4d.");
|
||||
int max_num_blocks_per_seq = block_tables.value().size(1);
|
||||
int num_blocks = k.size(0);
|
||||
int block_size = k.size(2);
|
||||
int k_head_num = k.size(1);
|
||||
CHECK_SHAPE(k, num_blocks, k_head_num, block_size, qk_head_size);
|
||||
CHECK_SHAPE(v, num_blocks, k_head_num, block_size, v_head_size);
|
||||
if (max_num_blocks_per_seq > 1) { // paged
|
||||
CHECK_SHAPE(cu_seq_lens_kv.value(), batch + 1);
|
||||
}
|
||||
} else {
|
||||
// 3d for packed
|
||||
TORCH_CHECK(k.dim() == 3 || k.dim() == 4, "key_cache must be 3d or 4d.");
|
||||
TORCH_CHECK(v.dim() == 3 || v.dim() == 4, "value_cache must be 3d or 4d.");
|
||||
if (k.dim() == 3) { // packed_kv
|
||||
CHECK_SHAPE(cu_seq_lens_kv.value(), batch + 1);
|
||||
}
|
||||
}
|
||||
// Convert torch tensor to tensor descs
|
||||
auto descs = createTensorDescs(
|
||||
{q, k, v, cu_seq_lens_q.value_or(at::Tensor()), cu_seq_lens_kv.value_or(at::Tensor()),
|
||||
alibi_slope.value_or(at::Tensor()), attn_bias.value_or(at::Tensor()),
|
||||
block_tables.value_or(at::Tensor()), output, output_lse.value_or(at::Tensor())});
|
||||
|
||||
// Get current handle.
|
||||
const torch_mlu::mlu::MLUGuard device_guard(q.device());
|
||||
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
||||
// Get workspace size and malloc workspace.
|
||||
cnnlDataType_t cnnl_compute_dtype = compute_dtype == "float" ? CNNL_DTYPE_FLOAT
|
||||
: compute_dtype == "half" ? CNNL_DTYPE_HALF
|
||||
: CNNL_DTYPE_BFLOAT16;
|
||||
size_t workspace_size = 0;
|
||||
CNNL_CHECK_FATAL(cnnlGetScaledDotProductAttnWorkspaceSize_v2(
|
||||
handle, nullptr /*op_desc*/, nullptr /*quant_desc*/, descs[0].get(), descs[1].get(),
|
||||
descs[2].get(), descs[3].get(), descs[4].get(), descs[6].get(), descs[5].get(),
|
||||
descs[7].get(), max_seq_len_q, max_seq_len_kv, is_causal, window_size_left, window_size_right,
|
||||
return_lse, CNNL_ACTIVATION_FAST, cnnl_compute_dtype, &workspace_size));
|
||||
auto workspace = at::empty({static_cast<int64_t>(workspace_size)}, q.options().dtype(at::kByte));
|
||||
|
||||
int64_t total_q = is_pack ? q.size(0) : q.size(0) * q.size(1);
|
||||
int64_t total_k =
|
||||
has_block_table ? k.size(0) * k.size(2) : (k.dim() == 3 ? k.size(0) : k.size(0) * k.size(1));
|
||||
int64_t head_q = q.size(-2);
|
||||
int64_t head_k = has_block_table ? k.size(1) : k.size(-2);
|
||||
cnnlDataType_t data_dtype = getCnnlDataType(q.scalar_type());
|
||||
FlashAttnTheory obj(batch, total_q, total_k, head_q, head_k, qk_head_size, v_head_size, is_causal,
|
||||
data_dtype);
|
||||
cnpxPush(obj);
|
||||
// call cnnl extra op.
|
||||
CNNL_CHECK_FATAL(cnnlScaledDotProductAttn_v3(
|
||||
handle, nullptr, nullptr, descs[0].get(), getAtTensorPtr(q), descs[1].get(),
|
||||
getAtTensorPtr(k), descs[2].get(), getAtTensorPtr(v), descs[3].get(),
|
||||
getAtTensorPtr(cu_seq_lens_q), descs[4].get(), getAtTensorPtr(cu_seq_lens_kv), nullptr,
|
||||
nullptr, descs[6].get(), getAtTensorPtr(attn_bias), descs[5].get(),
|
||||
getAtTensorPtr(alibi_slope), descs[7].get(), getAtTensorPtr(block_tables), max_seq_len_q,
|
||||
max_seq_len_kv, is_causal, window_size_left, window_size_right, softmax_scale,
|
||||
CNNL_ACTIVATION_FAST, cnnl_compute_dtype, getAtTensorPtr(workspace), workspace_size,
|
||||
return_lse, descs[9].get(), getAtTensorPtr(output_lse), descs[8].get(),
|
||||
getAtTensorPtr(output)));
|
||||
cnpxPop();
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
134
torch_mlu_ops-v1.3.2/csrc/torch_api/fuse_norm.cpp
Normal file
134
torch_mlu_ops-v1.3.2/csrc/torch_api/fuse_norm.cpp
Normal file
@@ -0,0 +1,134 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
void fused_layernorm(const at::Tensor &input,
|
||||
const at::Tensor &output,
|
||||
const c10::optional<at::Tensor> &residual,
|
||||
const c10::optional<at::Tensor> &gamma,
|
||||
const c10::optional<at::Tensor> &beta,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<at::Tensor> &quant_scale,
|
||||
const c10::optional<at::Tensor> &residual_out,
|
||||
const c10::optional<at::Tensor> &smooth_quant_scale,
|
||||
const std::string &norm_mode,
|
||||
double eps,
|
||||
bool store_output_before_norm,
|
||||
bool dynamic_quant) {
|
||||
// check device and dtype
|
||||
checkTensorSameAttr<TensorAttr::ALL>(input, residual, gamma, beta, bias);
|
||||
checkTensorSameAttr<TensorAttr::DEVICE>(input, output, residual_out, smooth_quant_scale);
|
||||
cnnlQuantizeScheme_t output_quant_scheme = CNNL_QUANTIZE_NONE;
|
||||
if (dynamic_quant) {
|
||||
TORCH_CHECK(quant_scale.has_value(), "dynamic_quant output, must have quant_scale");
|
||||
output_quant_scheme = CNNL_QUANTIZE_PER_TOKEN;
|
||||
} else if (quant_scale.has_value()) {
|
||||
output_quant_scheme = CNNL_QUANTIZE_PER_CHANNEL;
|
||||
}
|
||||
// check params
|
||||
bool has_residual = residual.has_value();
|
||||
bool quant_out = quant_scale.has_value();
|
||||
int hidden_size = input.size(-1);
|
||||
TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2.");
|
||||
TORCH_CHECK(input.stride(-1) == 1, "input last dim must be contiguous.");
|
||||
TORCH_CHECK(input.sizes() == output.sizes(), "input and output must have the same shape");
|
||||
if (has_residual) {
|
||||
TORCH_CHECK(input.sizes() == residual.value().sizes(),
|
||||
"input and residual must have the same shape");
|
||||
}
|
||||
TORCH_CHECK(norm_mode == "layernorm" || norm_mode == "rmsnorm",
|
||||
"norm_mode must be 'layernorm' or 'rmsnorm'.");
|
||||
cnnlTransformerNormType_t mode =
|
||||
norm_mode == "layernorm" ? CNNL_TRANSFORMER_LAYERNORM : CNNL_TRANSFORMER_RMSNORM;
|
||||
if (norm_mode == "layernorm") {
|
||||
TORCH_CHECK(gamma.has_value() && beta.has_value(), "layernorm mode need gamma and beta.");
|
||||
TORCH_CHECK(gamma.value().sizes() == beta.value().sizes(),
|
||||
"gamma and beta must have the same shape")
|
||||
TORCH_CHECK(gamma.value().dim() == 1 && gamma.value().size(0) == hidden_size,
|
||||
"layernorm mode, gamma and beta size must be hidden_size.");
|
||||
} else {
|
||||
TORCH_CHECK(gamma.has_value(), "rmsnorm mode need gamma.");
|
||||
TORCH_CHECK(gamma.value().dim() == 1 && gamma.value().size(0) == hidden_size,
|
||||
"rmsnorm mode, gamma size must be hidden_size.");
|
||||
}
|
||||
if (quant_out) {
|
||||
TORCH_CHECK(quant_scale.value().dim() == 1 && quant_scale.value().size(0) == hidden_size,
|
||||
"quant_scale shape must be [hidden_size]");
|
||||
}
|
||||
const torch_mlu::mlu::MLUGuard device_guard(input.device());
|
||||
at::Tensor smooth_quant_scale_flat =
|
||||
dynamic_quant ? smooth_quant_scale.value().flatten() : at::Tensor();
|
||||
at::Tensor input_flat;
|
||||
at::Tensor output_flat;
|
||||
at::Tensor residual_flat;
|
||||
at::Tensor residual_out_flat;
|
||||
if (quant_out) {
|
||||
// input must be 2-dim, output dim must be same as input
|
||||
TORCH_CHECK(input.is_contiguous(), "quant_out is not support when input has stride.");
|
||||
input_flat = input.flatten(0, -2);
|
||||
output_flat = output.flatten(0, -2);
|
||||
residual_flat = has_residual ? residual.value().flatten(0, -2) : residual_flat;
|
||||
residual_out_flat =
|
||||
store_output_before_norm ? residual_out.value().flatten(0, -2) : residual_out_flat;
|
||||
} else if (input.dim() > 3) {
|
||||
input_flat = input.flatten(0, -3);
|
||||
output_flat = output.flatten(0, -3);
|
||||
residual_flat = has_residual ? residual.value().flatten(0, -3) : residual_flat;
|
||||
residual_out_flat =
|
||||
store_output_before_norm ? residual_out.value().flatten(0, -3) : residual_out_flat;
|
||||
} else { // 2-dim or 3-dim
|
||||
input_flat = input;
|
||||
output_flat = output;
|
||||
residual_flat = has_residual ? residual.value() : residual_flat;
|
||||
residual_out_flat = store_output_before_norm ? residual_out.value() : residual_out_flat;
|
||||
}
|
||||
TORCH_CHECK(input.data_ptr() == input_flat.data_ptr(), "check the strides of input.");
|
||||
TORCH_CHECK(output_flat.data_ptr() == output.data_ptr(), "check the strides of output.");
|
||||
if (has_residual)
|
||||
TORCH_CHECK(residual.value().data_ptr() == residual_flat.data_ptr(),
|
||||
"check the strides of residual.");
|
||||
if (store_output_before_norm)
|
||||
TORCH_CHECK(residual_out.value().data_ptr() == residual_out_flat.data_ptr(),
|
||||
"check the strides of residual_out.");
|
||||
|
||||
// create tensor desc
|
||||
auto descs = createTensorDescs({input_flat, gamma.value_or(at::Tensor()),
|
||||
beta.value_or(at::Tensor()), bias.value_or(at::Tensor()),
|
||||
residual_flat, quant_scale.value_or(at::Tensor()),
|
||||
residual_out_flat, output_flat, smooth_quant_scale_flat});
|
||||
auto compute_dtype = getCnnlDataType(input_flat.scalar_type());
|
||||
// forward
|
||||
auto handle = torch_mlu::getCurrentHandle();
|
||||
FusedNormTheory obj(input_flat.size(0), input_flat.size(-1), has_residual, bias.has_value(),
|
||||
quant_out, dynamic_quant, store_output_before_norm, compute_dtype, norm_mode);
|
||||
cnpxPush(obj);
|
||||
CNNL_CHECK_FATAL(cnnlFuseNorm_v3(handle, descs[0].get(), getAtTensorPtr(input_flat), // input
|
||||
descs[5].get(), getAtTensorPtr(quant_scale), // input_scale
|
||||
descs[1].get(), getAtTensorPtr(gamma), // norm_scale
|
||||
descs[2].get(), getAtTensorPtr(beta), // norm_bias
|
||||
descs[4].get(), getAtTensorPtr(residual_flat), // residual
|
||||
descs[3].get(), getAtTensorPtr(bias), // bias
|
||||
eps, output_quant_scheme, store_output_before_norm, mode,
|
||||
compute_dtype, nullptr, 0, // set workspace
|
||||
descs[7].get(), getAtTensorPtr(output_flat), // output
|
||||
descs[6].get(),
|
||||
getAtTensorPtr(residual_out_flat), // residual_out
|
||||
descs[8].get(), getAtTensorPtr(smooth_quant_scale_flat)));
|
||||
cnpxPop();
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
373
torch_mlu_ops-v1.3.2/csrc/torch_api/fused_moe.cpp
Normal file
373
torch_mlu_ops-v1.3.2/csrc/torch_api/fused_moe.cpp
Normal file
@@ -0,0 +1,373 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include "kernels/moe/moe.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
const std::string arch_370 = "MLU370";
|
||||
using GroupGemmDesc = tmo::op_desc::GroupGemmDesc;
|
||||
using QuantMode = tmo::op_desc::GroupGemmDesc::QuantMode;
|
||||
|
||||
at::Tensor fused_moe(const at::Tensor &hidden_states,
|
||||
const at::Tensor &gating_output,
|
||||
const at::Tensor &w1,
|
||||
const at::Tensor &w2,
|
||||
const c10::optional<at::Tensor> &bias1,
|
||||
const c10::optional<at::Tensor> &bias2,
|
||||
const c10::optional<at::Tensor> &residual,
|
||||
const c10::optional<at::Tensor> &input_smooth,
|
||||
const c10::optional<at::Tensor> &act_smooth,
|
||||
const c10::optional<at::Tensor> &w1_scale,
|
||||
const c10::optional<at::Tensor> &w2_scale,
|
||||
const c10::optional<at::List<int64_t>> &w1_quant_flag,
|
||||
const c10::optional<at::List<int64_t>> &w2_quant_flag,
|
||||
const int64_t topk,
|
||||
const bool renormalize,
|
||||
const bool gated,
|
||||
const std::string &act_mode,
|
||||
const int64_t start_expert_id,
|
||||
const int64_t block_n,
|
||||
const int64_t cncl_comm) {
|
||||
auto sizes_1 = hidden_states.sizes();
|
||||
auto sizes_2 = gating_output.sizes();
|
||||
auto w2_shape = w2.sizes();
|
||||
TORCH_CHECK(sizes_1.size() == sizes_2.size(),
|
||||
"hidden_states and gating_output must have the same rank.")
|
||||
TORCH_CHECK(sizes_1.size() == 2 || sizes_1.size() == 3, "hidden_states must be 2-D or 3-D.")
|
||||
TORCH_CHECK(sizes_1[0] == sizes_2[0],
|
||||
"hidden_states and gating_output must have the same batch.");
|
||||
if (sizes_1.size() == 3) {
|
||||
TORCH_CHECK(sizes_1[1] == sizes_2[1],
|
||||
"hidden_states and gating_output must have the same seq.");
|
||||
}
|
||||
if (residual.has_value()) {
|
||||
TORCH_CHECK(residual.value().sizes() == sizes_1,
|
||||
"hidden_states and residual must have the same shape.");
|
||||
}
|
||||
TORCH_CHECK(!bias1.has_value() && !bias2.has_value(), "Currently not support bias1 and bias2.")
|
||||
TORCH_CHECK(w1.dim() == 3 || w1.dim() == 1, "w1 should be 1-D or 3-D.")
|
||||
TORCH_CHECK(w2.dim() == 3 || w2.dim() == 4 || w2.dim() == 1, "w2 should be 1-D or 3-D or 4-D.")
|
||||
TORCH_CHECK((w1.dim() == 1 && w2.dim() == 1) || (w1.dim() != 1 && w2.dim() != 1),
|
||||
"w1 and w2 should be both 1-D or not both 1-D.")
|
||||
if (w1.dim() == 1) {
|
||||
TORCH_CHECK(w1_quant_flag.has_value() && w2_quant_flag.has_value(),
|
||||
"w1_quant_flag and w2_quant_flag need to exist simultaneously.");
|
||||
}
|
||||
// check contiguous
|
||||
CHECK_TENSOR_CONTIGUOUS(hidden_states)
|
||||
CHECK_TENSOR_CONTIGUOUS(gating_output)
|
||||
CHECK_TENSOR_CONTIGUOUS(w1)
|
||||
CHECK_TENSOR_CONTIGUOUS(w2)
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(act_smooth)
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(w1_scale)
|
||||
const int64_t hidden_size = hidden_states.size(-1);
|
||||
const int64_t num_expert = gating_output.size(-1);
|
||||
const int64_t expert_size = w1_quant_flag.has_value() ? w1_scale.value().size(1) : w1.size(0);
|
||||
auto hidden_states_ = hidden_states.view({-1, hidden_size});
|
||||
auto gating_output_ = gating_output.view({-1, num_expert});
|
||||
const int64_t num_token = hidden_states_.size(0);
|
||||
const int64_t num_expand_token = num_token * topk;
|
||||
const int64_t gemm1_co = w1_quant_flag.has_value() ? w1_scale.value().size(2) : w1.size(1);
|
||||
const int64_t inner_size = gemm1_co / (1 + gated);
|
||||
bool has_input_smooth = input_smooth.has_value();
|
||||
bool has_act_smooth = act_smooth.has_value();
|
||||
bool has_w1_scale = w1_scale.has_value();
|
||||
bool has_w2_scale = w2_scale.has_value();
|
||||
int opt_num = has_input_smooth + has_act_smooth + has_w1_scale + has_w2_scale;
|
||||
QuantMode quant_mode = QuantMode::noQuant;
|
||||
bool quant_grouped = false;
|
||||
|
||||
bool per_token_sq = opt_num == 4;
|
||||
TORCH_CHECK(opt_num == 0 || opt_num == 4,
|
||||
"input_smooth, act_smooth, w1_scale and w2_scale must be present and absent at the "
|
||||
"same time.")
|
||||
if (per_token_sq) {
|
||||
TORCH_CHECK(input_smooth.value().dtype() == torch::kFloat32,
|
||||
"the data type of input_smooth must be float.");
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(input_smooth, act_smooth, w1_scale, w2_scale);
|
||||
CHECK_SHAPE(input_smooth.value(), expert_size, hidden_size);
|
||||
CHECK_SHAPE(act_smooth.value(), expert_size, inner_size);
|
||||
|
||||
quant_grouped = w1_scale.value().dim() == 3 ? true : false;
|
||||
if (quant_grouped) {
|
||||
CHECK_SHAPE(w1_scale.value(), w1_scale.value().size(0), expert_size, gemm1_co);
|
||||
CHECK_SHAPE(w2_scale.value(), w2_scale.value().size(0), expert_size, hidden_size);
|
||||
} else {
|
||||
CHECK_SHAPE(w1_scale.value(), expert_size, gemm1_co);
|
||||
CHECK_SHAPE(w2_scale.value(), expert_size, hidden_size);
|
||||
}
|
||||
|
||||
if (w1_quant_flag.has_value()) {
|
||||
quant_mode = QuantMode::W4W8;
|
||||
quant_grouped = true;
|
||||
} else {
|
||||
TORCH_CHECK(hidden_size == w1.size(-1) || hidden_size == w1.size(-1) * 2,
|
||||
"hidden_size == w1.size(-1) || hidden_size == w1.size(-1) * 2.");
|
||||
quant_mode = hidden_size == w1.size(-1) ? QuantMode::W8 : QuantMode::W4;
|
||||
}
|
||||
}
|
||||
|
||||
if (quant_mode == QuantMode::W8 || quant_mode == QuantMode::noQuant) {
|
||||
CHECK_SHAPE(w1, expert_size, gemm1_co, hidden_size);
|
||||
if (w2.dim() == 3) {
|
||||
CHECK_SHAPE(w2, expert_size, hidden_size, inner_size);
|
||||
} else {
|
||||
TORCH_CHECK(w2_shape[0] * w2_shape[2] == expert_size,
|
||||
"w2_shape[0] * w2_shape[2] == expert_size");
|
||||
CHECK_SHAPE(w2, w2_shape[0], hidden_size, w2_shape[2], inner_size);
|
||||
}
|
||||
} else if (quant_mode == QuantMode::W4) {
|
||||
CHECK_SHAPE(w1, expert_size, gemm1_co, hidden_size / 2);
|
||||
CHECK_SHAPE(w2, expert_size, hidden_size, inner_size / 2);
|
||||
}
|
||||
|
||||
if (bias1.has_value()) {
|
||||
CHECK_SHAPE(bias1.value(), num_expert, gemm1_co);
|
||||
}
|
||||
if (bias2.has_value()) {
|
||||
CHECK_SHAPE(bias2.value(), num_expert, hidden_size);
|
||||
}
|
||||
TORCH_CHECK(topk <= num_expert, "topk <= num_expert.")
|
||||
TORCH_CHECK(act_mode == "silu" || act_mode == "gelu",
|
||||
"act_mode must be 'silu' or 'gelu', but got ", act_mode)
|
||||
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(hidden_states, bias1, bias2);
|
||||
checkTensorSameAttr<TensorAttr::DEVICE>(hidden_states, gating_output, w1, w2, residual,
|
||||
input_smooth, act_smooth, w1_scale, w2_scale);
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(hidden_states_.device());
|
||||
torch_mlu::DeviceProp *dev_prop = torch_mlu::getDeviceProperties(hidden_states.get_device());
|
||||
std::string dev_name = dev_prop->name;
|
||||
bool is_mlu370 = arch_370.compare(3, 3, dev_name, 3, 3) >= 0 ? true : false;
|
||||
auto handle = torch_mlu::getCurrentHandle();
|
||||
auto gating_output_dtype = getCnnlDataType(gating_output.scalar_type());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
auto tensor_options = hidden_states_.options();
|
||||
auto data_dtype = getCnnlDataType(hidden_states.scalar_type());
|
||||
auto weight_dtype = getCnnlDataType(w1.scalar_type());
|
||||
auto input_dtype = per_token_sq ? CNNL_DTYPE_INT8 : weight_dtype;
|
||||
if (quant_mode == QuantMode::W4) {
|
||||
weight_dtype = CNNL_DTYPE_INT4X2;
|
||||
}
|
||||
|
||||
// create tensors
|
||||
auto reduce_weight = at::empty({num_token, topk}, tensor_options.dtype(torch::kFloat));
|
||||
auto expert_id = at::empty({num_token, topk}, tensor_options.dtype(torch::kInt32));
|
||||
auto int32_idx = at::empty({2, num_token, topk}, tensor_options.dtype(torch::kInt32));
|
||||
auto gather_expand_idx = int32_idx[0];
|
||||
auto gather_combine_idx = int32_idx[1];
|
||||
auto token_count = at::empty({num_expert}, tensor_options.dtype(torch::kInt32));
|
||||
auto cumsum_token_count = at::empty({num_expert + 1}, tensor_options.dtype(torch::kInt32));
|
||||
auto gen_idx_workspace =
|
||||
at::empty({num_expert + 1 + num_expand_token}, tensor_options.dtype(torch::kInt32));
|
||||
int64_t mid_dim = (1 + gated) * (1 + per_token_sq);
|
||||
auto gemm1_output = at::empty({num_expand_token, mid_dim, inner_size}, w1.options());
|
||||
auto gemm2_output = at::empty({num_expand_token, hidden_size}, tensor_options);
|
||||
auto quant_input = at::Tensor();
|
||||
auto input_scale = at::Tensor();
|
||||
auto act_scale = at::Tensor();
|
||||
auto expand_hidden_state = at::Tensor();
|
||||
if (is_mlu370 || !per_token_sq) {
|
||||
expand_hidden_state = at::empty({num_expand_token, hidden_size}, tensor_options);
|
||||
}
|
||||
if (per_token_sq) {
|
||||
quant_input = at::empty({num_expand_token, hidden_size}, tensor_options.dtype(torch::kInt8));
|
||||
input_scale = at::empty({num_expand_token}, tensor_options.dtype(torch::kFloat));
|
||||
act_scale = at::empty({num_expand_token}, tensor_options.dtype(torch::kFloat));
|
||||
}
|
||||
|
||||
//=========================================1:topk_softmax=========================================
|
||||
int normalize_mode = renormalize ? 1 : 0;
|
||||
tmo::invokeMoeSoftmaxTopkKernel(queue, (float *)getAtTensorPtr(reduce_weight),
|
||||
(int *)getAtTensorPtr(expert_id), getAtTensorPtr(gating_output_),
|
||||
nullptr, num_token, num_expert, 1, topk, -1, 0,
|
||||
gating_output_dtype, normalize_mode);
|
||||
//=========================================2:generate_idx=========================================
|
||||
tmo::invokeMoeGenIdxKernel(
|
||||
queue, (int *)getAtTensorPtr(gather_expand_idx), (int *)getAtTensorPtr(gather_combine_idx),
|
||||
(int *)getAtTensorPtr(token_count), (int *)getAtTensorPtr(cumsum_token_count),
|
||||
getAtTensorPtr(gen_idx_workspace), getAtTensorPtr(expert_id), num_token, num_expert, topk);
|
||||
if (per_token_sq) {
|
||||
//========================================3:pertoken_sq(optinal)==================================
|
||||
if (is_mlu370) {
|
||||
tmo::invokeMoeExpandInputKernel(
|
||||
queue, getAtTensorPtr(expand_hidden_state), getAtTensorPtr(hidden_states_),
|
||||
(int *)getAtTensorPtr(gather_expand_idx), (int *)getAtTensorPtr(cumsum_token_count),
|
||||
num_token, hidden_size, topk, data_dtype, num_expert, start_expert_id, expert_size);
|
||||
SmoothQuantTheory sm_obj1(num_expand_token, hidden_size, data_dtype, "per_token");
|
||||
cnpxPush(sm_obj1);
|
||||
tmo::ops::SmoothQuant(
|
||||
handle, getAtTensorPtr(expand_hidden_state), getAtTensorPtr(input_smooth),
|
||||
getAtTensorPtr(token_count[start_expert_id]), nullptr, nullptr,
|
||||
getAtTensorPtr(quant_input), getAtTensorPtr(input_scale), num_expand_token, hidden_size,
|
||||
expert_size, hidden_size, hidden_size, 1, data_dtype);
|
||||
cnpxPop();
|
||||
} else {
|
||||
SmoothQuantTheory sm_obj1(num_expand_token, hidden_size, data_dtype, "per_token");
|
||||
cnpxPush(sm_obj1);
|
||||
tmo::ops::SmoothQuant(handle, getAtTensorPtr(hidden_states_), getAtTensorPtr(input_smooth),
|
||||
getAtTensorPtr(token_count[start_expert_id]),
|
||||
getAtTensorPtr(gather_expand_idx),
|
||||
getAtTensorPtr(cumsum_token_count[start_expert_id]),
|
||||
getAtTensorPtr(quant_input), getAtTensorPtr(input_scale), num_token,
|
||||
hidden_size, expert_size, hidden_size, hidden_size, topk, data_dtype);
|
||||
cnpxPop();
|
||||
}
|
||||
} else {
|
||||
//========================================3:expand_input========================================
|
||||
tmo::invokeMoeExpandInputKernel(
|
||||
queue, getAtTensorPtr(expand_hidden_state), getAtTensorPtr(hidden_states_),
|
||||
(int *)getAtTensorPtr(gather_expand_idx), (int *)getAtTensorPtr(cumsum_token_count),
|
||||
num_token, hidden_size, topk, data_dtype, num_expert, start_expert_id, expert_size);
|
||||
}
|
||||
//========================================4:group_gemm1===========================================
|
||||
GroupGemmDesc group_gemm_desc1(
|
||||
expert_size, num_token /*the maximum number of token that may be processed by each expert*/,
|
||||
gemm1_co, hidden_size, data_dtype, false, quant_mode);
|
||||
group_gemm_desc1.setInputOutputTensor(
|
||||
input_dtype, weight_dtype, data_dtype, CNNL_DTYPE_INT32,
|
||||
per_token_sq ? getAtTensorPtr(quant_input) : getAtTensorPtr(expand_hidden_state),
|
||||
getAtTensorPtr(w1), nullptr, getAtTensorPtr(gemm1_output), nullptr, hidden_size, hidden_size,
|
||||
num_expand_token, false);
|
||||
std::vector<int> w1_flag_vec;
|
||||
if (per_token_sq) {
|
||||
if (w1_quant_flag.has_value()) {
|
||||
auto vec = w1_quant_flag.value().vec();
|
||||
w1_flag_vec.resize(vec.size());
|
||||
std::copy(vec.begin(), vec.end(), w1_flag_vec.begin());
|
||||
}
|
||||
group_gemm_desc1.setPerRowColScaleBiasAct(
|
||||
getAtTensorPtr(input_scale), getAtTensorPtr(w1_scale), w1_flag_vec.data(), nullptr,
|
||||
data_dtype, quant_grouped ? hidden_size : 0,
|
||||
quant_grouped ? hidden_size / w1_scale.value().size(0) : 0, 0);
|
||||
}
|
||||
size_t group_gemm1_wsize =
|
||||
tmo::ops::getGroupGemmWorkspaceSize(handle, group_gemm_desc1, expert_size);
|
||||
auto group_gemm1_workspace = at::empty({static_cast<int64_t>(group_gemm1_wsize)},
|
||||
hidden_states.options().dtype(at::kByte));
|
||||
std::vector<int> ldb_array;
|
||||
if (quant_mode != QuantMode::W4W8) ldb_array.assign(expert_size, hidden_size);
|
||||
GroupGemmTheory gg_obj1(num_expand_token, expert_size, hidden_size, gemm1_co, false /*has_res*/,
|
||||
input_dtype, data_dtype);
|
||||
cnpxPush(gg_obj1);
|
||||
tmo::ops::GroupGemm(handle, group_gemm_desc1, getAtTensorPtr(token_count[start_expert_id]),
|
||||
nullptr, nullptr, getAtTensorPtr(group_gemm1_workspace), group_gemm1_wsize,
|
||||
expert_size, hidden_size, /*k*/
|
||||
gemm1_co, /*n*/
|
||||
hidden_size /*lda*/, ldb_array /*ldb*/);
|
||||
cnpxPop();
|
||||
//========================================5:activation============================================
|
||||
cnnlActivationMode_t act_type = act_mode == "silu" ? CNNL_ACTIVATION_SWISH : CNNL_ACTIVATION_GELU;
|
||||
GroupAddBiasActiveTheory add_bias_obj(expert_size, num_expand_token, inner_size, gated,
|
||||
bias1.has_value(), data_dtype, act_mode);
|
||||
cnpxPush(add_bias_obj);
|
||||
tmo::invokeGroupAddBiasActivationKernel(
|
||||
queue, getAtTensorPtr(gemm1_output), getAtTensorPtr(gemm1_output),
|
||||
nullptr /*getAtTensorPtr(bias1)*/, (int *)getAtTensorPtr(cumsum_token_count), num_expert,
|
||||
num_expand_token, inner_size, gemm1_co, data_dtype, gated, act_type, start_expert_id,
|
||||
expert_size, 1.0f);
|
||||
cnpxPop();
|
||||
//========================================6:smooth_quant==========================================
|
||||
if (per_token_sq) {
|
||||
SmoothQuantTheory sm_obj2(num_expand_token, inner_size * expert_size, data_dtype, "per_token");
|
||||
cnpxPush(sm_obj2);
|
||||
tmo::ops::SmoothQuant(handle, getAtTensorPtr(gemm1_output), getAtTensorPtr(act_smooth),
|
||||
getAtTensorPtr(token_count[start_expert_id]), nullptr /*gather_idx*/,
|
||||
nullptr, getAtTensorPtr(gemm1_output), getAtTensorPtr(act_scale),
|
||||
num_expand_token, inner_size, expert_size, gemm1_co /*input_stride*/,
|
||||
2 * gemm1_co /*output_stride*/, 1 /*topk*/, data_dtype);
|
||||
cnpxPop();
|
||||
}
|
||||
if (cncl_comm > 0) {
|
||||
TORCH_CHECK(num_expert == expert_size, "expert_size must be num_expert when cncl_comm > 0");
|
||||
c10::optional<at::Tensor> act_scale_opt(act_scale);
|
||||
c10::optional<at::Tensor> w2_scale_opt(w2_scale);
|
||||
std::string dtype_s = torchDtype2Str(hidden_states.scalar_type());
|
||||
auto gg_input =
|
||||
gemm1_output.as_strided({num_expand_token, inner_size}, {mid_dim * inner_size, 1});
|
||||
auto output = group_gemm_combine_result_allreduce(
|
||||
cncl_comm, gg_input, w2, token_count, gather_combine_idx, reduce_weight, c10::nullopt,
|
||||
c10::nullopt, c10::nullopt, per_token_sq ? act_scale_opt : c10::nullopt,
|
||||
per_token_sq ? w2_scale_opt : c10::nullopt, dtype_s, num_token, topk, block_n);
|
||||
if (residual.has_value()) {
|
||||
output.view(sizes_1) += residual.value();
|
||||
}
|
||||
return output.view(sizes_1);
|
||||
} else {
|
||||
//========================================8:group_gemm2===========================================
|
||||
GroupGemmDesc group_gemm_desc2(expert_size, num_token, hidden_size, inner_size, data_dtype,
|
||||
false, quant_mode);
|
||||
std::vector<int64_t> b_offset(expert_size);
|
||||
int w2_ldb = inner_size;
|
||||
if (w2.dim() == 4) {
|
||||
w2_ldb = w2_shape[2] * inner_size;
|
||||
auto elem_size = w2.element_size();
|
||||
for (int64_t i = 0; i < expert_size; i++) {
|
||||
b_offset[i] = (i / w2_shape[2] * w2.stride(0) + i % w2_shape[2] * w2_shape[3]) * elem_size;
|
||||
}
|
||||
}
|
||||
group_gemm_desc2.setInputOutputTensor(input_dtype, weight_dtype, data_dtype, CNNL_DTYPE_INT32,
|
||||
getAtTensorPtr(gemm1_output), getAtTensorPtr(w2), nullptr,
|
||||
getAtTensorPtr(gemm2_output), nullptr, inner_size,
|
||||
gemm1_co * (per_token_sq + 1), num_expand_token, false,
|
||||
w2.dim() == 4 ? b_offset.data() : nullptr);
|
||||
std::vector<int> w2_flag_vec;
|
||||
if (per_token_sq) {
|
||||
if (w2_quant_flag.has_value()) {
|
||||
auto vec = w2_quant_flag.value().vec();
|
||||
w2_flag_vec.resize(vec.size());
|
||||
std::copy(vec.begin(), vec.end(), w2_flag_vec.begin());
|
||||
}
|
||||
group_gemm_desc2.setPerRowColScaleBiasAct(
|
||||
getAtTensorPtr(act_scale), getAtTensorPtr(w2_scale), w2_flag_vec.data(), nullptr,
|
||||
data_dtype, quant_grouped ? inner_size : 0,
|
||||
quant_grouped ? inner_size / w2_scale.value().size(0) : 0, 0);
|
||||
}
|
||||
|
||||
size_t group_gemm2_wsize =
|
||||
tmo::ops::getGroupGemmWorkspaceSize(handle, group_gemm_desc2, expert_size);
|
||||
auto group_gemm2_workspace = at::empty({static_cast<int64_t>(group_gemm2_wsize)},
|
||||
hidden_states.options().dtype(at::kByte));
|
||||
std::vector<int> w2_ldb_array;
|
||||
if (quant_mode != QuantMode::W4W8) w2_ldb_array.assign(expert_size, w2_ldb);
|
||||
GroupGemmTheory gg_obj2(num_expand_token, expert_size, inner_size, hidden_size,
|
||||
false /*has_res*/, input_dtype, data_dtype);
|
||||
cnpxPush(gg_obj2);
|
||||
tmo::ops::GroupGemm(handle, group_gemm_desc2, getAtTensorPtr(token_count[start_expert_id]),
|
||||
nullptr, nullptr, getAtTensorPtr(group_gemm2_workspace), group_gemm2_wsize,
|
||||
expert_size, inner_size, /*k*/
|
||||
hidden_size, /*n*/
|
||||
gemm1_co * (per_token_sq + 1) /*lda*/, w2_ldb_array /*ldb*/);
|
||||
cnpxPop();
|
||||
//========================================9:combine_result=======================================
|
||||
auto output = at::empty(hidden_states_.sizes(), hidden_states_.options());
|
||||
MoeCombineResultTheory cr_obj(num_token, topk, hidden_size, expert_size, bias2.has_value(),
|
||||
residual.has_value(), data_dtype);
|
||||
cnpxPush(cr_obj);
|
||||
tmo::invokeMoeCombineResultKernel(
|
||||
queue, getAtTensorPtr(output), getAtTensorPtr(gemm2_output), nullptr,
|
||||
getAtTensorPtr(residual), (float *)getAtTensorPtr(reduce_weight),
|
||||
(int *)getAtTensorPtr(cumsum_token_count), (int *)getAtTensorPtr(gather_combine_idx),
|
||||
num_token, topk, num_expert, hidden_size, start_expert_id, expert_size, data_dtype);
|
||||
cnpxPop();
|
||||
return output.view(sizes_1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
281
torch_mlu_ops-v1.3.2/csrc/torch_api/fused_rope.cpp
Normal file
281
torch_mlu_ops-v1.3.2/csrc/torch_api/fused_rope.cpp
Normal file
@@ -0,0 +1,281 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "kernels/fused_rope.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
void fused_rope(at::Tensor &qkv,
|
||||
at::Tensor &key_cache_hp,
|
||||
at::Tensor &value_cache_hp,
|
||||
const c10::optional<at::Tensor> &key_cache_lp,
|
||||
const c10::optional<at::Tensor> &value_cache_lp,
|
||||
const at::Tensor &sin_table,
|
||||
const at::Tensor &cos_table,
|
||||
const at::Tensor &position_ids,
|
||||
const at::Tensor &gamma,
|
||||
const at::Tensor &beta,
|
||||
const c10::optional<at::Tensor> &key_scale_hp,
|
||||
const c10::optional<at::Tensor> &value_scale_hp,
|
||||
const c10::optional<at::Tensor> &key_scale_lp,
|
||||
const c10::optional<at::Tensor> &value_scale_lp,
|
||||
const c10::optional<at::Tensor> &cache_bs_id_hp,
|
||||
const c10::optional<at::Tensor> &cache_seq_offsets_hp,
|
||||
const c10::optional<at::Tensor> &cache_bs_id_lp,
|
||||
const c10::optional<at::Tensor> &cache_seq_offsets_lp,
|
||||
const c10::optional<at::Tensor> &slot_mapping_hp,
|
||||
const c10::optional<at::Tensor> &slot_mapping_lp,
|
||||
const double eps) {
|
||||
bool paged_cache_hp = slot_mapping_hp.has_value();
|
||||
bool paged_cache_lp = slot_mapping_lp.has_value();
|
||||
bool mixed_cache = key_cache_lp.has_value() && value_cache_lp.has_value();
|
||||
const int origin_device_id = qkv.get_device();
|
||||
|
||||
checkTensorSameAttr<TensorAttr::ALL>(qkv, sin_table, cos_table, gamma, beta);
|
||||
TORCH_CHECK(qkv.is_contiguous(), "qkv tensor must be contiguous.")
|
||||
TORCH_CHECK(key_cache_hp.is_contiguous(), "key_cache_hp tensor must be contiguous")
|
||||
TORCH_CHECK(value_cache_hp.is_contiguous(), "value_cache_hp tensor must be contiguous")
|
||||
if (mixed_cache) {
|
||||
TORCH_CHECK(key_cache_lp.value().is_contiguous(), "key_cache_lp tensor must be contiguous")
|
||||
TORCH_CHECK(key_scale_hp.has_value() && value_scale_hp.has_value(),
|
||||
"key_scale_hp and value_scale_hp must not be null under mixed cache.")
|
||||
}
|
||||
if (mixed_cache) {
|
||||
TORCH_CHECK(value_cache_lp.value().is_contiguous(), "value_cache_lp tensor must be contiguous")
|
||||
}
|
||||
TORCH_CHECK(position_ids.is_contiguous(), "position_ids tensor must be contiguous");
|
||||
TORCH_CHECK(gamma.is_contiguous(), "gamma tensor must be contiguous");
|
||||
TORCH_CHECK(beta.is_contiguous(), "beta tensor must be contiguous");
|
||||
if (cache_bs_id_hp.has_value()) {
|
||||
TORCH_CHECK(cache_bs_id_hp.value().is_contiguous(), "cache_bs_id_hp tensor must be contiguous.")
|
||||
}
|
||||
if (cache_seq_offsets_hp.has_value()) {
|
||||
TORCH_CHECK(cache_seq_offsets_hp.value().is_contiguous(),
|
||||
"cache_seq_offsets_hp tensor must be contiguous")
|
||||
}
|
||||
if (cache_bs_id_lp.has_value()) {
|
||||
TORCH_CHECK(cache_bs_id_lp.value().is_contiguous(), "cache_bs_id_lp tensor must be contiguous.")
|
||||
}
|
||||
if (cache_seq_offsets_lp.has_value()) {
|
||||
TORCH_CHECK(cache_seq_offsets_lp.value().is_contiguous(),
|
||||
"cache_seq_offsets_lp tensor must be contiguous")
|
||||
}
|
||||
if (paged_cache_hp) {
|
||||
TORCH_CHECK(slot_mapping_hp.value().is_contiguous(),
|
||||
"slot_mapping_hp tensor must be contiguous.")
|
||||
}
|
||||
if (paged_cache_lp) {
|
||||
TORCH_CHECK(slot_mapping_lp.value().is_contiguous(),
|
||||
"slot_mapping_lp tensor must be contiguous.")
|
||||
}
|
||||
|
||||
// check qkv key_cache value_cache
|
||||
auto dtype = qkv.dtype();
|
||||
TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16,
|
||||
"qkv dtype must be half or bfloat16.")
|
||||
TORCH_CHECK(qkv.dim() == 4,
|
||||
"qkv tensor dim must be 4: (batch_size, 1, head_num_q + head_num_kv * 2, head_size).")
|
||||
TORCH_CHECK(qkv.size(1) == 1, "only support seq_len == 1.");
|
||||
TORCH_CHECK(key_cache_hp.dim() == 4,
|
||||
"key_cache_hp dim must be 4: (max_bs, head_num_kv, max_decode_len_hp, head_size),"
|
||||
"or (num_blocks, head_num_kv, block_size_hp, head_size).");
|
||||
TORCH_CHECK(value_cache_hp.dim() == 4,
|
||||
"value_cache_hp dim must be 4: (max_bs, head_num_kv, max_decode_len_hp, head_size),"
|
||||
"or (num_blocks, head_num_kv, block_size_hp, head_size).");
|
||||
void *key_cache_lp_ptr = nullptr;
|
||||
void *value_cache_lp_ptr = nullptr;
|
||||
if (mixed_cache) {
|
||||
TORCH_CHECK(
|
||||
key_cache_lp.value().dim() == 4,
|
||||
"key_cache_lp dim must be 4: (max_bs, head_num_kv, max_decode_len_lp, head_size / 2),"
|
||||
"or (num_blocks, head_num_kv, block_size_lp, head_size / 2).");
|
||||
TORCH_CHECK(
|
||||
value_cache_lp.value().dim() == 4,
|
||||
"value_cache_lp dim must be 4: (max_bs, head_num_kv, max_decode_len_lp / 2, head_size),"
|
||||
"or (num_blocks, head_num_kv, block_size_lp / 2, head_size).");
|
||||
TORCH_CHECK(key_cache_lp.value().get_device() == origin_device_id,
|
||||
"key_scale_hp tensor device index is not same, original index: ", origin_device_id,
|
||||
"now index is: ", key_cache_lp.value().get_device());
|
||||
TORCH_CHECK(value_cache_lp.value().get_device() == origin_device_id,
|
||||
"value_scale_hp tensor device index is not same, original index: ",
|
||||
origin_device_id, "now index is: ", value_cache_lp.value().get_device());
|
||||
key_cache_lp_ptr = key_cache_lp.value().data_ptr();
|
||||
value_cache_lp_ptr = value_cache_lp.value().data_ptr();
|
||||
}
|
||||
|
||||
int batch_size = qkv.size(0);
|
||||
int head_num_qkv = qkv.size(-2);
|
||||
int head_size = qkv.size(-1);
|
||||
int head_num_kv = key_cache_hp.size(1); // linear cache and paged cache same
|
||||
int max_bs_hp = key_cache_hp.size(0); // for linear cache, paged cache no use
|
||||
int max_bs_lp = mixed_cache ? key_cache_lp.value().size(0) : 0;
|
||||
int max_decode_len_hp = key_cache_hp.size(2); // for linear cache, paged cache no use
|
||||
int max_decode_len_lp = mixed_cache ? key_cache_lp.value().size(2) : 0;
|
||||
int num_blocks_hp = key_cache_hp.size(0);
|
||||
int num_blocks_lp = mixed_cache ? key_cache_lp.value().size(0) : 0;
|
||||
int block_size_hp = key_cache_hp.size(2); // for paged cache, linear cache no use
|
||||
int block_size_lp = mixed_cache ? key_cache_lp.value().size(2) : 0;
|
||||
int head_num_q = head_num_qkv - 2 * head_num_kv;
|
||||
int group_size = head_size;
|
||||
|
||||
TORCH_CHECK(head_size <= 128, "only support qkv head_size <= 128.");
|
||||
TORCH_CHECK(head_size % 2 == 0, "head_size must be divided by 2.");
|
||||
TORCH_CHECK(head_num_q <= 32, "only support head_num_q <= 32.");
|
||||
TORCH_CHECK(head_num_kv <= 32, "only support head_num_kv <= 32.");
|
||||
int kv_out_size = key_scale_hp.has_value() && value_scale_hp.has_value() ? 1 : 2;
|
||||
size_t hp_cache_total_bytes =
|
||||
paged_cache_hp
|
||||
? (size_t)num_blocks_hp * head_num_kv * block_size_hp * head_size * kv_out_size
|
||||
: (size_t)max_bs_hp * max_decode_len_hp * head_num_kv * head_size * kv_out_size;
|
||||
TORCH_CHECK(hp_cache_total_bytes <= INT32_MAX,
|
||||
"hp_cache memory btyes must be less than or equal to INT32_MAX.");
|
||||
if (mixed_cache) {
|
||||
size_t lp_cache_total_bytes =
|
||||
paged_cache_lp ? (size_t)num_blocks_lp * head_num_kv * block_size_lp * head_size / 2
|
||||
: (size_t)max_bs_lp * max_decode_len_lp * head_num_kv * head_size / 2;
|
||||
TORCH_CHECK(lp_cache_total_bytes <= INT32_MAX,
|
||||
"lp_cache memory btyes must be less than or equal to INT32_MAX.");
|
||||
}
|
||||
|
||||
// check sin_table cos_table gamma beta
|
||||
TORCH_CHECK(sin_table.dim() == 2 && cos_table.dim() == 2,
|
||||
"sin_table and cos_table tensor dim must be 2: (rotary_seqlen, head_size).");
|
||||
TORCH_CHECK(sin_table.size(1) == head_size && cos_table.size(1) == head_size,
|
||||
"rotary_dim must be same with head_size.");
|
||||
TORCH_CHECK(sin_table.stride(0) == cos_table.stride(0),
|
||||
"sin_table first stride must be equal to cos_table first_stride.");
|
||||
int rotary_stride = sin_table.stride(0);
|
||||
|
||||
CHECK_SHAPE(gamma, head_size);
|
||||
CHECK_SHAPE(beta, head_size);
|
||||
|
||||
// check key value scale
|
||||
bool has_kv_scale = key_scale_hp.has_value() && value_scale_hp.has_value();
|
||||
const void *key_scale_hp_ptr = nullptr;
|
||||
const void *value_scale_hp_ptr = nullptr;
|
||||
if (has_kv_scale) {
|
||||
CHECK_SHAPE(key_scale_hp.value(), head_num_kv, head_size);
|
||||
CHECK_SHAPE(value_scale_hp.value(), head_num_kv, head_size);
|
||||
TORCH_CHECK(key_scale_hp.value().scalar_type() == torch::kFloat32,
|
||||
"key_scale_hp dtype must be float.");
|
||||
TORCH_CHECK(value_scale_hp.value().scalar_type() == torch::kFloat32,
|
||||
"value_scale_hp dtype must be float.");
|
||||
TORCH_CHECK(key_scale_hp.value().get_device() == origin_device_id,
|
||||
"key_scale_hp tensor device index is not same, original index: ", origin_device_id,
|
||||
"now index is: ", key_scale_hp.value().get_device());
|
||||
TORCH_CHECK(value_scale_hp.value().get_device() == origin_device_id,
|
||||
"value_scale_hp tensor device index is not same, original index: ",
|
||||
origin_device_id, "now index is: ", value_scale_hp.value().get_device());
|
||||
key_scale_hp_ptr = key_scale_hp.value().data_ptr();
|
||||
value_scale_hp_ptr = value_scale_hp.value().data_ptr();
|
||||
}
|
||||
|
||||
void *key_scale_lp_ptr = nullptr;
|
||||
void *value_scale_lp_ptr = nullptr;
|
||||
if (mixed_cache) {
|
||||
TORCH_CHECK(key_scale_lp.value().dim() == 4,
|
||||
"key_scale_lp dim must be 4: (max_bs, head_num_kv, max_decode_len_lp, group_num),"
|
||||
"or (num_blocks, head_num_kv, block_size_lp, group_num).");
|
||||
TORCH_CHECK(key_scale_lp.value().scalar_type() == torch::kFloat32,
|
||||
"key_scale_hp dtype must be float.");
|
||||
TORCH_CHECK(value_scale_lp.value().scalar_type() == torch::kFloat32,
|
||||
"value_scale_hp dtype must be float.");
|
||||
TORCH_CHECK(key_scale_lp.value().get_device() == origin_device_id,
|
||||
"key_scale_hp tensor device index is not same, original index: ", origin_device_id,
|
||||
"now index is: ", key_scale_lp.value().get_device());
|
||||
TORCH_CHECK(value_scale_lp.value().get_device() == origin_device_id,
|
||||
"value_scale_hp tensor device index is not same, original index: ",
|
||||
origin_device_id, "now index is: ", value_scale_lp.value().get_device());
|
||||
group_size = head_size / key_scale_lp.value().size(-1);
|
||||
key_scale_lp_ptr = key_scale_lp.value().data_ptr();
|
||||
value_scale_lp_ptr = value_scale_lp.value().data_ptr();
|
||||
}
|
||||
|
||||
// check cache_bs_id cache_seq_offsets
|
||||
const void *cache_bs_id_hp_ptr = nullptr;
|
||||
const void *cache_bs_id_lp_ptr = nullptr;
|
||||
bool has_cache_bs_id_hp = cache_bs_id_hp.has_value();
|
||||
bool has_cache_bs_id_lp = cache_bs_id_lp.has_value();
|
||||
if (has_cache_bs_id_hp) {
|
||||
CHECK_SHAPE(cache_bs_id_hp.value(), batch_size);
|
||||
TORCH_CHECK(cache_bs_id_hp.value().scalar_type() == torch::kInt32,
|
||||
"cache_bs_id_hp dtype need be torch::kInt32");
|
||||
TORCH_CHECK(cache_bs_id_hp.value().get_device() == origin_device_id,
|
||||
"cache_bs_id_hp tensor device index is not same, original index: ",
|
||||
origin_device_id, "now index is: ", cache_bs_id_hp.value().get_device());
|
||||
cache_bs_id_hp_ptr = cache_bs_id_hp.value().data_ptr();
|
||||
}
|
||||
if (mixed_cache && has_cache_bs_id_lp) {
|
||||
CHECK_SHAPE(cache_bs_id_lp.value(), batch_size);
|
||||
TORCH_CHECK(cache_bs_id_lp.value().scalar_type() == torch::kInt32,
|
||||
"cache_bs_id_lp dtype need be torch::kInt32");
|
||||
TORCH_CHECK(cache_bs_id_lp.value().get_device() == origin_device_id,
|
||||
"cache_bs_id_lp tensor device index is not same, original index: ",
|
||||
origin_device_id, "now index is: ", cache_bs_id_lp.value().get_device());
|
||||
cache_bs_id_lp_ptr = cache_bs_id_lp.value().data_ptr();
|
||||
}
|
||||
const void *slot_mapping_hp_ptr = nullptr;
|
||||
const void *slot_mapping_lp_ptr = nullptr;
|
||||
if (paged_cache_hp) {
|
||||
CHECK_SHAPE(slot_mapping_hp.value(), batch_size);
|
||||
TORCH_CHECK(slot_mapping_hp.value().scalar_type() == torch::kInt32,
|
||||
"slot_mapping_hp dtype need be torch::kInt32");
|
||||
TORCH_CHECK(slot_mapping_hp.value().get_device() == origin_device_id,
|
||||
"slot_mapping_hp tensor device index is not same, original index: ",
|
||||
origin_device_id, "now index is: ", slot_mapping_hp.value().get_device());
|
||||
slot_mapping_hp_ptr = slot_mapping_hp.value().data_ptr();
|
||||
}
|
||||
if (mixed_cache && paged_cache_lp) {
|
||||
CHECK_SHAPE(slot_mapping_lp.value(), batch_size);
|
||||
TORCH_CHECK(slot_mapping_lp.value().scalar_type() == torch::kInt32,
|
||||
"slot_mapping_lp dtype need be torch::kInt32");
|
||||
TORCH_CHECK(slot_mapping_lp.value().get_device() == origin_device_id,
|
||||
"slot_mapping_lp tensor device index is not same, original index: ",
|
||||
origin_device_id, "now index is: ", slot_mapping_lp.value().get_device());
|
||||
slot_mapping_lp_ptr = slot_mapping_lp.value().data_ptr();
|
||||
}
|
||||
const void *cache_seq_offsets_hp_ptr = nullptr;
|
||||
const void *cache_seq_offsets_lp_ptr = nullptr;
|
||||
if (!paged_cache_hp) {
|
||||
CHECK_SHAPE(cache_seq_offsets_hp.value(), batch_size);
|
||||
TORCH_CHECK(cache_seq_offsets_hp.value().dtype() == torch::kInt32,
|
||||
"cache_seq_offsets_hp dtype need be torch::kInt32");
|
||||
TORCH_CHECK(cache_seq_offsets_hp.value().get_device() == origin_device_id,
|
||||
"cache_seq_offsets_hp tensor device index is not same, original index:",
|
||||
origin_device_id, "now index is: ", cache_seq_offsets_hp.value().get_device());
|
||||
cache_seq_offsets_hp_ptr = cache_seq_offsets_hp.value().data_ptr();
|
||||
}
|
||||
if (mixed_cache && !paged_cache_lp) {
|
||||
CHECK_SHAPE(cache_seq_offsets_lp.value(), batch_size);
|
||||
TORCH_CHECK(cache_seq_offsets_lp.value().dtype() == torch::kInt32,
|
||||
"cache_seq_offsets_lp dtype need be torch::kInt32");
|
||||
TORCH_CHECK(cache_seq_offsets_lp.value().get_device() == origin_device_id,
|
||||
"cache_seq_offsets_lp tensor device index is not same, original index:",
|
||||
origin_device_id, "now index is: ", cache_seq_offsets_lp.value().get_device());
|
||||
cache_seq_offsets_lp_ptr = cache_seq_offsets_lp.value().data_ptr();
|
||||
}
|
||||
|
||||
auto queue = torch_mlu::getCurrentMLUStream();
|
||||
auto data_type = getCnnlDataType(qkv.scalar_type());
|
||||
invokeFusedRope(queue, getAtTensorPtr(qkv), getAtTensorPtr(key_cache_hp),
|
||||
getAtTensorPtr(value_cache_hp), key_cache_lp_ptr, value_cache_lp_ptr,
|
||||
getAtTensorPtr(sin_table), getAtTensorPtr(cos_table),
|
||||
getAtTensorPtr(position_ids), getAtTensorPtr(gamma), getAtTensorPtr(beta),
|
||||
key_scale_hp_ptr, value_scale_hp_ptr, key_scale_lp_ptr, value_scale_lp_ptr,
|
||||
cache_bs_id_hp_ptr, cache_seq_offsets_hp_ptr, cache_bs_id_lp_ptr,
|
||||
cache_seq_offsets_lp_ptr, slot_mapping_hp_ptr, slot_mapping_lp_ptr, rotary_stride,
|
||||
batch_size, head_num_q, head_num_kv, head_size, max_decode_len_hp,
|
||||
max_decode_len_lp, block_size_hp, block_size_lp, group_size, data_type, eps);
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
46
torch_mlu_ops-v1.3.2/csrc/torch_api/glue_ops.cpp
Normal file
46
torch_mlu_ops-v1.3.2/csrc/torch_api/glue_ops.cpp
Normal file
@@ -0,0 +1,46 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "glue_ops.h"
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
|
||||
void copy_blocks__functionalization_glue(
|
||||
const std::vector<torch::Tensor> &k_caches,
|
||||
const std::vector<torch::Tensor> &v_caches,
|
||||
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping) {
|
||||
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(k_caches));
|
||||
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(v_caches));
|
||||
at::functionalization::impl::sync(k_caches);
|
||||
at::functionalization::impl::sync(v_caches);
|
||||
auto new_k_caches = at::functionalization::impl::from_functional_tensor(k_caches);
|
||||
auto new_v_caches = at::functionalization::impl::from_functional_tensor(v_caches);
|
||||
|
||||
// Grab the dispatcher entry corresponding to the out-of-place op, "foo"
|
||||
static auto op_handle =
|
||||
c10::Dispatcher::singleton()
|
||||
// specify namespace::op_name, op_overload_name
|
||||
.findSchemaOrThrow("torch_mlu_ops::copy_blocks_out_of_place", "")
|
||||
// Specify the C++ schema of the out-of-place op.
|
||||
.typed<std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>>(
|
||||
const std::vector<at::Tensor> &k_caches, const std::vector<at::Tensor> &v_caches,
|
||||
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping)>();
|
||||
// Next, redispatch to the out-of-place op, foo() (user called foo_, we call fooo)
|
||||
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> tmp_output;
|
||||
{
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
tmp_output = op_handle.call(new_k_caches, new_v_caches, block_mapping);
|
||||
}
|
||||
// Finally, tell functionalization about this mutation.
|
||||
at::functionalization::impl::replace_(k_caches, std::get<0>(tmp_output));
|
||||
at::functionalization::impl::replace_(v_caches, std::get<1>(tmp_output));
|
||||
at::functionalization::impl::commit_update(k_caches);
|
||||
at::functionalization::impl::commit_update(v_caches);
|
||||
}
|
||||
22
torch_mlu_ops-v1.3.2/csrc/torch_api/glue_ops.h
Normal file
22
torch_mlu_ops-v1.3.2/csrc/torch_api/glue_ops.h
Normal file
@@ -0,0 +1,22 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_TORCH_API_GLUE_OPS_H_
|
||||
#define CSRC_TORCH_API_GLUE_OPS_H_
|
||||
#include "torch/extension.h"
|
||||
#include "utils.h"
|
||||
|
||||
void copy_blocks__functionalization_glue(
|
||||
const std::vector<torch::Tensor> &k_caches,
|
||||
const std::vector<torch::Tensor> &v_caches,
|
||||
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping);
|
||||
|
||||
#endif // CSRC_TORCH_API_GLUE_OPS_H_
|
||||
231
torch_mlu_ops-v1.3.2/csrc/torch_api/group_gemm.cpp
Normal file
231
torch_mlu_ops-v1.3.2/csrc/torch_api/group_gemm.cpp
Normal file
@@ -0,0 +1,231 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include "torch_ops_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
using GroupGemmDesc = op_desc::GroupGemmDesc;
|
||||
using QuantMode = GroupGemmDesc::QuantMode;
|
||||
|
||||
at::Tensor group_gemm(const at::Tensor &a_tensor,
|
||||
const at::Tensor &b_tensor,
|
||||
const at::Tensor &m_list,
|
||||
const c10::optional<at::Tensor> &gather_idx,
|
||||
const c10::optional<at::Tensor> &c_tensor,
|
||||
const c10::optional<at::Tensor> &alpha,
|
||||
const c10::optional<at::Tensor> &beta,
|
||||
const c10::optional<at::Tensor> &a_scale,
|
||||
const c10::optional<at::Tensor> &b_scale,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<std::string> &data_type,
|
||||
const c10::optional<at::List<int64_t>> &quant_flag,
|
||||
const c10::optional<at::Tensor> &b_offset,
|
||||
const int64_t max_m) {
|
||||
bool has_dtype = data_type.has_value();
|
||||
bool has_c_tensor = c_tensor.has_value();
|
||||
bool has_gather_idx = gather_idx.has_value();
|
||||
bool has_a_scale = a_scale.has_value();
|
||||
bool has_b_scale = b_scale.has_value();
|
||||
bool has_alpha = alpha.has_value();
|
||||
bool has_beta = beta.has_value();
|
||||
bool has_bias = bias.has_value();
|
||||
bool has_quant_flag = quant_flag.has_value();
|
||||
QuantMode quant_mode = QuantMode::noQuant;
|
||||
bool quant_grouped = false;
|
||||
|
||||
// check stride
|
||||
TORCH_CHECK(a_tensor.stride(-1) == 1, "a_tensor last dim must be contiguous");
|
||||
CHECK_TENSOR_CONTIGUOUS(m_list);
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(c_tensor);
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(gather_idx);
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(a_scale);
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(b_scale);
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(alpha);
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(beta);
|
||||
|
||||
// check dim
|
||||
TORCH_CHECK(m_list.dim() == 1, "the shape of m_list should be 2D");
|
||||
TORCH_CHECK(a_tensor.dim() == 2, "the shape of a_tensor should be 2D");
|
||||
TORCH_CHECK(b_tensor.dim() == 3 || b_tensor.dim() == 4 || (b_tensor.dim() == 1 && has_quant_flag),
|
||||
"the shape of b_tensor should be 3D or 4D");
|
||||
if (has_c_tensor) {
|
||||
TORCH_CHECK(c_tensor.value().dim() == 2, "the shape of c_tensor should be 2D");
|
||||
}
|
||||
if (has_gather_idx) {
|
||||
TORCH_CHECK(gather_idx.value().dim() == 1, "the shape of gather_idx should be 1D");
|
||||
}
|
||||
if (has_a_scale) {
|
||||
TORCH_CHECK(a_scale.value().dim() == 1, "the shape of a_scale should be 1D");
|
||||
}
|
||||
if (has_b_scale) {
|
||||
TORCH_CHECK(b_scale.value().dim() == 2 || b_scale.value().dim() == 3,
|
||||
"the shape of b_scale should be 2D or 3D");
|
||||
}
|
||||
if (has_alpha) {
|
||||
TORCH_CHECK(alpha.value().dim() == 1, "the shape of alpha should be 1D");
|
||||
}
|
||||
if (has_beta) {
|
||||
TORCH_CHECK(beta.value().dim() == 1, "the shape of beta should be 1D");
|
||||
}
|
||||
|
||||
// check shape
|
||||
int64_t experts_num = m_list.size(0);
|
||||
int64_t num_token = a_tensor.size(0);
|
||||
int64_t k = a_tensor.size(-1);
|
||||
int64_t n = has_quant_flag ? b_scale.value().size(2) : b_tensor.size(1);
|
||||
int64_t lda = a_tensor.stride(0);
|
||||
int64_t total_m = has_gather_idx ? gather_idx.value().size(0) : num_token;
|
||||
|
||||
TORCH_CHECK(experts_num > 0, "number of groups must be large than zero");
|
||||
CHECK_SHAPE(a_tensor, num_token, k);
|
||||
if (has_a_scale) {
|
||||
CHECK_SHAPE(a_scale.value(), total_m);
|
||||
}
|
||||
if (has_quant_flag) {
|
||||
quant_grouped = true;
|
||||
quant_mode = QuantMode::W4W8;
|
||||
TORCH_CHECK(has_b_scale, "if quant_flag given, b_scale must exist");
|
||||
CHECK_SHAPE(b_scale.value(), b_scale.value().size(0), experts_num, n);
|
||||
} else if (has_b_scale) {
|
||||
int64_t b_k = b_tensor.size(-1);
|
||||
quant_grouped = b_scale.value().dim() == 3 ? true : false;
|
||||
if (quant_grouped) {
|
||||
CHECK_SHAPE(b_scale.value(), b_scale.value().size(0), experts_num, n);
|
||||
} else {
|
||||
CHECK_SHAPE(b_scale.value(), experts_num, n);
|
||||
}
|
||||
TORCH_CHECK(k == b_k || k == b_k * 2, "k == b_k || k == b_k * 2.");
|
||||
quant_mode = (b_k == k) ? QuantMode::W8 : QuantMode::W4;
|
||||
}
|
||||
if (has_alpha) {
|
||||
CHECK_SHAPE(alpha.value(), experts_num);
|
||||
}
|
||||
if (has_beta) {
|
||||
CHECK_SHAPE(beta.value(), experts_num);
|
||||
}
|
||||
auto b_shape = b_tensor.sizes();
|
||||
cnnlDataType_t a_dtype = getCnnlDataType(a_tensor.scalar_type());
|
||||
cnnlDataType_t b_dtype = getCnnlDataType(b_tensor.scalar_type());
|
||||
torch::Dtype dtype = a_tensor.scalar_type();
|
||||
if (quant_mode == QuantMode::W8 || quant_mode == QuantMode::noQuant) {
|
||||
if (b_tensor.dim() == 3) {
|
||||
CHECK_SHAPE(b_tensor, experts_num, n, k);
|
||||
} else {
|
||||
TORCH_CHECK(b_shape[0] * b_shape[2] == experts_num, "b_shape[0] * b_shape[2] == experts_num");
|
||||
CHECK_SHAPE(b_tensor, b_shape[0], n, b_shape[2], k);
|
||||
}
|
||||
} else if (quant_mode == QuantMode::W4) {
|
||||
CHECK_SHAPE(b_tensor, experts_num, n, k / 2);
|
||||
b_dtype = CNNL_DTYPE_INT4X2;
|
||||
}
|
||||
|
||||
TORCH_CHECK(m_list.dtype() == torch::kInt32, "data type of m_list should be int32");
|
||||
if (has_gather_idx) {
|
||||
TORCH_CHECK(gather_idx.value().dtype() == torch::kInt32,
|
||||
"data type of gather_idx should be int32");
|
||||
}
|
||||
|
||||
if (has_a_scale) {
|
||||
TORCH_CHECK(has_dtype, "data_type must given when a_scale and b_scale are given");
|
||||
TORCH_CHECK(data_type.value() == "float" || data_type.value() == "half" ||
|
||||
data_type.value() == "bfloat16",
|
||||
"data_type must be 'float', 'half' or 'bfloat16'.");
|
||||
dtype = data_type.value() == "float" ? torch::kFloat32
|
||||
: data_type.value() == "half" ? torch::kFloat16
|
||||
: torch::kBFloat16;
|
||||
}
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(a_tensor.device());
|
||||
at::Tensor d_tensor = at::empty({total_m, n}, a_tensor.options().dtype(dtype));
|
||||
cnnlDataType_t d_dtype = getCnnlDataType(d_tensor.scalar_type());
|
||||
|
||||
// check device
|
||||
TORCH_CHECK(isMlu(m_list), "m_list must on mlu");
|
||||
TORCH_CHECK(isMlu(a_tensor), "a_tensor must on mlu");
|
||||
TORCH_CHECK(isMlu(b_tensor), "b_tensor must on mlu");
|
||||
TORCH_CHECK(isMlu(d_tensor), "d_tensor must on mlu");
|
||||
if (has_c_tensor) {
|
||||
TORCH_CHECK(isMlu(c_tensor.value()), "c_tensor must on mlu");
|
||||
}
|
||||
if (has_gather_idx) {
|
||||
TORCH_CHECK(isMlu(gather_idx.value()), "gather_idx must on mlu");
|
||||
}
|
||||
if (has_a_scale) {
|
||||
TORCH_CHECK(isMlu(a_scale.value()), "a_scale must on mlu");
|
||||
}
|
||||
if (has_b_scale) {
|
||||
TORCH_CHECK(isMlu(b_scale.value()), "b_scale must on mlu");
|
||||
}
|
||||
if (has_alpha) {
|
||||
TORCH_CHECK(isMlu(alpha.value()), "alpha must on mlu");
|
||||
}
|
||||
if (has_beta) {
|
||||
TORCH_CHECK(isMlu(beta.value()), "beta must on mlu");
|
||||
}
|
||||
|
||||
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
||||
GroupGemmDesc group_gemm_desc(experts_num, max_m, n, k, d_dtype, has_gather_idx, quant_mode);
|
||||
group_gemm_desc.setInputOutputTensor(a_dtype, b_dtype, d_dtype, CNNL_DTYPE_INT32,
|
||||
getAtTensorPtr(a_tensor), getAtTensorPtr(b_tensor),
|
||||
getAtTensorPtr(c_tensor), getAtTensorPtr(d_tensor),
|
||||
getAtTensorPtr(gather_idx), k, lda, total_m, has_c_tensor,
|
||||
(int64_t *)getAtTensorPtr(b_offset));
|
||||
std::vector<int> flag_vec;
|
||||
if (quant_mode != QuantMode::noQuant || has_bias) {
|
||||
if (has_quant_flag) {
|
||||
auto vec = quant_flag.value().vec();
|
||||
flag_vec.resize(vec.size());
|
||||
std::copy(vec.begin(), vec.end(), flag_vec.begin());
|
||||
}
|
||||
group_gemm_desc.setPerRowColScaleBiasAct(
|
||||
has_a_scale ? getAtTensorPtr(a_scale) : nullptr,
|
||||
has_b_scale ? getAtTensorPtr(b_scale) : nullptr, flag_vec.data(),
|
||||
has_bias ? getAtTensorPtr(bias) : nullptr, d_dtype, quant_grouped ? k : 0,
|
||||
quant_grouped ? k / b_scale.value().size(0) : 0, 0);
|
||||
}
|
||||
|
||||
size_t group_gemm_wsize =
|
||||
tmo::ops::getGroupGemmWorkspaceSize(handle, group_gemm_desc, experts_num);
|
||||
auto group_gemm_workspace =
|
||||
at::empty({static_cast<int64_t>(group_gemm_wsize)}, a_tensor.options().dtype(at::kByte));
|
||||
|
||||
std::vector<int> ldb_array;
|
||||
// if (quant_mode == QuantMode::W4W8) {
|
||||
// ldb_array.resize(experts_num);
|
||||
// auto q_group = b_scale.value().size(0);
|
||||
// for (auto i = 0; i < experts_num; ++i) {
|
||||
// int sum = 0;
|
||||
// for (auto j = 0; j < q_group; j++) {
|
||||
// sum += flag_vec[i * q_group + j];
|
||||
// }
|
||||
// ldb_array[i] = (sum / 4) * (k / q_group / 2);
|
||||
// }
|
||||
// }
|
||||
if (quant_mode != QuantMode::W4W8) {
|
||||
int ldb = b_tensor.dim() == 3 ? k : b_shape[2] * k;
|
||||
ldb_array.assign(experts_num, ldb);
|
||||
}
|
||||
GroupGemmTheory obj(total_m, experts_num, k, n, has_c_tensor, a_dtype, getCnnlDataType(dtype));
|
||||
cnpxPush(obj);
|
||||
ops::GroupGemm(handle, group_gemm_desc, getAtTensorPtr(m_list), getAtTensorPtr(alpha),
|
||||
getAtTensorPtr(beta), getAtTensorPtr(group_gemm_workspace), group_gemm_wsize,
|
||||
experts_num, k, n, lda, ldb_array);
|
||||
cnpxPop();
|
||||
return d_tensor;
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
129
torch_mlu_ops-v1.3.2/csrc/torch_api/matmul.cpp
Normal file
129
torch_mlu_ops-v1.3.2/csrc/torch_api/matmul.cpp
Normal file
@@ -0,0 +1,129 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "torch_ops_api.h"
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
at::Tensor matmul(const at::Tensor &a,
|
||||
const at::Tensor &b,
|
||||
const c10::optional<at::Tensor> &d,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<at::Tensor> &c,
|
||||
const c10::optional<std::string> &dtype,
|
||||
const std::string &act_mode,
|
||||
double alpha,
|
||||
double beta,
|
||||
bool fast_act,
|
||||
bool approximate,
|
||||
double a_scale,
|
||||
double b_scale,
|
||||
bool trans_a,
|
||||
bool trans_b) {
|
||||
bool has_bias = bias.has_value();
|
||||
bool has_res = c.has_value();
|
||||
bool use_beta = false;
|
||||
int m = trans_a ? a.size(1) : a.size(0);
|
||||
int n = trans_b ? b.size(0) : b.size(1);
|
||||
int k = trans_a ? a.size(0) : a.size(1);
|
||||
// check device and dtype
|
||||
checkTensorSameAttr<TensorAttr::DEVICE>(a, b, c, bias);
|
||||
// check contiguous
|
||||
TORCH_CHECK(a.stride(-1) == 1, "a last dim must be contiguous.")
|
||||
TORCH_CHECK(b.stride(-1) == 1, "b last dim must be contiguous.")
|
||||
if (has_bias) TORCH_CHECK(bias.value().is_contiguous(), "bias must be contiguous.")
|
||||
if (has_res) {
|
||||
use_beta = true;
|
||||
TORCH_CHECK(c.value().is_contiguous(), "c must be contiguous.")
|
||||
CHECK_SHAPE(c.value(), m, n);
|
||||
}
|
||||
at::Tensor bias_view;
|
||||
if (has_bias) {
|
||||
TORCH_CHECK(bias.value().dim() == 1, "bias must be 1-D tensor.")
|
||||
bias_view = bias.value().unsqueeze(0);
|
||||
}
|
||||
// get cnnl data type and init output
|
||||
auto a_dtype = getCnnlDataType(a.scalar_type());
|
||||
auto b_dtype = getCnnlDataType(b.scalar_type());
|
||||
TORCH_CHECK(a_dtype == b_dtype, "a, b must be same dtype.");
|
||||
if (a_dtype == CNNL_DTYPE_INT8) {
|
||||
TORCH_CHECK(d.has_value() || dtype.has_value(),
|
||||
"d and d_dtype cant be none at the same time when input dtype is int8");
|
||||
}
|
||||
|
||||
at::Tensor output;
|
||||
if (d.has_value()) {
|
||||
output = d.value();
|
||||
} else {
|
||||
torch::Dtype out_dtype = a.scalar_type();
|
||||
if (dtype.has_value()) {
|
||||
TORCH_CHECK(
|
||||
dtype.value() == "float" || dtype.value() == "half" || dtype.value() == "bfloat16",
|
||||
"data_type must be 'float', 'half' or 'bfloat16'.");
|
||||
out_dtype = dtype.value() == "float" ? torch::kFloat32
|
||||
: dtype.value() == "half" ? torch::kFloat16
|
||||
: torch::kBFloat16;
|
||||
}
|
||||
output = at::empty({m, n}, a.options().dtype(out_dtype));
|
||||
}
|
||||
|
||||
auto cd_dtype = getCnnlDataType(output.scalar_type());
|
||||
if (a_dtype == CNNL_DTYPE_BFLOAT16) cd_dtype = CNNL_DTYPE_FLOAT;
|
||||
// create tensor desc
|
||||
auto descs = createTensorDescs({a, b, bias_view, c.value_or(at::Tensor()), output});
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[0].get(), a_dtype));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[1].get(), b_dtype));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[4].get(), cd_dtype));
|
||||
if (has_res) {
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[3].get(), cd_dtype));
|
||||
}
|
||||
if (a_dtype == CNNL_DTYPE_INT8) {
|
||||
float int_max = 127.0;
|
||||
int quant_bit = 8;
|
||||
float max_a = int_max / a_scale;
|
||||
float max_b = int_max / b_scale;
|
||||
int pos_a = std::floor(std::log2(max_a) - (quant_bit - 2));
|
||||
int pos_b = std::floor(std::log2(max_b) - (quant_bit - 2));
|
||||
float new_a_scale = std::pow(2.0f, pos_a) * a_scale;
|
||||
float new_b_scale = std::pow(2.0f, pos_b) * b_scale;
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[0].get(), pos_a, new_a_scale));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[1].get(), pos_b, new_b_scale));
|
||||
TORCH_CHECK(cd_dtype != CNNL_DTYPE_BFLOAT16,
|
||||
"output dtype cannot be bfloat16 when a/b is fixed-point")
|
||||
}
|
||||
|
||||
// get && set Op desc
|
||||
int lda = a.stride(0);
|
||||
int ldb = b.stride(0);
|
||||
int ldc = output.stride(0);
|
||||
size_t workspace_size = 0;
|
||||
auto matmul_desc =
|
||||
tmo::op_desc::MatMulDesc(descs[2].get(), getAtTensorPtr(bias_view), workspace_size, lda, ldb,
|
||||
ldc, act_mode, use_beta, trans_a, trans_b, fast_act, approximate);
|
||||
const torch_mlu::mlu::MLUGuard device_guard(a.device());
|
||||
auto handle = torch_mlu::getCurrentHandle();
|
||||
auto workspace = at::empty({static_cast<int64_t>(workspace_size)}, a.options().dtype(at::kByte));
|
||||
|
||||
// run forward
|
||||
float alpha_f = alpha;
|
||||
float beta_f = beta;
|
||||
MatmulTheory obj(m, k, n, has_res, a_dtype, cd_dtype);
|
||||
cnpxPush(obj);
|
||||
CNNL_CHECK_FATAL(cnnlMatMulEx(handle, matmul_desc, &alpha_f, descs[0].get(), getAtTensorPtr(a),
|
||||
descs[1].get(), getAtTensorPtr(b), &beta_f, descs[3].get(),
|
||||
getAtTensorPtr(c), descs[4].get(), getAtTensorPtr(output), nullptr,
|
||||
getAtTensorPtr(workspace), workspace_size));
|
||||
cnpxPop();
|
||||
return output;
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
52
torch_mlu_ops-v1.3.2/csrc/torch_api/moe_cast_gating.cpp
Normal file
52
torch_mlu_ops-v1.3.2/csrc/torch_api/moe_cast_gating.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include <numeric>
|
||||
#include "kernels/moe/cast_gating.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
at::Tensor moe_cast_gating(const at::Tensor &input, const at::Tensor &weight) {
|
||||
// check
|
||||
checkTensorSameAttr<TensorAttr::DEVICE>(input, weight);
|
||||
TORCH_CHECK(input.is_contiguous(), "input must be contiguous")
|
||||
TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous")
|
||||
TORCH_CHECK(weight.dim() == 2, "weight.dim() == 2")
|
||||
TORCH_CHECK(input.size(-1) == weight.size(-1), "input.size(-1) == weight.size(-1)")
|
||||
TORCH_CHECK(input.scalar_type() == torch::kFloat16 || input.scalar_type() == torch::kBFloat16,
|
||||
"input type need be torch::kFloat16 or torch::kBFloat16")
|
||||
TORCH_CHECK(weight.scalar_type() == torch::kFloat32, "weight type need be torch::kFloat32")
|
||||
|
||||
auto hidden_size = input.size(-1);
|
||||
auto expert_num = weight.size(0);
|
||||
TORCH_CHECK(hidden_size > 0 && hidden_size <= 16384, "hidden_size > 0 && hidden_size <= 16384")
|
||||
TORCH_CHECK(expert_num > 0 && expert_num <= 128, "expert_num > 0 && expert_num <= 128")
|
||||
auto input_shape = input.sizes(); // [..., hidden_size]
|
||||
auto input_row =
|
||||
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, std::multiplies<int>());
|
||||
std::vector<int64_t> output_shape(input_shape.begin(), input_shape.end() - 1);
|
||||
output_shape.push_back(expert_num);
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(input.device());
|
||||
auto output = at::empty(output_shape, weight.options());
|
||||
const int64_t workspace_size = 16 * 1024 * 1024;
|
||||
auto workspace = at::empty({workspace_size}, weight.options().dtype(at::kByte));
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
auto input_dtype = getCnnlDataType(input.scalar_type());
|
||||
TMO_KERNEL_CHECK_FATAL(tmo::invokeCastGating(
|
||||
queue, input.data_ptr(), weight.data_ptr(), output.data_ptr(), input_row, expert_num,
|
||||
hidden_size, input_dtype, workspace.data_ptr(), workspace_size));
|
||||
return output;
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
124
torch_mlu_ops-v1.3.2/csrc/torch_api/moe_combine_result.cpp
Normal file
124
torch_mlu_ops-v1.3.2/csrc/torch_api/moe_combine_result.cpp
Normal file
@@ -0,0 +1,124 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "kernels/moe/combine_result.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
at::Tensor moe_combine_result(const at::Tensor &input,
|
||||
const at::Tensor &reduce_weight,
|
||||
const at::Tensor &gather_ids,
|
||||
const c10::optional<at::Tensor> &residual,
|
||||
const c10::optional<at::Tensor> &cusum_token_count,
|
||||
int64_t start_expert_id,
|
||||
int64_t expert_size,
|
||||
const c10::optional<at::Tensor> &bias) {
|
||||
// check device and dtype
|
||||
checkTensorSameAttr<TensorAttr::DEVICE>(input, reduce_weight, residual, bias, cusum_token_count,
|
||||
gather_ids);
|
||||
|
||||
// check contiguous
|
||||
CHECK_TENSOR_CONTIGUOUS(input)
|
||||
CHECK_TENSOR_CONTIGUOUS(reduce_weight)
|
||||
CHECK_TENSOR_CONTIGUOUS(gather_ids)
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(residual)
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(bias)
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(cusum_token_count)
|
||||
|
||||
// do not support bias
|
||||
TORCH_CHECK(!bias.has_value(), "bias is not supported.");
|
||||
|
||||
// check input
|
||||
TORCH_CHECK(input.dim() == 2, "input should be with shape [num_tokens, hidden_size].");
|
||||
TORCH_CHECK(input.dtype() == torch::kFloat32 || input.dtype() == torch::kFloat16 ||
|
||||
input.dtype() == torch::kBFloat16,
|
||||
"input dtype should be float32/float16/bfloat16.");
|
||||
|
||||
int num_tokens = input.size(0);
|
||||
int hidden_size = input.size(1);
|
||||
|
||||
// check reduce weight
|
||||
TORCH_CHECK(reduce_weight.dim() == 2, "reduce_weight should be with shape [num_token, topk].");
|
||||
TORCH_CHECK(reduce_weight.dtype() == torch::kFloat32,
|
||||
"reduce_weight should be with dtype float32.");
|
||||
int topk = reduce_weight.size(1);
|
||||
TORCH_CHECK(reduce_weight.numel() == num_tokens,
|
||||
"reduce_weight.numel() should equal to input.size(0).");
|
||||
int num_token = num_tokens / topk;
|
||||
|
||||
// check gather_ids
|
||||
TORCH_CHECK(gather_ids.dim() == 1, "gather_ids should be with shape [num_tokens].");
|
||||
TORCH_CHECK(gather_ids.dtype() == torch::kInt32, "gather_ids should be with dtype int32.");
|
||||
TORCH_CHECK(gather_ids.size(0) == num_tokens,
|
||||
"gather_ids.numel() should equal to input.size(0).");
|
||||
|
||||
if (residual.has_value()) {
|
||||
TORCH_CHECK(residual.value().dim() == 2,
|
||||
"residual should be with shape [num_token, hidden_size].");
|
||||
TORCH_CHECK(residual.value().dtype() == input.dtype(),
|
||||
"residual should have same dtype with input.");
|
||||
TORCH_CHECK(residual.value().size(0) == num_token,
|
||||
"residual.size(0) should equal to input.size(0).");
|
||||
TORCH_CHECK(residual.value().size(1) == hidden_size,
|
||||
"residual.size(1) should equal to input.size(1).");
|
||||
}
|
||||
|
||||
// check bias and cusum_token_count
|
||||
bool has_bias = bias.has_value();
|
||||
bool has_cusum_token_count = cusum_token_count.has_value();
|
||||
TORCH_CHECK((!has_bias || (has_bias && has_cusum_token_count)),
|
||||
"if bias is not None, cusum_token_count should not be None.");
|
||||
|
||||
int num_expert = -1;
|
||||
if (has_bias) {
|
||||
TORCH_CHECK(bias.value().dim() == 2, "bias should be with shape [num_expert, hidden_size].");
|
||||
TORCH_CHECK(bias.value().dtype() == input.dtype(), "bias should have same dtype with input.");
|
||||
TORCH_CHECK(bias.value().size(1) == hidden_size, "bias.size(1) should equal to input.size(1).");
|
||||
num_expert = bias.value().size(0);
|
||||
}
|
||||
|
||||
if (has_cusum_token_count) {
|
||||
TORCH_CHECK(cusum_token_count.value().dim() == 1,
|
||||
"cusum_token_count should be with shape [num_expert + 1].");
|
||||
TORCH_CHECK(cusum_token_count.value().dtype() == torch::kInt32,
|
||||
"cusum_token_count should be with dtype int32.");
|
||||
if (num_expert > 0) {
|
||||
TORCH_CHECK(cusum_token_count.value().dim() == 1,
|
||||
"cusum_token_count should be with shape [num_expert + 1].");
|
||||
} else {
|
||||
num_expert = cusum_token_count.value().size(0) - 1;
|
||||
}
|
||||
}
|
||||
|
||||
// check expert
|
||||
TORCH_CHECK(start_expert_id >= 0, "start_expert_id shoule be larger or equal to 0.");
|
||||
TORCH_CHECK(num_expert == -1 || num_expert >= (start_expert_id + expert_size),
|
||||
"num_expert shape shoule be larger or equal to start_expert_id + expert_size.");
|
||||
if (num_expert == -1) {
|
||||
num_expert = expert_size;
|
||||
}
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(input.device());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
auto output_shape = std::vector<int64_t>({num_token, hidden_size});
|
||||
auto output = at::empty(output_shape, input.options());
|
||||
auto dtype = getCnnlDataType(input.scalar_type());
|
||||
|
||||
invokeMoeCombineResultKernel(queue, output.data_ptr(), input.data_ptr(), getAtTensorPtr(bias),
|
||||
getAtTensorPtr(residual), (float *)reduce_weight.data_ptr(),
|
||||
(int *)getAtTensorPtr(cusum_token_count),
|
||||
(int *)getAtTensorPtr(gather_ids), num_token, topk, num_expert,
|
||||
hidden_size, start_expert_id, expert_size, dtype);
|
||||
return output;
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
60
torch_mlu_ops-v1.3.2/csrc/torch_api/moe_expand_input.cpp
Normal file
60
torch_mlu_ops-v1.3.2/csrc/torch_api/moe_expand_input.cpp
Normal file
@@ -0,0 +1,60 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/moe/expand_input.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
at::Tensor moe_expand_input(const torch::Tensor &input,
|
||||
const torch::Tensor &gather_idx,
|
||||
const c10::optional<torch::Tensor> &cusum_token_count,
|
||||
int64_t start_expert_id,
|
||||
int64_t expert_size) {
|
||||
TORCH_CHECK(input.dim() == 2, "input dim must be equal to 2.")
|
||||
TORCH_CHECK(gather_idx.dim() == 1, "gather_idx dim must be equal to 1.")
|
||||
// check device and dtype
|
||||
TORCH_CHECK(isMlu(input), "input tensor must on mlu.")
|
||||
TORCH_CHECK(isMlu(gather_idx), "gather_idx must on mlu.")
|
||||
TORCH_CHECK(gather_idx.dtype() == torch::kInt32, "data type of gather_idx must be int32.")
|
||||
int64_t expand_token_num = gather_idx.size(0);
|
||||
int64_t token_num = input.size(0);
|
||||
int64_t hidden_size = input.size(1);
|
||||
TORCH_CHECK(expand_token_num % token_num == 0, "expand_token_num % token_num == 0.")
|
||||
int64_t topk = expand_token_num / token_num;
|
||||
int64_t expert_num = 0;
|
||||
if (cusum_token_count.has_value()) {
|
||||
TORCH_CHECK(isMlu(cusum_token_count.value()), "cusum_token_count must on mlu.")
|
||||
TORCH_CHECK(cusum_token_count.value().dtype() == torch::kInt32,
|
||||
"data type of cusum_token_count must be int32.")
|
||||
expert_num = cusum_token_count.value().size(0) - 1;
|
||||
TORCH_CHECK(start_expert_id >= 0 && start_expert_id < expert_num,
|
||||
"start_expert_id >=0 && start_expert_id < expert_num.")
|
||||
TORCH_CHECK((start_expert_id + expert_size) <= expert_num,
|
||||
"start_expert_id + expert_size <= expert_num.")
|
||||
}
|
||||
const torch_mlu::mlu::MLUGuard device_guard(input.device());
|
||||
std::vector<int64_t> output_shape = {expand_token_num, hidden_size};
|
||||
auto output = torch::empty(output_shape, input.options());
|
||||
|
||||
auto dtype = getCnnlDataType(input.scalar_type());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
|
||||
tmo::invokeMoeExpandInputKernel(queue, getAtTensorPtr(output), getAtTensorPtr(input),
|
||||
(int *)getAtTensorPtr(gather_idx),
|
||||
(int *)getAtTensorPtr(cusum_token_count), token_num, hidden_size,
|
||||
topk, dtype, expert_num, start_expert_id, expert_size);
|
||||
return output;
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
44
torch_mlu_ops-v1.3.2/csrc/torch_api/moe_gen_idx.cpp
Normal file
44
torch_mlu_ops-v1.3.2/csrc/torch_api/moe_gen_idx.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/moe/gen_idx.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
std::vector<at::Tensor> moe_gen_idx(const torch::Tensor &expert_id, int64_t expert_num) {
|
||||
TORCH_CHECK(expert_id.dim() == 2, "expert_id dim must be equal to 2.")
|
||||
int64_t token_num = expert_id.size(0);
|
||||
int64_t topk = expert_id.size(1);
|
||||
|
||||
// check device and dtype
|
||||
TORCH_CHECK(isMlu(expert_id), "expert_id must on mlu.")
|
||||
TORCH_CHECK(expert_id.dtype() == torch::kInt32, "data type of expert_id must be int32.")
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(expert_id.device());
|
||||
auto expand_idx = at::empty({token_num * topk}, expert_id.options());
|
||||
auto combine_idx = at::empty({token_num * topk}, expert_id.options());
|
||||
auto token_count = at::empty({expert_num}, expert_id.options());
|
||||
auto cusum_token_count = at::empty({expert_num + 1}, expert_id.options());
|
||||
auto gen_idx_workspace = at::empty({expert_num + 1 + token_num * topk}, expert_id.options());
|
||||
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
tmo::invokeMoeGenIdxKernel(
|
||||
queue, (int *)getAtTensorPtr(expand_idx), (int *)getAtTensorPtr(combine_idx),
|
||||
(int *)getAtTensorPtr(token_count), (int *)getAtTensorPtr(cusum_token_count),
|
||||
getAtTensorPtr(gen_idx_workspace), getAtTensorPtr(expert_id), token_num, expert_num, topk);
|
||||
std::vector<at::Tensor> output = {expand_idx, combine_idx, token_count, cusum_token_count};
|
||||
return output;
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
83
torch_mlu_ops-v1.3.2/csrc/torch_api/moe_softmax_topk.cpp
Normal file
83
torch_mlu_ops-v1.3.2/csrc/torch_api/moe_softmax_topk.cpp
Normal file
@@ -0,0 +1,83 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/moe/softmax_topk.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
std::vector<at::Tensor> moe_softmax_topk(const at::Tensor &input, // (num_token, num_expert)
|
||||
int64_t topk,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
bool normalize,
|
||||
const c10::optional<at::Tensor> &mask,
|
||||
const std::string &normed_by) {
|
||||
CHECK_TENSOR_CONTIGUOUS(input)
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(mask)
|
||||
int normalize_mode = 0;
|
||||
if (normalize) {
|
||||
TORCH_CHECK(normed_by == "topk_logit" || normed_by == "softmax_logit",
|
||||
"normed_by must be 'topk_logit' or 'softmax_logit'")
|
||||
if (normed_by == "topk_logit") {
|
||||
normalize_mode = 1;
|
||||
} else if (normed_by == "softmax_logit") {
|
||||
normalize_mode = 2;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2")
|
||||
int num_expert = input.size(-1);
|
||||
int num_token = (int)(input.numel() / num_expert);
|
||||
int num_mask = 1;
|
||||
auto input_shape = input.sizes();
|
||||
TORCH_CHECK(topk > 0 && topk <= num_expert, "topk > 0 && topk <= num_expert")
|
||||
bool has_mask = mask.has_value();
|
||||
if (has_mask) {
|
||||
TORCH_CHECK(mask.value().dim() == input.dim(), "the dim of mask should be the same as input")
|
||||
TORCH_CHECK(mask.value().size(-1) == num_expert,
|
||||
"the last dim of mask should be the same as the last dim of input")
|
||||
TORCH_CHECK(mask.value().size(-2) == input.size(-2),
|
||||
"the penultimate dim of mask should be the same as the penultimate dim of input")
|
||||
int high_dim = mask.value().numel() / (mask.value().size(-1) * mask.value().size(-2));
|
||||
TORCH_CHECK(high_dim == 1, "the product of all but the lower two dimensions of mask is 1")
|
||||
num_mask = mask.value().numel() / num_expert;
|
||||
TORCH_CHECK(mask.value().dtype() == input.dtype(),
|
||||
"the dtype of mask should be the same as input")
|
||||
}
|
||||
if (num_expert_group > 1) {
|
||||
TORCH_CHECK(has_mask == false, "if num_expert_group > 1, mask should be None")
|
||||
TORCH_CHECK(num_expert % num_expert_group == 0, "num_expert % num_expert_group == 0")
|
||||
TORCH_CHECK(topk_group > 0 && topk_group <= num_expert_group,
|
||||
"topk_group > 0 && topk_group <= num_expert_group")
|
||||
TORCH_CHECK(topk <= (num_expert / num_expert_group) * topk_group,
|
||||
"topk <= (num_expert / num_expert_group) * topk_group")
|
||||
}
|
||||
std::vector<int64_t> out_shape(input_shape.begin(), input_shape.end());
|
||||
out_shape.back() = topk;
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(input.device());
|
||||
auto tensor_options = input.options();
|
||||
auto reduce_weight = at::empty(out_shape, tensor_options.dtype(torch::kFloat));
|
||||
auto expert_id = at::empty(out_shape, tensor_options.dtype(torch::kInt32));
|
||||
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
auto input_dtype = getCnnlDataType(input.scalar_type());
|
||||
TMO_KERNEL_CHECK_FATAL(invokeMoeSoftmaxTopkKernel(
|
||||
queue, (float *)reduce_weight.data_ptr(), (int *)expert_id.data_ptr(), input.data_ptr(),
|
||||
getAtTensorPtr(mask), num_token, num_expert, num_mask, topk, num_expert_group, topk_group,
|
||||
input_dtype, normalize_mode));
|
||||
return {reduce_weight, expert_id};
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,177 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/offline_quant_to_linear_cache.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
void offline_quant_to_linear_cache(
|
||||
const at::Tensor &key, // [total_seqlen, num_heads, head_size]
|
||||
// or [bs, max_context_len, num_heads, head_size]
|
||||
const c10::optional<at::Tensor> &value,
|
||||
at::Tensor &key_cache, // [max_bs, num_heads, cache_memory_len, head_size]
|
||||
const c10::optional<at::Tensor> &value_cache,
|
||||
const at::Tensor &key_cache_scale, // [num_heads, cache_memory_len] or [num_heads, head_size]
|
||||
const c10::optional<at::Tensor> &value_cache_scale,
|
||||
const at::Tensor &context_lengths,
|
||||
const int64_t max_context_len,
|
||||
const int64_t quant_mode, // 0:per_channel, others:per_head
|
||||
const bool packed,
|
||||
const c10::optional<at::Tensor> &context_seq_offset,
|
||||
const c10::optional<at::Tensor> &cache_bs_id,
|
||||
const c10::optional<at::Tensor> &cache_seqlen_offset) {
|
||||
/****************************************check key***************************************/
|
||||
// check dtype
|
||||
TORCH_CHECK(key.scalar_type() == torch::kFloat32 || key.scalar_type() == torch::kFloat16 ||
|
||||
key.scalar_type() == torch::kBFloat16,
|
||||
"key type need be torch::kFloat32, torch::kFloat16 or torch::kBFloat16");
|
||||
TORCH_CHECK(key_cache.scalar_type() == torch::kInt8, "key_cache type need be torch::kInt8");
|
||||
TORCH_CHECK(key_cache_scale.scalar_type() == torch::kFloat32,
|
||||
"key_cache_scale type need be torch::kFloat32");
|
||||
|
||||
// check shape
|
||||
const int batch_size = packed ? context_lengths.size(0) - 1 : context_lengths.size(0);
|
||||
const int max_bs = key_cache.size(0);
|
||||
const int num_heads = key_cache.size(1);
|
||||
const int cache_mem_len = key_cache.size(2);
|
||||
const int head_size = key_cache.size(3);
|
||||
const int total_seqlen = packed ? key.size(0) : batch_size * key.size(-2);
|
||||
|
||||
TORCH_CHECK(key_cache.size(0) >= batch_size,
|
||||
"first dim of key_cache must be great than or equal to batch_size");
|
||||
if (packed) {
|
||||
CHECK_SHAPE(key, total_seqlen, num_heads, head_size);
|
||||
CHECK_SHAPE(context_lengths, batch_size + 1);
|
||||
} else {
|
||||
CHECK_SHAPE(key, batch_size, key.size(1), num_heads, head_size);
|
||||
CHECK_SHAPE(context_lengths, batch_size);
|
||||
}
|
||||
if (quant_mode == 0) {
|
||||
CHECK_SHAPE(key_cache_scale, num_heads, head_size);
|
||||
} else {
|
||||
CHECK_SHAPE(key_cache_scale, num_heads, cache_mem_len);
|
||||
}
|
||||
|
||||
// check contiguous
|
||||
TORCH_CHECK(key_cache.is_contiguous(), "key_cache tensor must be contiguous.");
|
||||
TORCH_CHECK(key_cache_scale.is_contiguous(), "key_cache_scale tensor must be contiguous.");
|
||||
|
||||
// check device
|
||||
const int64_t device_id = key.get_device();
|
||||
TORCH_CHECK(key_cache.get_device() == device_id,
|
||||
"key_cache tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", key_cache.get_device());
|
||||
TORCH_CHECK(key_cache_scale.get_device() == device_id,
|
||||
"key_cache_scale tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", key_cache_scale.get_device());
|
||||
|
||||
/***************************************check value***************************************/
|
||||
if (value.has_value() || value_cache.has_value() || value_cache_scale.has_value()) {
|
||||
TORCH_CHECK(value.has_value() && value_cache.has_value() && value_cache_scale.has_value(),
|
||||
"value_cache, value_cache_scale, value must all have value.")
|
||||
}
|
||||
if (value_cache.has_value()) {
|
||||
TORCH_CHECK(value_cache.value().scalar_type() == torch::kInt8,
|
||||
"value_cache type need be torch::kInt8");
|
||||
TORCH_CHECK(value_cache_scale.value().scalar_type() == torch::kFloat32,
|
||||
"value_cache_scale type need be torch::kFloat32");
|
||||
TORCH_CHECK(value.value().scalar_type() == key.scalar_type(),
|
||||
"value type need be same with key");
|
||||
if (packed) {
|
||||
CHECK_SHAPE(value.value(), total_seqlen, num_heads, head_size);
|
||||
} else {
|
||||
CHECK_SHAPE(value.value(), batch_size, key.size(1), num_heads, head_size);
|
||||
}
|
||||
if (quant_mode == 0) {
|
||||
CHECK_SHAPE(value_cache_scale.value(), num_heads, head_size);
|
||||
} else {
|
||||
CHECK_SHAPE(value_cache_scale.value(), num_heads, cache_mem_len);
|
||||
}
|
||||
CHECK_SHAPE(value_cache.value(), max_bs, num_heads, cache_mem_len, head_size);
|
||||
|
||||
for (int i = 0; i < key.dim(); i++) {
|
||||
TORCH_CHECK(value.value().stride(i) == key.stride(i),
|
||||
"key and value must have same stride along axi ", i);
|
||||
}
|
||||
TORCH_CHECK(value_cache.value().is_contiguous(), "value_cache tensor must be contiguous.");
|
||||
TORCH_CHECK(value_cache_scale.value().is_contiguous(),
|
||||
"value_cache_scale tensor must be contiguous.");
|
||||
TORCH_CHECK(value_cache.value().get_device() == device_id,
|
||||
"value_cache tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", value_cache.value().get_device());
|
||||
TORCH_CHECK(value_cache_scale.value().get_device() == device_id,
|
||||
"value_cache_scale tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", value_cache_scale.value().get_device());
|
||||
}
|
||||
|
||||
TORCH_CHECK(context_lengths.scalar_type() == torch::kInt32,
|
||||
"context_lengths type need be torch::kInt32");
|
||||
/*********************************check optional tensor***********************************/
|
||||
const void *context_seq_offset_ptr = nullptr;
|
||||
if (context_seq_offset.has_value()) {
|
||||
CHECK_SHAPE(context_seq_offset.value(), batch_size);
|
||||
TORCH_CHECK(context_seq_offset.value().scalar_type() == torch::kInt32,
|
||||
"context_seq_offset type need be torch::kInt32");
|
||||
TORCH_CHECK(context_seq_offset.value().get_device() == device_id,
|
||||
"context_seq_offset tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", context_seq_offset.value().get_device());
|
||||
context_seq_offset_ptr = context_seq_offset.value().data_ptr();
|
||||
}
|
||||
|
||||
const void *cache_bs_id_ptr = nullptr;
|
||||
if (cache_bs_id.has_value()) {
|
||||
CHECK_SHAPE(cache_bs_id.value(), batch_size);
|
||||
TORCH_CHECK(cache_bs_id.value().scalar_type() == torch::kInt32,
|
||||
"cache_bs_id type need be torch::kInt32");
|
||||
TORCH_CHECK(cache_bs_id.value().get_device() == device_id,
|
||||
"cache_bs_id tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", cache_bs_id.value().get_device());
|
||||
cache_bs_id_ptr = cache_bs_id.value().data_ptr();
|
||||
}
|
||||
|
||||
const void *cache_seqlen_offset_ptr = nullptr;
|
||||
if (cache_seqlen_offset.has_value()) {
|
||||
CHECK_SHAPE(cache_seqlen_offset.value(), batch_size);
|
||||
TORCH_CHECK(cache_seqlen_offset.value().scalar_type() == torch::kInt32,
|
||||
"cache_seqlen_offset type need be torch::kInt32");
|
||||
TORCH_CHECK(cache_seqlen_offset.value().get_device() == device_id,
|
||||
"cache_seqlen_offset tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", cache_seqlen_offset.value().get_device());
|
||||
cache_seqlen_offset_ptr = cache_seqlen_offset.value().data_ptr();
|
||||
}
|
||||
|
||||
const size_t context_bs_stride = packed ? 0 : key.stride(0);
|
||||
const size_t context_head_stride = packed ? key.stride(1) : key.stride(2);
|
||||
const size_t context_seq_stride = packed ? key.stride(0) : key.stride(1);
|
||||
const size_t cache_bs_stride = key_cache.stride(0);
|
||||
const size_t cache_head_stride = key_cache.stride(1);
|
||||
const size_t cache_seq_stride = key_cache.stride(2);
|
||||
const size_t cache_scale_head_stride = key_cache_scale.stride(0);
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(key.device());
|
||||
auto data_dtype = getCnnlDataType(key.scalar_type());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
|
||||
// run forward
|
||||
TMO_KERNEL_CHECK_FATAL(invokeOfflineQuantToLinearCache(
|
||||
queue, getAtTensorPtr(key_cache), getAtTensorPtr(value_cache),
|
||||
getAtTensorPtr(key_cache_scale), getAtTensorPtr(value_cache_scale), cache_bs_id_ptr,
|
||||
cache_seqlen_offset_ptr, getAtTensorPtr(key), getAtTensorPtr(value), context_seq_offset_ptr,
|
||||
getAtTensorPtr(context_lengths), data_dtype, batch_size, num_heads, head_size,
|
||||
(int)max_context_len, cache_mem_len, context_bs_stride, context_head_stride,
|
||||
context_seq_stride, cache_bs_stride, cache_head_stride, cache_seq_stride,
|
||||
cache_scale_head_stride, packed, quant_mode));
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,89 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/offline_quant_to_paged_cache.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
void offline_quant_to_paged_cache(
|
||||
const torch::Tensor &k, // [num_tokens, num_heads, head_size]
|
||||
const c10::optional<torch::Tensor> &v, // [num_tokens, num_heads, head_size]
|
||||
const torch::Tensor &k_cache_scale, // [num_heads, head_size]
|
||||
const c10::optional<torch::Tensor> &v_cache_scale, // [num_heads, head_size]
|
||||
const torch::Tensor &slot_mapping,
|
||||
torch::Tensor &k_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||
const c10::optional<torch::Tensor>
|
||||
&v_cache) { // [num_blocks, num_heads, block_size, head_size]
|
||||
// 1. check device and tensor type
|
||||
TORCH_CHECK(slot_mapping.dtype() == torch::kInt32, "slot_mapping type need be torch::kInt32");
|
||||
checkTensorSameAttr<TensorAttr::DEVICE>(k, v, k_cache_scale, v_cache_scale, slot_mapping);
|
||||
|
||||
// 2. check dim and shape
|
||||
TORCH_CHECK(k.dim() == 3, "dim of k must be 3");
|
||||
TORCH_CHECK(k_cache.dim() == 4, "dim of k_cache must be 4");
|
||||
TORCH_CHECK(k_cache_scale.dim() == 2, "dim of k_cache_scale must be 2");
|
||||
TORCH_CHECK(
|
||||
v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value(),
|
||||
"v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value().");
|
||||
if (v.has_value()) {
|
||||
TORCH_CHECK(v.value().dim() == 3, "dim of v must be 3");
|
||||
TORCH_CHECK(v_cache.value().dim() == 4, "dim of v_cache must be 4");
|
||||
TORCH_CHECK(v_cache_scale.value().dim() == 2, "dim of v_cache_scale must be 2");
|
||||
}
|
||||
TORCH_CHECK(slot_mapping.dim() == 1, "dim of slot_mapping must be 1");
|
||||
|
||||
const int num_tokens = k.size(0);
|
||||
const int num_heads = k.size(1);
|
||||
const int head_size = k.size(2);
|
||||
const int block_size = k_cache.size(2);
|
||||
const int num_blocks = k_cache.size(0);
|
||||
|
||||
CHECK_SHAPE(k, num_tokens, num_heads, head_size);
|
||||
CHECK_SHAPE(k_cache_scale, num_heads, head_size);
|
||||
CHECK_SHAPE(k_cache, num_blocks, num_heads, block_size, head_size);
|
||||
if (v.has_value()) {
|
||||
CHECK_SHAPE(v.value(), num_tokens, num_heads, head_size);
|
||||
CHECK_SHAPE(v_cache_scale.value(), num_heads, head_size);
|
||||
CHECK_SHAPE(v_cache.value(), num_blocks, num_heads, block_size, head_size);
|
||||
}
|
||||
CHECK_SHAPE(slot_mapping, num_tokens);
|
||||
|
||||
// 3. check strides
|
||||
TORCH_CHECK(slot_mapping.is_contiguous(), "slot_mapping need be contiguous.");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "k last dim must be contiguous.");
|
||||
TORCH_CHECK(k.stride(-2) == head_size, "k second dim must be contiguous.");
|
||||
if (v.has_value()) {
|
||||
TORCH_CHECK(v.value().stride(-1) == 1, "v last dim must be contiguous.");
|
||||
TORCH_CHECK(v.value().stride(-2) == head_size, "v second dim must be contiguous.");
|
||||
}
|
||||
|
||||
int64_t kv_cache_range =
|
||||
(int64_t)num_blocks * num_heads * block_size * head_size * k_cache.element_size();
|
||||
TORCH_CHECK(kv_cache_range <= UINT32_MAX, "The addressing range of kv_cache cannot exceed 4G.");
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(k.device());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
cnnlDataType_t dtype = getCnnlDataType(k.scalar_type());
|
||||
int32_t key_stride0 = static_cast<int32_t>(k.stride(0));
|
||||
int32_t value_stride0 = 1;
|
||||
if (v.has_value()) {
|
||||
value_stride0 = static_cast<int32_t>(v.value().stride(0));
|
||||
}
|
||||
TMO_KERNEL_CHECK_FATAL(tmo::invokeOfflineQuantToPagedCache(
|
||||
queue, dtype, getAtTensorPtr(k), getAtTensorPtr(v), getAtTensorPtr(k_cache),
|
||||
getAtTensorPtr(v_cache), getAtTensorPtr(k_cache_scale), getAtTensorPtr(v_cache_scale),
|
||||
getAtTensorPtr(slot_mapping), key_stride0, value_stride0, num_tokens, num_heads, num_blocks,
|
||||
block_size, head_size));
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
423
torch_mlu_ops-v1.3.2/csrc/torch_api/op_theory.h
Normal file
423
torch_mlu_ops-v1.3.2/csrc/torch_api/op_theory.h
Normal file
@@ -0,0 +1,423 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_TORCH_API_OP_THEORY_H_
|
||||
#define CSRC_TORCH_API_OP_THEORY_H_
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
// base class
|
||||
struct OpTheory {
|
||||
virtual size_t getTheoryCalc(void) const = 0;
|
||||
virtual size_t getTheoryIO(void) const = 0;
|
||||
virtual cnnlDataType_t getCalcDtype(void) const = 0;
|
||||
virtual const std::string &getOpName(void) const = 0;
|
||||
};
|
||||
|
||||
struct MatmulTheory : public OpTheory {
|
||||
int m_, k_, n_;
|
||||
bool has_res_;
|
||||
cnnlDataType_t calc_type_, data_type_;
|
||||
std::string name_;
|
||||
MatmulTheory(int m,
|
||||
int k,
|
||||
int n,
|
||||
bool has_res,
|
||||
cnnlDataType_t calc_type,
|
||||
cnnlDataType_t data_type)
|
||||
: OpTheory(),
|
||||
m_(m),
|
||||
k_(k),
|
||||
n_(n),
|
||||
has_res_(has_res),
|
||||
calc_type_(calc_type),
|
||||
data_type_(data_type),
|
||||
name_("matmul") {}
|
||||
// lt: m * n * k, ct: add_bias + act
|
||||
size_t getTheoryCalc(void) const override {
|
||||
size_t size_calc = (size_t)m_ * n_ * k_;
|
||||
return size_calc;
|
||||
}
|
||||
size_t getTheoryIO(void) const override {
|
||||
size_t calc_dw, data_dw;
|
||||
cnnlGetSizeOfDataType(calc_type_, &calc_dw);
|
||||
cnnlGetSizeOfDataType(data_type_, &data_dw);
|
||||
return calc_dw * (m_ * k_ + n_ * k_) + data_dw * m_ * n_ * (1 + has_res_);
|
||||
}
|
||||
cnnlDataType_t getCalcDtype(void) const override { return calc_type_; }
|
||||
const std::string &getOpName(void) const override { return name_; }
|
||||
};
|
||||
|
||||
struct QuantMatmulTheory : public OpTheory {
|
||||
int m_, k_, n_, quant_bit_size_;
|
||||
bool has_res_;
|
||||
cnnlDataType_t data_type_;
|
||||
std::string quant_algo_, name_;
|
||||
QuantMatmulTheory(int m,
|
||||
int k,
|
||||
int n,
|
||||
int quant_bit_size,
|
||||
bool has_res,
|
||||
cnnlDataType_t data_type,
|
||||
std::string quant_algo)
|
||||
: OpTheory(),
|
||||
m_(m),
|
||||
k_(k),
|
||||
n_(n),
|
||||
quant_bit_size_(quant_bit_size),
|
||||
has_res_(has_res),
|
||||
data_type_(data_type),
|
||||
quant_algo_(quant_algo),
|
||||
name_("quant_matmul") {}
|
||||
size_t getTheoryCalc(void) const override {
|
||||
size_t size_calc = (size_t)m_ * n_ * k_;
|
||||
return size_calc;
|
||||
}
|
||||
size_t getTheoryIO(void) const override {
|
||||
size_t io_size = 0, data_dw = 0;
|
||||
cnnlGetSizeOfDataType(data_type_, &data_dw);
|
||||
if (quant_algo_ == "weight_only") {
|
||||
io_size = (data_dw * m_ * k_ + m_ * n_ * (1 + has_res_)) + n_ * k_ * (quant_bit_size_ / 8.0);
|
||||
} else {
|
||||
io_size = data_dw * m_ * n_ * (1 + has_res_) + (m_ * k_ + n_ * k_) * (quant_bit_size_ / 8.0);
|
||||
}
|
||||
return io_size;
|
||||
}
|
||||
cnnlDataType_t getCalcDtype(void) const override {
|
||||
if (quant_algo_ == "weight_only")
|
||||
return data_type_;
|
||||
else
|
||||
return CNNL_DTYPE_INT8;
|
||||
}
|
||||
const std::string &getOpName(void) const override { return name_; }
|
||||
};
|
||||
|
||||
struct GroupGemmTheory : public OpTheory {
|
||||
int total_m_, experts_num_, k_, n_, quant_bit_size_;
|
||||
bool has_res_;
|
||||
cnnlDataType_t calc_type_, data_type_;
|
||||
std::string name_;
|
||||
GroupGemmTheory(int total_m,
|
||||
int experts_num, // The max number of processed experts.
|
||||
int k,
|
||||
int n,
|
||||
bool has_res,
|
||||
cnnlDataType_t calc_type,
|
||||
cnnlDataType_t data_type)
|
||||
: OpTheory(),
|
||||
total_m_(total_m),
|
||||
experts_num_(experts_num),
|
||||
k_(k),
|
||||
n_(n),
|
||||
has_res_(has_res),
|
||||
calc_type_(calc_type),
|
||||
data_type_(data_type),
|
||||
name_("group_gemm") {} // smooth_quant or float-point
|
||||
size_t getTheoryCalc(void) const override {
|
||||
size_t size_calc = (size_t)total_m_ * n_ * k_;
|
||||
return size_calc;
|
||||
}
|
||||
size_t getTheoryIO(void) const override {
|
||||
size_t calc_dw, data_dw;
|
||||
cnnlGetSizeOfDataType(calc_type_, &calc_dw);
|
||||
cnnlGetSizeOfDataType(data_type_, &data_dw);
|
||||
// assume load max_num weights
|
||||
size_t size_io = calc_dw * ((size_t)total_m_ * k_ + (size_t)experts_num_ * n_ * k_) +
|
||||
data_dw * total_m_ * n_ * (1 + has_res_);
|
||||
return size_io;
|
||||
}
|
||||
cnnlDataType_t getCalcDtype(void) const override { return calc_type_; }
|
||||
const std::string &getOpName(void) const override { return name_; }
|
||||
};
|
||||
|
||||
struct FlashAttnTheory : public OpTheory {
|
||||
int batch_, total_q_, total_k_, head_num_q_, head_num_kv_, head_size_qk_, head_size_v_;
|
||||
bool is_causal_;
|
||||
cnnlDataType_t data_type_;
|
||||
std::string name_;
|
||||
FlashAttnTheory(int batch,
|
||||
int total_q,
|
||||
int total_k,
|
||||
int head_num_q,
|
||||
int head_num_kv,
|
||||
int head_size_qk,
|
||||
int head_size_v,
|
||||
bool is_causal,
|
||||
cnnlDataType_t data_type)
|
||||
: OpTheory(),
|
||||
batch_(batch),
|
||||
total_q_(total_q),
|
||||
total_k_(total_k),
|
||||
head_num_q_(head_num_q),
|
||||
head_num_kv_(head_num_kv),
|
||||
head_size_qk_(head_size_qk),
|
||||
head_size_v_(head_size_v),
|
||||
is_causal_(is_causal),
|
||||
data_type_(data_type),
|
||||
name_("flash_attention") {}
|
||||
size_t getTheoryCalc(void) const override {
|
||||
// q * k + qk * v
|
||||
size_t seq_q = total_q_ / batch_;
|
||||
size_t seq_kv = total_k_ / batch_;
|
||||
if (is_causal_) seq_kv = seq_kv - 0.5 * seq_q;
|
||||
size_t size_lt = (size_t)batch_ * head_num_q_ * seq_q * seq_kv * (head_size_qk_ + head_size_v_);
|
||||
return size_lt;
|
||||
}
|
||||
size_t getTheoryIO(void) const override {
|
||||
size_t data_dw;
|
||||
cnnlGetSizeOfDataType(data_type_, &data_dw);
|
||||
size_t size_io = data_dw * (head_num_q_ * total_q_ + head_num_kv_ * total_k_) *
|
||||
(head_size_qk_ + head_size_v_);
|
||||
return size_io;
|
||||
}
|
||||
cnnlDataType_t getCalcDtype(void) const override { return data_type_; }
|
||||
const std::string &getOpName(void) const override { return name_; }
|
||||
};
|
||||
|
||||
struct SingleQueryCachedKvAttnTheory : public OpTheory {
|
||||
int batch_, seq_q_, head_num_q_, head_num_kv_, num_block_, block_size_, head_size_qk_,
|
||||
head_size_v_, kv_cache_quant_bit_size_;
|
||||
cnnlQuantizeLayout_t cache_quant_layout_;
|
||||
cnnlDataType_t data_type_;
|
||||
std::string name_;
|
||||
SingleQueryCachedKvAttnTheory(int batch,
|
||||
int seq_q,
|
||||
int head_num_q,
|
||||
int head_num_kv,
|
||||
int num_block,
|
||||
int block_size,
|
||||
int head_size_qk,
|
||||
int head_size_v,
|
||||
int kv_cache_quant_bit_size,
|
||||
cnnlQuantizeLayout_t cache_quant_layout,
|
||||
cnnlDataType_t data_type)
|
||||
: OpTheory(),
|
||||
batch_(batch),
|
||||
seq_q_(seq_q),
|
||||
head_num_q_(head_num_q),
|
||||
head_num_kv_(head_num_kv),
|
||||
num_block_(num_block),
|
||||
block_size_(block_size),
|
||||
head_size_qk_(head_size_qk),
|
||||
head_size_v_(head_size_v),
|
||||
kv_cache_quant_bit_size_(kv_cache_quant_bit_size),
|
||||
cache_quant_layout_(cache_quant_layout),
|
||||
data_type_(data_type),
|
||||
name_("single_query_cached_kv_attn") {}
|
||||
size_t getTheoryCalc(void) const override {
|
||||
// q * k + qk * v
|
||||
size_t size_lt =
|
||||
(size_t)seq_q_ * head_num_q_ * num_block_ * block_size_ * (head_size_qk_ + head_size_v_);
|
||||
return size_lt;
|
||||
}
|
||||
size_t getTheoryIO(void) const override {
|
||||
size_t data_dw = 0, dw_float = 0, size_scale = 0;
|
||||
cnnlGetSizeOfDataType(data_type_, &data_dw);
|
||||
cnnlGetSizeOfDataType(CNNL_DTYPE_FLOAT, &dw_float);
|
||||
size_t cache_dw = kv_cache_quant_bit_size_;
|
||||
if (kv_cache_quant_bit_size_ == -1) {
|
||||
cache_dw = data_dw * 8;
|
||||
} else {
|
||||
if (cache_quant_layout_ == CNNL_QUANTIZE_PER_CHANNEL) {
|
||||
size_scale = dw_float * head_num_kv_ * (head_size_qk_ + head_size_v_);
|
||||
} else {
|
||||
size_scale = dw_float * num_block_ * head_num_kv_ * block_size_ * 2;
|
||||
}
|
||||
}
|
||||
size_t size_io =
|
||||
size_scale + data_dw * batch_ * seq_q_ * head_num_q_ * (head_size_qk_ + head_size_v_) +
|
||||
(cache_dw / 8) * head_num_kv_ * num_block_ * block_size_ * (head_size_qk_ + head_size_v_);
|
||||
return size_io;
|
||||
}
|
||||
cnnlDataType_t getCalcDtype(void) const override { return data_type_; }
|
||||
const std::string &getOpName(void) const override { return name_; }
|
||||
};
|
||||
|
||||
struct FusedNormTheory : public OpTheory {
|
||||
int num_token_, hidden_size_;
|
||||
bool has_res_, has_bias_, quant_out_, dynamic_quant_, store_output_before_norm_;
|
||||
cnnlDataType_t data_type_;
|
||||
std::string norm_mode_, name_;
|
||||
FusedNormTheory(int num_token,
|
||||
int hidden_size,
|
||||
bool has_res,
|
||||
bool has_bias,
|
||||
bool quant_out,
|
||||
bool dynamic_quant,
|
||||
bool store_output_before_norm,
|
||||
cnnlDataType_t data_type,
|
||||
std::string norm_mode)
|
||||
: OpTheory(),
|
||||
num_token_(num_token),
|
||||
hidden_size_(hidden_size),
|
||||
has_res_(has_res),
|
||||
has_bias_(has_bias),
|
||||
quant_out_(quant_out),
|
||||
dynamic_quant_(dynamic_quant),
|
||||
store_output_before_norm_(store_output_before_norm),
|
||||
data_type_(data_type),
|
||||
norm_mode_(norm_mode),
|
||||
name_("fused_norm") {}
|
||||
size_t getTheoryCalc(void) const override {
|
||||
size_t size_ct = 0;
|
||||
if (norm_mode_ == "layernorm") {
|
||||
size_ct = 7; // avg + sub + square + sum + cycle_mul + mul gamma + add beta,
|
||||
} else if (norm_mode_ == "rmsnorm") {
|
||||
size_ct = 4; // square + avg + cycle_mul + mul gamma
|
||||
}
|
||||
size_ct += (has_res_ + has_bias_);
|
||||
if (dynamic_quant_)
|
||||
size_ct += 4;
|
||||
else if (quant_out_)
|
||||
size_ct += 1;
|
||||
return size_ct * num_token_ * hidden_size_;
|
||||
}
|
||||
size_t getTheoryIO(void) const override {
|
||||
size_t data_dw, dw_int8, dw_float;
|
||||
cnnlGetSizeOfDataType(CNNL_DTYPE_FLOAT, &dw_float);
|
||||
cnnlGetSizeOfDataType(data_type_, &data_dw);
|
||||
cnnlGetSizeOfDataType(CNNL_DTYPE_INT8, &dw_int8);
|
||||
size_t size_io = data_dw * num_token_ * hidden_size_;
|
||||
if (has_res_) size_io += data_dw * num_token_ * hidden_size_;
|
||||
if (has_bias_) size_io += data_dw * hidden_size_;
|
||||
size_io += hidden_size_ * data_dw * (1 + (norm_mode_ == "layernorm")); // gamma&beta
|
||||
if (quant_out_) {
|
||||
size_io += dw_int8 * num_token_ * hidden_size_; // output
|
||||
size_io += dw_float * (hidden_size_ + dynamic_quant_ ? num_token_ : 0); // scale
|
||||
} else {
|
||||
size_io += data_dw * num_token_ * hidden_size_;
|
||||
}
|
||||
if (store_output_before_norm_) size_io += data_dw * num_token_ * hidden_size_;
|
||||
return size_io;
|
||||
}
|
||||
cnnlDataType_t getCalcDtype(void) const override { return data_type_; }
|
||||
const std::string &getOpName(void) const override { return name_; }
|
||||
};
|
||||
|
||||
struct SmoothQuantTheory : public OpTheory {
|
||||
int m_, ci_;
|
||||
cnnlDataType_t data_type_;
|
||||
bool dynamic_quant_;
|
||||
std::string name_;
|
||||
SmoothQuantTheory(int m, int ci, cnnlDataType_t data_type, bool dynamic_quant)
|
||||
: OpTheory(),
|
||||
m_(m),
|
||||
ci_(ci),
|
||||
data_type_(data_type),
|
||||
dynamic_quant_(dynamic_quant),
|
||||
name_("smooth_quant_online") {}
|
||||
size_t getTheoryCalc(void) const override {
|
||||
if (!dynamic_quant_) {
|
||||
return 3 * (size_t)m_ * ci_; // convert + mul scale + convert
|
||||
}
|
||||
size_t calc_scale = m_; // 128 / maxvalue
|
||||
size_t size_calc = (2 /*max + min*/ + 2 /*mul scale*/ + 2 /*convert*/) * (size_t)m_ * ci_;
|
||||
return calc_scale + size_calc;
|
||||
}
|
||||
size_t getTheoryIO(void) const override {
|
||||
size_t data_dw, dw_int8, dw_float;
|
||||
cnnlGetSizeOfDataType(CNNL_DTYPE_FLOAT, &dw_float);
|
||||
cnnlGetSizeOfDataType(CNNL_DTYPE_INT8, &dw_int8);
|
||||
cnnlGetSizeOfDataType(data_type_, &data_dw);
|
||||
size_t scale_size = dw_float * (ci_ + dynamic_quant_ ? m_ : 0);
|
||||
return scale_size + (data_dw + dw_int8) * (m_ * ci_);
|
||||
}
|
||||
cnnlDataType_t getCalcDtype(void) const override { return data_type_; }
|
||||
const std::string &getOpName(void) const override { return name_; }
|
||||
};
|
||||
|
||||
struct GroupAddBiasActiveTheory : public OpTheory {
|
||||
int expert_size_, num_expand_token_, ci_, inner_size_;
|
||||
bool has_gate_, is_addbias_;
|
||||
cnnlDataType_t data_type_;
|
||||
std::string act_mode_, name_;
|
||||
GroupAddBiasActiveTheory(int expert_size,
|
||||
int num_expand_token,
|
||||
int inner_size,
|
||||
bool has_gate,
|
||||
bool is_addbias,
|
||||
cnnlDataType_t data_type,
|
||||
std::string act_mode)
|
||||
: OpTheory(),
|
||||
expert_size_(expert_size),
|
||||
num_expand_token_(num_expand_token),
|
||||
inner_size_(inner_size),
|
||||
has_gate_(has_gate),
|
||||
is_addbias_(is_addbias),
|
||||
data_type_(data_type),
|
||||
act_mode_(act_mode),
|
||||
name_("group_add_bias_active") {}
|
||||
size_t getTheoryCalc(void) const override {
|
||||
size_t size_vec = 0, act_size = 0;
|
||||
if (is_addbias_) size_vec += (size_t)num_expand_token_ * inner_size_ * (1 + has_gate_);
|
||||
if (act_mode_ == "silu") { // silu = sigmoid + mul
|
||||
act_size += 2 * (size_t)num_expand_token_ * inner_size_;
|
||||
} else { // gelu = 1 vector_compution
|
||||
act_size += (size_t)num_expand_token_ * inner_size_;
|
||||
}
|
||||
if (has_gate_) size_vec += (size_t)num_expand_token_ * inner_size_;
|
||||
return size_vec + act_size;
|
||||
}
|
||||
size_t getTheoryIO(void) const override {
|
||||
size_t data_dw = 0;
|
||||
cnnlGetSizeOfDataType(data_type_, &data_dw);
|
||||
size_t bias_io = is_addbias_ ? expert_size_ * ci_ * (1 + has_gate_) : 0;
|
||||
return data_dw * (bias_io + num_expand_token_ * inner_size_ * (2 + has_gate_));
|
||||
}
|
||||
cnnlDataType_t getCalcDtype(void) const override { return data_type_; }
|
||||
const std::string &getOpName(void) const override { return name_; }
|
||||
};
|
||||
|
||||
struct MoeCombineResultTheory : public OpTheory {
|
||||
int num_token_, topk_, ci_, expert_size_;
|
||||
bool has_bias_, has_res_;
|
||||
cnnlDataType_t data_type_;
|
||||
std::string name_;
|
||||
MoeCombineResultTheory(int num_token,
|
||||
int topk,
|
||||
int ci,
|
||||
int expert_size,
|
||||
bool has_bias,
|
||||
bool has_res,
|
||||
cnnlDataType_t data_type)
|
||||
: OpTheory(),
|
||||
num_token_(num_token),
|
||||
topk_(topk),
|
||||
ci_(ci),
|
||||
expert_size_(expert_size),
|
||||
has_bias_(has_bias),
|
||||
has_res_(has_res),
|
||||
data_type_(data_type),
|
||||
name_("moe_combine_result") {}
|
||||
size_t getTheoryCalc(void) const override {
|
||||
size_t size_calc =
|
||||
(size_t)((has_bias_ + 3 /*cvt+fma+cvt*/) * topk_ + has_res_) * num_token_ * ci_;
|
||||
return size_calc;
|
||||
}
|
||||
size_t getTheoryIO(void) const override {
|
||||
size_t data_dw = 0;
|
||||
cnnlGetSizeOfDataType(data_type_, &data_dw);
|
||||
size_t bias_io = has_bias_ ? expert_size_ * ci_ : 0;
|
||||
size_t res_io = has_res_ ? num_token_ * ci_ : 0;
|
||||
return data_dw * (bias_io + res_io + num_token_ * topk_ * ci_ /*input*/ +
|
||||
num_token_ * ci_ /*output*/ + num_token_ * topk_ /*reduce_weight*/);
|
||||
}
|
||||
cnnlDataType_t getCalcDtype(void) const override { return data_type_; }
|
||||
const std::string &getOpName(void) const override { return name_; }
|
||||
};
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_TORCH_API_OP_THEORY_H_
|
||||
25
torch_mlu_ops-v1.3.2/csrc/torch_api/preload.cpp
Normal file
25
torch_mlu_ops-v1.3.2/csrc/torch_api/preload.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/preload.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
void preload(const torch::Tensor &weight, const int64_t size) {
|
||||
const torch_mlu::mlu::MLUGuard device_guard(weight.device());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
|
||||
invokePreload(queue, weight.data_ptr(), weight.element_size() * weight.numel(), size);
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
115
torch_mlu_ops-v1.3.2/csrc/torch_api/quant_matmul.cpp
Normal file
115
torch_mlu_ops-v1.3.2/csrc/torch_api/quant_matmul.cpp
Normal file
@@ -0,0 +1,115 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
at::Tensor quant_matmul(const at::Tensor &a_tensor,
|
||||
const c10::optional<at::Tensor> &a_scale,
|
||||
const c10::optional<at::Tensor> &a_zero,
|
||||
const at::Tensor &b_tensor,
|
||||
const c10::optional<at::Tensor> &b_scale,
|
||||
const c10::optional<at::Tensor> &b_zero,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<at::Tensor> &c_tensor,
|
||||
const c10::optional<at::Tensor> &c_scale,
|
||||
const c10::optional<at::Tensor> &c_zero,
|
||||
const c10::optional<at::Tensor> &gemm_output_scale,
|
||||
const c10::optional<at::Tensor> &gemm_output_zero,
|
||||
const c10::optional<std::string> &data_type,
|
||||
const c10::optional<at::Tensor> &d,
|
||||
const std::string &quant_algo,
|
||||
const std::string &a_quant_layout,
|
||||
const std::string &b_quant_layout,
|
||||
int64_t quant_bit_size,
|
||||
const std::string &act_mode,
|
||||
bool use_hp_active,
|
||||
double act_coef,
|
||||
double alpha,
|
||||
double beta,
|
||||
bool trans_a,
|
||||
bool trans_b) {
|
||||
// check device and dtype
|
||||
// only support weight only or smooth quant
|
||||
TORCH_CHECK(quant_algo == "weight_only" || quant_algo == "smooth_quant", "illegal quant mode");
|
||||
TORCH_CHECK(quant_bit_size == 4 || quant_bit_size == 8, "illegal quant bit size");
|
||||
TORCH_CHECK(trans_a == false && trans_b == true, "illegal trans mode");
|
||||
// create output tensor
|
||||
auto tensor_options = a_tensor.options();
|
||||
std::vector<int64_t> output_shape = {a_tensor.sizes()[0], b_tensor.sizes()[0]};
|
||||
at::Tensor output_tensor;
|
||||
if (d.has_value()) {
|
||||
output_tensor = d.value();
|
||||
} else if (data_type.has_value()) {
|
||||
auto dtype = data_type.value();
|
||||
TORCH_CHECK(dtype == "float" || dtype == "half" || dtype == "bfloat16",
|
||||
"data type must be 'float', 'half' or 'bfloat16'")
|
||||
auto torch_dtype = str2TorchDtype(dtype);
|
||||
output_tensor = at::empty(output_shape, tensor_options.dtype(torch_dtype));
|
||||
} else {
|
||||
output_tensor = at::empty(output_shape, tensor_options);
|
||||
}
|
||||
|
||||
// create tensor desc
|
||||
auto descs = createTensorDescs(
|
||||
{a_tensor, a_scale.value_or(torch::Tensor()), a_zero.value_or(torch::Tensor()), b_tensor,
|
||||
b_scale.value_or(torch::Tensor()), b_zero.value_or(torch::Tensor()),
|
||||
bias.value_or(torch::Tensor()), c_tensor.value_or(torch::Tensor()),
|
||||
c_scale.value_or(torch::Tensor()), c_zero.value_or(torch::Tensor()), output_tensor,
|
||||
gemm_output_scale.value_or(torch::Tensor()), gemm_output_zero.value_or(torch::Tensor())});
|
||||
|
||||
// get && set Op desc
|
||||
auto compute_dtype = getCnnlDataType(output_tensor.scalar_type());
|
||||
auto op_desc = tmo::op_desc::QuantMatmulDesc(quant_algo, a_quant_layout, b_quant_layout,
|
||||
quant_bit_size, compute_dtype, act_mode,
|
||||
use_hp_active, act_coef, trans_a, trans_b);
|
||||
|
||||
auto handle = torch_mlu::getCurrentHandle();
|
||||
// get workspace size
|
||||
size_t workspace_size = 0;
|
||||
CNNL_CHECK_FATAL(cnnlGetLLMQuantMatmulWorkspaceSize(handle, op_desc, descs[0].get(),
|
||||
descs[3].get(), descs[7].get(),
|
||||
descs[10].get(), &workspace_size));
|
||||
auto workspace =
|
||||
at::empty({static_cast<int64_t>(workspace_size)}, a_tensor.options().dtype(at::kByte));
|
||||
|
||||
// run forward
|
||||
const cnnlTensorDescriptor_t a_descs[] = {descs[0].get(), descs[1].get(), descs[2].get()};
|
||||
const cnnlTensorDescriptor_t b_descs[] = {descs[3].get(), descs[4].get(), descs[5].get()};
|
||||
const cnnlTensorDescriptor_t c_descs[] = {descs[7].get(), descs[8].get(), descs[9].get()};
|
||||
const cnnlTensorDescriptor_t d_descs[] = {descs[10].get(), nullptr, nullptr};
|
||||
const void *a_tensors[] = {getAtTensorPtr(a_tensor), getAtTensorPtr(a_scale),
|
||||
getAtTensorPtr(a_zero)};
|
||||
const void *b_tensors[] = {getAtTensorPtr(b_tensor), getAtTensorPtr(b_scale),
|
||||
getAtTensorPtr(b_zero)};
|
||||
const void *c_tensors[] = {getAtTensorPtr(c_tensor), getAtTensorPtr(c_scale),
|
||||
getAtTensorPtr(c_zero)};
|
||||
void *d_tensors[] = {getAtTensorPtr(output_tensor), nullptr, nullptr};
|
||||
float alpha_arr[1] = {(float)alpha};
|
||||
float beta_arr[1] = {(float)beta};
|
||||
bool has_res = c_tensor.has_value();
|
||||
QuantMatmulTheory obj(a_tensor.size(0), a_tensor.size(1), b_tensor.size(0), quant_bit_size,
|
||||
has_res, compute_dtype, quant_algo);
|
||||
cnpxPush(obj);
|
||||
CNNL_CHECK_FATAL(cnnlLLMQuantMatmul(
|
||||
handle, op_desc, nullptr, alpha_arr, a_descs, a_tensors, b_descs, b_tensors, nullptr,
|
||||
beta_arr, c_descs, c_tensors, descs[6].get(), getAtTensorPtr(bias), descs[11].get(),
|
||||
getAtTensorPtr(gemm_output_scale), descs[12].get(), getAtTensorPtr(gemm_output_zero),
|
||||
getAtTensorPtr(workspace), workspace_size, d_descs, d_tensors));
|
||||
cnpxPop();
|
||||
return output_tensor;
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
208
torch_mlu_ops-v1.3.2/csrc/torch_api/quant_to_linear_cache.cpp
Normal file
208
torch_mlu_ops-v1.3.2/csrc/torch_api/quant_to_linear_cache.cpp
Normal file
@@ -0,0 +1,208 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/quant_to_linear_cache.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
void quant_to_linear_cache(
|
||||
const at::Tensor &key, // [total_seqlen, num_heads, head_size]
|
||||
const c10::optional<at::Tensor> &value, // or [bs, max_context_len, num_heads, head_size]
|
||||
at::Tensor &key_cache, // [max_bs, num_heads, cache_memory_len, head_size]
|
||||
const c10::optional<at::Tensor>
|
||||
&value_cache, // [max_bs, num_heads, cache_memory_len, head_size]
|
||||
at::Tensor &key_cache_scale, // [max_bs, num_heads, cache_memory_len]
|
||||
const c10::optional<at::Tensor> &value_cache_scale,
|
||||
const at::Tensor &context_lengths,
|
||||
const int64_t max_context_len,
|
||||
bool packed,
|
||||
const c10::optional<at::Tensor> &context_seq_offset,
|
||||
const c10::optional<at::Tensor> &cache_bs_id,
|
||||
const c10::optional<at::Tensor> &cache_seqlen_offset,
|
||||
const int64_t quant_bit) {
|
||||
// check device and dtype
|
||||
TORCH_CHECK(key.scalar_type() == torch::kFloat32 || key.scalar_type() == torch::kFloat16 ||
|
||||
key.scalar_type() == torch::kBFloat16,
|
||||
"key type need be torch::kFloat32, torch::kFloat16 or torch::kBFloat16");
|
||||
TORCH_CHECK(key_cache.scalar_type() == torch::kInt8, "key_cache type need be torch::kInt8");
|
||||
TORCH_CHECK(key_cache_scale.scalar_type() == torch::kFloat32,
|
||||
"key_cache_scale type need be torch::kFloat32");
|
||||
if (value.has_value() || value_cache.has_value() || value_cache_scale.has_value()) {
|
||||
TORCH_CHECK(value.has_value() && value_cache.has_value() && value_cache_scale.has_value(),
|
||||
"value_cache, value_cache_scale, value must all have value.")
|
||||
}
|
||||
if (value_cache.has_value()) {
|
||||
TORCH_CHECK(value_cache.value().scalar_type() == torch::kInt8,
|
||||
"value_cache type need be torch::kInt8");
|
||||
TORCH_CHECK(value_cache_scale.value().scalar_type() == torch::kFloat32,
|
||||
"value_cache_scale type need be torch::kFloat32");
|
||||
}
|
||||
if (value.has_value()) {
|
||||
TORCH_CHECK(value.value().scalar_type() == key.scalar_type(),
|
||||
"value type need be same with key");
|
||||
}
|
||||
TORCH_CHECK(context_lengths.scalar_type() == torch::kInt32,
|
||||
"context_lengths type need be torch::kInt32");
|
||||
const int batch_size = packed ? context_lengths.size(0) - 1 : context_lengths.size(0);
|
||||
|
||||
const int max_bs = key_cache.size(0);
|
||||
const int num_heads = key_cache.size(1);
|
||||
const int cache_mem_len = key_cache.size(2);
|
||||
const int head_size = key.size(-1);
|
||||
// check quantize param
|
||||
TORCH_CHECK(quant_bit == 8 || quant_bit == 4, "quant_bit must be 8 or 4");
|
||||
if (key_cache_scale.dim() == 4) {
|
||||
TORCH_CHECK(key_cache_scale.size(-1) > 0 && head_size % key_cache_scale.size(-1) == 0,
|
||||
"key_cache_scale.size(-1) must be legal");
|
||||
}
|
||||
int group_size = key_cache_scale.dim() == 3 ? head_size : head_size / key_cache_scale.size(-1);
|
||||
|
||||
const int total_seqlen = packed ? key.size(0) : batch_size * key.size(-2);
|
||||
|
||||
const size_t context_bs_stride = packed ? 0 : key.stride(0);
|
||||
const size_t context_head_stride = packed ? key.stride(1) : key.stride(2);
|
||||
const size_t context_seq_stride = packed ? key.stride(0) : key.stride(1);
|
||||
const size_t cache_bs_stride = key_cache.stride(0);
|
||||
const size_t cache_head_stride = key_cache.stride(1);
|
||||
const size_t key_cache_seq_stride = key_cache.stride(2);
|
||||
size_t value_cache_seq_stride = head_size;
|
||||
const size_t cache_scale_bs_stride = key_cache_scale.stride(0);
|
||||
const size_t cache_scale_head_stride = key_cache_scale.stride(1);
|
||||
|
||||
if (key_cache_scale.dim() == 3) {
|
||||
CHECK_SHAPE(key_cache_scale, max_bs, num_heads, cache_mem_len);
|
||||
} else {
|
||||
CHECK_SHAPE(key_cache_scale, max_bs, num_heads, cache_mem_len, head_size / group_size);
|
||||
}
|
||||
if (quant_bit == 8) {
|
||||
TORCH_CHECK(key_cache.size(-1) == head_size, "last dim of key_cache should be head_size");
|
||||
} else {
|
||||
TORCH_CHECK(key_cache.size(-1) == head_size / 2, "last dim of key_cache should be head_size");
|
||||
}
|
||||
if (value_cache.has_value()) {
|
||||
if (quant_bit == 8) {
|
||||
CHECK_SHAPE(value_cache.value(), max_bs, num_heads, cache_mem_len, head_size);
|
||||
} else if (quant_bit == 4) {
|
||||
CHECK_SHAPE(value_cache.value(), max_bs, num_heads, cache_mem_len / 2, head_size);
|
||||
}
|
||||
if (key_cache_scale.dim() == 3) {
|
||||
CHECK_SHAPE(value_cache_scale.value(), max_bs, num_heads, cache_mem_len);
|
||||
} else {
|
||||
CHECK_SHAPE(value_cache_scale.value(), max_bs, num_heads, cache_mem_len,
|
||||
head_size / group_size);
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t device_id = key.get_device();
|
||||
TORCH_CHECK(key_cache.get_device() == device_id,
|
||||
"key_cache tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", key_cache.get_device());
|
||||
TORCH_CHECK(key_cache_scale.get_device() == device_id,
|
||||
"key_cache_scale tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", key_cache_scale.get_device());
|
||||
|
||||
bool has_context_seq_offset = context_seq_offset.has_value();
|
||||
const void *context_seq_offset_ptr = nullptr;
|
||||
if (has_context_seq_offset) {
|
||||
CHECK_SHAPE(context_seq_offset.value(), batch_size);
|
||||
TORCH_CHECK(context_seq_offset.value().scalar_type() == torch::kInt32,
|
||||
"context_seq_offset type need be torch::kInt32");
|
||||
TORCH_CHECK(context_seq_offset.value().get_device() == device_id,
|
||||
"context_seq_offset tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", context_seq_offset.value().get_device());
|
||||
context_seq_offset_ptr = context_seq_offset.value().data_ptr();
|
||||
}
|
||||
|
||||
bool has_cache_bs_id = cache_bs_id.has_value();
|
||||
const void *cache_bs_id_ptr = nullptr;
|
||||
if (has_cache_bs_id) {
|
||||
CHECK_SHAPE(cache_bs_id.value(), batch_size);
|
||||
TORCH_CHECK(cache_bs_id.value().scalar_type() == torch::kInt32,
|
||||
"cache_bs_id type need be torch::kInt32");
|
||||
TORCH_CHECK(cache_bs_id.value().get_device() == device_id,
|
||||
"cache_bs_id tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", cache_bs_id.value().get_device());
|
||||
cache_bs_id_ptr = cache_bs_id.value().data_ptr();
|
||||
}
|
||||
|
||||
bool has_cache_seqlen_offset = cache_seqlen_offset.has_value();
|
||||
const void *cache_seqlen_offset_ptr = nullptr;
|
||||
if (has_cache_seqlen_offset) {
|
||||
CHECK_SHAPE(cache_seqlen_offset.value(), batch_size);
|
||||
TORCH_CHECK(cache_seqlen_offset.value().scalar_type() == torch::kInt32,
|
||||
"cache_seqlen_offset type need be torch::kInt32");
|
||||
TORCH_CHECK(cache_seqlen_offset.value().get_device() == device_id,
|
||||
"cache_seqlen_offset tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", cache_seqlen_offset.value().get_device());
|
||||
cache_seqlen_offset_ptr = cache_seqlen_offset.value().data_ptr();
|
||||
}
|
||||
|
||||
// check contiguous
|
||||
TORCH_CHECK(key_cache.is_contiguous(), "key_cache tensor must be contiguous.");
|
||||
TORCH_CHECK(key_cache_scale.is_contiguous(), "key_cache_scale tensor must be contiguous.");
|
||||
if (value_cache.has_value()) {
|
||||
TORCH_CHECK(value_cache.value().is_contiguous(), "value_cache tensor must be contiguous.");
|
||||
TORCH_CHECK(value_cache.value().get_device() == device_id,
|
||||
"value_cache tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", value_cache.value().get_device());
|
||||
TORCH_CHECK(value_cache_scale.value().is_contiguous(),
|
||||
"value_cache_scale tensor must be contiguous.");
|
||||
TORCH_CHECK(value_cache_scale.value().get_device() == device_id,
|
||||
"value_cache_scale tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", value_cache_scale.value().get_device());
|
||||
}
|
||||
|
||||
// check shape
|
||||
if (packed) {
|
||||
CHECK_SHAPE(key, total_seqlen, num_heads, head_size);
|
||||
if (value.has_value()) {
|
||||
CHECK_SHAPE(value.value(), total_seqlen, num_heads, head_size);
|
||||
}
|
||||
CHECK_SHAPE(context_lengths, batch_size + 1);
|
||||
} else {
|
||||
CHECK_SHAPE(key, batch_size, key.size(1), num_heads, head_size);
|
||||
if (value.has_value()) {
|
||||
CHECK_SHAPE(value.value(), batch_size, key.size(1), num_heads, head_size);
|
||||
}
|
||||
CHECK_SHAPE(context_lengths, batch_size);
|
||||
}
|
||||
|
||||
if (value.has_value()) {
|
||||
for (int i = 0; i < key.dim(); i++) {
|
||||
TORCH_CHECK(value.value().stride(i) == key.stride(i),
|
||||
"key and value must have same stride along axi ", i);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(key_cache.size(0) >= batch_size,
|
||||
"first dim of key_cache must be great than or equal to batch_size");
|
||||
TORCH_CHECK(key_cache.dim() == 4, "dim of key_cache must be 4.");
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(key.device());
|
||||
auto data_dtype = getCnnlDataType(key.scalar_type());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
|
||||
// run forward
|
||||
invokeQuantToLinearCache(
|
||||
queue, getAtTensorPtr(key_cache), getAtTensorPtr(value_cache),
|
||||
getAtTensorPtr(key_cache_scale), getAtTensorPtr(value_cache_scale), cache_bs_id_ptr,
|
||||
cache_seqlen_offset_ptr, getAtTensorPtr(key), getAtTensorPtr(value), context_seq_offset_ptr,
|
||||
getAtTensorPtr(context_lengths), data_dtype, batch_size, num_heads, head_size,
|
||||
(int)max_context_len, cache_mem_len, context_bs_stride, context_head_stride,
|
||||
context_seq_stride, cache_bs_stride, cache_head_stride, key_cache_seq_stride,
|
||||
value_cache_seq_stride, cache_scale_bs_stride, cache_scale_head_stride, packed,
|
||||
(int)quant_bit, group_size);
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
81
torch_mlu_ops-v1.3.2/csrc/torch_api/quant_to_paged_cache.cpp
Normal file
81
torch_mlu_ops-v1.3.2/csrc/torch_api/quant_to_paged_cache.cpp
Normal file
@@ -0,0 +1,81 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/quant_to_paged_cache.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
void quant_to_paged_cache(
|
||||
const torch::Tensor &k, // [num_tokens, num_heads, head_dim]
|
||||
const c10::optional<torch::Tensor> &v, // [num_tokens, num_heads, head_dim]
|
||||
torch::Tensor &k_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
const c10::optional<torch::Tensor> &v_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
torch::Tensor &k_cache_scale, // [num_blocks, num_heads, block_size]
|
||||
const c10::optional<torch::Tensor> &v_cache_scale, // [num_blocks, num_heads, block_size]
|
||||
const torch::Tensor &slot_mapping) {
|
||||
// 1. check device and tensor type
|
||||
TORCH_CHECK(slot_mapping.dtype() == torch::kInt32, "slot_mapping type need be torch::kInt32");
|
||||
TORCH_CHECK(slot_mapping.get_device() == k.get_device(),
|
||||
"Tensor device index is not same, original index: ", k.get_device(),
|
||||
"now index is: ", slot_mapping.get_device());
|
||||
|
||||
// 2. check shape
|
||||
const int num_tokens = k.size(0);
|
||||
const int num_heads = k.size(1);
|
||||
const int head_dim = k.size(2);
|
||||
const int block_size = k_cache.size(2);
|
||||
const int head_size = k_cache.size(3);
|
||||
const int num_blocks = k_cache.size(0);
|
||||
|
||||
CHECK_SHAPE(k, num_tokens, num_heads, head_dim);
|
||||
CHECK_SHAPE(k_cache, num_blocks, num_heads, block_size, head_dim);
|
||||
CHECK_SHAPE(k_cache_scale, num_blocks, num_heads, block_size);
|
||||
TORCH_CHECK(v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value(),
|
||||
"v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value().")
|
||||
if (v.has_value()) {
|
||||
CHECK_SHAPE(v.value(), num_tokens, num_heads, head_dim);
|
||||
CHECK_SHAPE(v_cache.value(), num_blocks, num_heads, block_size, head_dim);
|
||||
CHECK_SHAPE(v_cache_scale.value(), num_blocks, num_heads, block_size);
|
||||
}
|
||||
CHECK_SHAPE(slot_mapping, num_tokens);
|
||||
|
||||
// 3. check strides
|
||||
TORCH_CHECK(slot_mapping.is_contiguous(), "slot_mapping need be contiguous.");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "k last dim must be contiguous.");
|
||||
TORCH_CHECK(k.stride(-2) == head_dim, "k second dim must be contiguous.");
|
||||
if (v.has_value()) {
|
||||
TORCH_CHECK(v.value().stride(-1) == 1, "v last dim must be contiguous.");
|
||||
TORCH_CHECK(v.value().stride(-2) == head_dim, "v second dim must be contiguous.");
|
||||
}
|
||||
|
||||
int64_t kv_cache_range =
|
||||
(int64_t)num_blocks * num_heads * block_size * head_dim * k_cache.element_size();
|
||||
TORCH_CHECK(kv_cache_range <= UINT32_MAX, "The addressing range of kv_cache cannot exceed 4G.");
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(k.device());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
cnnlDataType_t dtype = getCnnlDataType(k.scalar_type());
|
||||
int32_t key_stride0 = static_cast<int32_t>(k.stride(0));
|
||||
int32_t value_stride0 = 1;
|
||||
if (v.has_value()) {
|
||||
value_stride0 = static_cast<int32_t>(v.value().stride(0));
|
||||
}
|
||||
|
||||
TMO_KERNEL_CHECK_FATAL(tmo::invokeQuantToPagedCache(
|
||||
queue, dtype, getAtTensorPtr(k), getAtTensorPtr(v), getAtTensorPtr(k_cache),
|
||||
getAtTensorPtr(v_cache), getAtTensorPtr(k_cache_scale), getAtTensorPtr(v_cache_scale),
|
||||
getAtTensorPtr(slot_mapping), key_stride0, value_stride0, num_tokens, num_heads, num_blocks,
|
||||
block_size, head_size));
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
136
torch_mlu_ops-v1.3.2/csrc/torch_api/reshape_linear_cache.cpp
Normal file
136
torch_mlu_ops-v1.3.2/csrc/torch_api/reshape_linear_cache.cpp
Normal file
@@ -0,0 +1,136 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/reshape_linear_cache.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
void reshape_linear_cache(const at::Tensor &key,
|
||||
const c10::optional<at::Tensor> &value,
|
||||
at::Tensor &key_cache,
|
||||
const c10::optional<at::Tensor> &value_cache,
|
||||
const at::Tensor &context_lengths,
|
||||
const int64_t max_context_len,
|
||||
bool packed,
|
||||
const c10::optional<at::Tensor> &context_seq_offset,
|
||||
const c10::optional<at::Tensor> &cache_bs_id,
|
||||
const c10::optional<at::Tensor> &cache_seqlen_offset) {
|
||||
// check device and dtype
|
||||
checkTensorSameAttr<TensorAttr::ALL>(key, key_cache, value, value_cache);
|
||||
|
||||
const int batch_size = packed ? context_lengths.size(0) - 1 : context_lengths.size(0);
|
||||
|
||||
const int max_bs = key_cache.size(0);
|
||||
const int num_heads = key_cache.size(1);
|
||||
const int cache_mem_len = key_cache.size(2);
|
||||
const int head_size = key_cache.size(3);
|
||||
|
||||
const int total_seqlen = packed ? key.size(0) : batch_size * key.size(1);
|
||||
|
||||
const int context_bs_stride = packed ? 0 : key.stride(0);
|
||||
const int context_head_stride = packed ? key.stride(1) : key.stride(2);
|
||||
const int context_seq_stride = packed ? key.stride(0) : key.stride(1);
|
||||
const int cache_bs_stride = key_cache.stride(0);
|
||||
const int cache_head_stride = key_cache.stride(1);
|
||||
const int cache_seq_stride = key_cache.stride(2);
|
||||
|
||||
TORCH_CHECK(context_lengths.scalar_type() == torch::kInt32,
|
||||
"context_lengths type need be torch::kInt32");
|
||||
|
||||
if (value_cache.has_value()) {
|
||||
CHECK_SHAPE(value_cache.value(), max_bs, num_heads, cache_mem_len, head_size);
|
||||
}
|
||||
|
||||
const int64_t device_id = key.get_device();
|
||||
bool has_context_seq_offset = context_seq_offset.has_value();
|
||||
const void *context_seq_offset_ptr = nullptr;
|
||||
if (has_context_seq_offset) {
|
||||
CHECK_SHAPE(context_seq_offset.value(), batch_size);
|
||||
TORCH_CHECK(context_seq_offset.value().scalar_type() == torch::kInt32,
|
||||
"context_seq_offset type need be torch::kInt32");
|
||||
TORCH_CHECK(context_seq_offset.value().get_device() == device_id,
|
||||
"context_seq_offset tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", context_seq_offset.value().get_device());
|
||||
context_seq_offset_ptr = context_seq_offset.value().data_ptr();
|
||||
}
|
||||
|
||||
bool has_cache_bs_id = cache_bs_id.has_value();
|
||||
const void *cache_bs_id_ptr = nullptr;
|
||||
if (has_cache_bs_id) {
|
||||
CHECK_SHAPE(cache_bs_id.value(), batch_size);
|
||||
TORCH_CHECK(cache_bs_id.value().scalar_type() == torch::kInt32,
|
||||
"cache_bs_id type need be torch::kInt32");
|
||||
TORCH_CHECK(cache_bs_id.value().get_device() == device_id,
|
||||
"cache_bs_id tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", cache_bs_id.value().get_device());
|
||||
cache_bs_id_ptr = cache_bs_id.value().data_ptr();
|
||||
}
|
||||
|
||||
bool has_cache_seqlen_offset = cache_seqlen_offset.has_value();
|
||||
const void *cache_seqlen_offset_ptr = nullptr;
|
||||
if (has_cache_seqlen_offset) {
|
||||
CHECK_SHAPE(cache_seqlen_offset.value(), batch_size);
|
||||
TORCH_CHECK(cache_seqlen_offset.value().scalar_type() == torch::kInt32,
|
||||
"cache_seqlen_offset type need be torch::kInt32");
|
||||
TORCH_CHECK(cache_seqlen_offset.value().get_device() == device_id,
|
||||
"cache_seqlen_offset tensor device index is not same, original index: ", device_id,
|
||||
"now index is: ", cache_seqlen_offset.value().get_device());
|
||||
cache_seqlen_offset_ptr = cache_seqlen_offset.value().data_ptr();
|
||||
}
|
||||
|
||||
// check contiguous
|
||||
TORCH_CHECK(key_cache.is_contiguous(), "key_cache tensor must be contiguous.");
|
||||
if (value_cache.has_value()) {
|
||||
TORCH_CHECK(value_cache.value().is_contiguous(), "value_cache tensor must be contiguous.");
|
||||
}
|
||||
|
||||
// check shape
|
||||
if (packed) {
|
||||
CHECK_SHAPE(key, total_seqlen, num_heads, head_size);
|
||||
if (value.has_value()) {
|
||||
CHECK_SHAPE(value.value(), total_seqlen, num_heads, head_size);
|
||||
}
|
||||
CHECK_SHAPE(context_lengths, batch_size + 1);
|
||||
} else {
|
||||
CHECK_SHAPE(key, batch_size, key.size(1), num_heads, head_size);
|
||||
if (value.has_value()) {
|
||||
CHECK_SHAPE(value.value(), batch_size, key.size(1), num_heads, head_size);
|
||||
}
|
||||
CHECK_SHAPE(context_lengths, batch_size);
|
||||
}
|
||||
|
||||
if (value.has_value()) {
|
||||
for (int i = 0; i < key.dim(); i++) {
|
||||
TORCH_CHECK(value.value().stride(i) == key.stride(i),
|
||||
"key and value must have same stride along axis ", i);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(key_cache.size(0) >= batch_size,
|
||||
"first dim of key_cache must be great than or equal to batch_size");
|
||||
TORCH_CHECK(key_cache.dim() == 4, "dim of key_cache must be 4.");
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(key.device());
|
||||
auto data_dtype = getCnnlDataType(key.scalar_type());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
|
||||
// run forward
|
||||
TMO_KERNEL_CHECK_FATAL(invokeReshapeLinearCache(
|
||||
queue, getAtTensorPtr(key_cache), getAtTensorPtr(value_cache), cache_bs_id_ptr,
|
||||
cache_seqlen_offset_ptr, getAtTensorPtr(key), getAtTensorPtr(value), context_seq_offset_ptr,
|
||||
getAtTensorPtr(context_lengths), data_dtype, batch_size, num_heads, head_size,
|
||||
(int)max_context_len, cache_mem_len, context_bs_stride, context_head_stride,
|
||||
context_seq_stride, cache_bs_stride, cache_head_stride, cache_seq_stride, packed));
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
80
torch_mlu_ops-v1.3.2/csrc/torch_api/reshape_paged_cache.cpp
Normal file
80
torch_mlu_ops-v1.3.2/csrc/torch_api/reshape_paged_cache.cpp
Normal file
@@ -0,0 +1,80 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/reshape_paged_cache.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
void reshape_paged_cache(
|
||||
const torch::Tensor &k, // [num_tokens, num_heads, head_dim]
|
||||
const c10::optional<torch::Tensor> &v, // [num_tokens, num_heads, head_dim]
|
||||
torch::Tensor &k_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
const c10::optional<torch::Tensor> &v_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
const torch::Tensor &slot_mapping) {
|
||||
// 1. check device and tensor type
|
||||
checkTensorSameAttr<TensorAttr::ALL>(k, v, k_cache, v_cache);
|
||||
TORCH_CHECK(slot_mapping.dtype() == torch::kInt32 || slot_mapping.dtype() == torch::kLong,
|
||||
"slot_mapping type need be torch::kInt32 or torch::Long");
|
||||
TORCH_CHECK(slot_mapping.get_device() == k.get_device(),
|
||||
"Tensor device index is not same, original index: ", k.get_device(),
|
||||
"now index is: ", slot_mapping.get_device());
|
||||
|
||||
// 2. check shape
|
||||
const int num_tokens = k.size(0);
|
||||
const int num_heads = k.size(1);
|
||||
const int head_dim = k.size(2);
|
||||
const int block_size = k_cache.size(2);
|
||||
const int head_size = k_cache.size(3);
|
||||
const int num_blocks = k_cache.size(0);
|
||||
|
||||
CHECK_SHAPE(k, num_tokens, num_heads, head_dim);
|
||||
CHECK_SHAPE(k_cache, num_blocks, num_heads, block_size, head_dim);
|
||||
TORCH_CHECK(v.has_value() == v_cache.has_value(), "v.has_value() == v_cache.has_value().")
|
||||
if (v.has_value()) {
|
||||
CHECK_SHAPE(v.value(), num_tokens, num_heads, head_dim);
|
||||
CHECK_SHAPE(v_cache.value(), num_blocks, num_heads, block_size, head_dim);
|
||||
}
|
||||
CHECK_SHAPE(slot_mapping, num_tokens);
|
||||
|
||||
// 3. check strides
|
||||
TORCH_CHECK(k_cache.is_contiguous(), "k_cache need be contiguous.");
|
||||
TORCH_CHECK(slot_mapping.is_contiguous(), "slot_mapping need be contiguous.");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "k last dim must be contiguous.");
|
||||
TORCH_CHECK(k.stride(-2) == head_dim, "k second dim must be contiguous.");
|
||||
if (v.has_value()) {
|
||||
TORCH_CHECK(v_cache.value().is_contiguous(), "v_cache need be contiguous.");
|
||||
TORCH_CHECK(v.value().stride(-1) == 1, "v last dim must be contiguous.");
|
||||
TORCH_CHECK(v.value().stride(-2) == head_dim, "v second dim must be contiguous.");
|
||||
}
|
||||
|
||||
// check large tensor
|
||||
int64_t kv_cache_range =
|
||||
(int64_t)num_blocks * num_heads * block_size * head_dim * k_cache.element_size();
|
||||
TORCH_CHECK(kv_cache_range <= UINT32_MAX, "The addressing range of kv_cache cannot exceed 4G.");
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(k.device());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
cnnlDataType_t dtype = getCnnlDataType(k.scalar_type());
|
||||
int32_t key_stride0 = static_cast<int32_t>(k.stride(0));
|
||||
int32_t value_stride0 = 1;
|
||||
if (v.has_value()) {
|
||||
value_stride0 = static_cast<int32_t>(v.value().stride(0));
|
||||
}
|
||||
|
||||
TMO_KERNEL_CHECK_FATAL(tmo::invokeReshapePagedCache(
|
||||
queue, dtype, getAtTensorPtr(k), getAtTensorPtr(v), getAtTensorPtr(k_cache),
|
||||
getAtTensorPtr(v_cache), getAtTensorPtr(slot_mapping), key_stride0, value_stride0, num_tokens,
|
||||
num_heads, num_blocks, block_size, head_size));
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,165 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "torch_ops_api.h"
|
||||
#include "utils.h"
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
// q_ori, out Tensor shape is [batch, seq_len, head_num, head_size]
|
||||
// if (pagedattn) k_cache, v_cache shape is [num_blocks, k_head_num, block_size, head_size]
|
||||
// else [max_bs, head_num, max_seq_len, head_size]
|
||||
// if(pagedattn) k_cache_quant_scale shape is [num_blocks, k_head_num, block_size] else [max_bs,
|
||||
// head_num, max_seq_len]
|
||||
// if (pagedattn) block_tables shape is [bs, max_nblock_per_seq] else [bs, 1]
|
||||
// context_lens: true seqkv_len of each bacth
|
||||
void single_query_cached_kv_attn(
|
||||
const torch::Tensor &q_ori,
|
||||
const torch::Tensor &k_cache,
|
||||
const torch::Tensor &v_cache,
|
||||
const torch::Tensor &output,
|
||||
const torch::Tensor &block_tables,
|
||||
const torch::Tensor &context_lens, // [batch]
|
||||
const c10::optional<torch::Tensor> &output_lse,
|
||||
const c10::optional<torch::Tensor> &k_cache_quant_scale,
|
||||
const c10::optional<torch::Tensor> &v_cache_quant_scale,
|
||||
const c10::optional<torch::Tensor> &alibi_slopes, // [bs, head_num] or [head_num]
|
||||
int64_t max_context_len,
|
||||
int64_t windows_size_left,
|
||||
int64_t windows_size_right,
|
||||
double softmax_scale,
|
||||
bool return_lse,
|
||||
int64_t kv_cache_quant_bit_size) {
|
||||
// Check tensor type and tensor device.
|
||||
cnnlQuantizeLayout_t cache_quant_layout = CNNL_QUANTIZE_NONE;
|
||||
bool is_kv_quant = k_cache_quant_scale.has_value();
|
||||
bool has_alibi = alibi_slopes.has_value();
|
||||
|
||||
// Check Tensor Shape
|
||||
int batch = q_ori.size(0);
|
||||
int seq_q = q_ori.size(1);
|
||||
int head_num = q_ori.size(2);
|
||||
int qk_head_size = q_ori.size(3);
|
||||
int k_head_num = k_cache.size(1);
|
||||
int v_head_size = v_cache.size(-1);
|
||||
int max_num_block_per_seq = block_tables.size(1);
|
||||
int num_blocks = k_cache.size(0);
|
||||
int block_size = k_cache.size(2);
|
||||
TORCH_CHECK(windows_size_right < 0, "only support windows_size_right < 0 currently.");
|
||||
TORCH_CHECK(head_num % k_head_num == 0, "num_heads need be mutiple of num_kv_heads.");
|
||||
TORCH_CHECK(
|
||||
kv_cache_quant_bit_size == 4 || kv_cache_quant_bit_size == 8 || kv_cache_quant_bit_size == -1,
|
||||
"illegal quant bit size, only support 4, 8 or -1.");
|
||||
CHECK_SHAPE(block_tables, batch, max_num_block_per_seq);
|
||||
if (kv_cache_quant_bit_size == 4) {
|
||||
int v_cache_len = PAD_UP_DIV(block_size, 2);
|
||||
CHECK_SHAPE(k_cache, num_blocks, k_head_num, block_size, qk_head_size / 2);
|
||||
CHECK_SHAPE(v_cache, num_blocks, k_head_num, v_cache_len, v_head_size);
|
||||
} else {
|
||||
CHECK_SHAPE(k_cache, num_blocks, k_head_num, block_size, qk_head_size);
|
||||
CHECK_SHAPE(v_cache, num_blocks, k_head_num, block_size, v_head_size);
|
||||
}
|
||||
TORCH_CHECK(q_ori.stride(-1) == 1 && q_ori.stride(-2) == qk_head_size,
|
||||
"q last two dim need be contiguous.");
|
||||
TORCH_CHECK(k_cache.is_contiguous() && v_cache.is_contiguous(),
|
||||
"k_cache and v_cache need be contiguous.");
|
||||
TORCH_CHECK(isMlu(q_ori), "q_ori need be mlu tensor.");
|
||||
|
||||
checkTensorSameAttr<TensorAttr::DEVICE>(q_ori, k_cache, v_cache, output, block_tables,
|
||||
context_lens, k_cache_quant_scale, v_cache_quant_scale,
|
||||
output_lse, alibi_slopes);
|
||||
if (is_kv_quant) {
|
||||
TORCH_CHECK(k_cache_quant_scale.value().dim() == 2 || k_cache_quant_scale.value().dim() == 3 ||
|
||||
k_cache_quant_scale.value().dim() == 4,
|
||||
"k_cache_quant_scale must be 2d or 3d or 4d.");
|
||||
TORCH_CHECK(k_cache_quant_scale.value().dim() == v_cache_quant_scale.value().dim(),
|
||||
"the dim of k_cache_quant_scale and v_cache_quant_scale must be euqal.");
|
||||
if (k_cache_quant_scale.value().dim() == 2) {
|
||||
CHECK_SHAPE(k_cache_quant_scale.value(), k_head_num, qk_head_size);
|
||||
CHECK_SHAPE(v_cache_quant_scale.value(), k_head_num, v_head_size);
|
||||
cache_quant_layout = CNNL_QUANTIZE_PER_CHANNEL;
|
||||
} else if (k_cache_quant_scale.value().dim() == 3) {
|
||||
CHECK_SHAPE(k_cache_quant_scale.value(), num_blocks, k_head_num, block_size);
|
||||
CHECK_SHAPE(v_cache_quant_scale.value(), num_blocks, k_head_num, block_size);
|
||||
cache_quant_layout = CNNL_QUANTIZE_PER_TOKEN;
|
||||
} else {
|
||||
CHECK_SHAPE(k_cache_quant_scale.value(), num_blocks, k_head_num, block_size, 1);
|
||||
CHECK_SHAPE(v_cache_quant_scale.value(), num_blocks, k_head_num, block_size, 1);
|
||||
cache_quant_layout = CNNL_QUANTIZE_PER_TOKEN;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(
|
||||
block_tables.scalar_type() == torch::kInt32 || block_tables.scalar_type() == torch::kLong,
|
||||
"block_tables type need be torch::kInt32 or torch::kLong.");
|
||||
// Check context_lens
|
||||
TORCH_CHECK(context_lens.dtype() == torch::kInt32, "context_lens type need be torch::kInt32.");
|
||||
TORCH_CHECK(context_lens.is_contiguous(), "context_lens need be contiguous.");
|
||||
|
||||
if (has_alibi) {
|
||||
CHECK_SHAPE(alibi_slopes.value(), batch, head_num);
|
||||
}
|
||||
if (return_lse) {
|
||||
TORCH_CHECK(seq_q == 1, "return lse only support seq_q = 1 currently.");
|
||||
CHECK_SHAPE(output_lse.value(), batch, head_num, seq_q);
|
||||
}
|
||||
if (windows_size_left >= 0) {
|
||||
windows_size_left = windows_size_left + 1; // would be removed next version
|
||||
}
|
||||
|
||||
// Convert torch tensor to tensor descs
|
||||
auto descs = createTensorDescs(
|
||||
{q_ori, k_cache, v_cache, k_cache_quant_scale.value_or(at::Tensor()),
|
||||
v_cache_quant_scale.value_or(at::Tensor()), context_lens, block_tables,
|
||||
alibi_slopes.value_or(at::Tensor()), output, output_lse.value_or(at::Tensor())});
|
||||
|
||||
if (kv_cache_quant_bit_size == 4) {
|
||||
cnnlDataType_t data_type = CNNL_DTYPE_INT4X2;
|
||||
// k_cache
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(descs[1].get(), CNNL_LAYOUT_ARRAY, data_type,
|
||||
k_cache.sizes().size(), k_cache.sizes().data()));
|
||||
// v_cache
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(descs[2].get(), CNNL_LAYOUT_ARRAY, data_type,
|
||||
v_cache.sizes().size(), v_cache.sizes().data()));
|
||||
}
|
||||
|
||||
// Get current handle.
|
||||
const torch_mlu::mlu::MLUGuard device_guard(q_ori.device());
|
||||
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
||||
// Get workspace size and malloc workspace.
|
||||
size_t workspace_size = 0;
|
||||
cnnlSingleQueryCachedKVAttnDescriptor_t att_desc;
|
||||
CNNL_CHECK_FATAL(cnnlGetSingleQueryCachedKVAttnWorkspaceSize_v2(
|
||||
handle, att_desc, descs[0].get(), descs[1].get(), descs[2].get(), (int)max_context_len,
|
||||
&workspace_size));
|
||||
auto workspace =
|
||||
at::empty({static_cast<int64_t>(workspace_size)}, q_ori.options().dtype(at::kByte));
|
||||
|
||||
cnnlDataType_t data_dtype = getCnnlDataType(q_ori.scalar_type());
|
||||
SingleQueryCachedKvAttnTheory obj(batch, seq_q, head_num, k_head_num, num_blocks, block_size,
|
||||
qk_head_size, v_head_size, kv_cache_quant_bit_size,
|
||||
cache_quant_layout, data_dtype);
|
||||
cnpxPush(obj);
|
||||
// call cnnl extra op.
|
||||
CNNL_CHECK_FATAL(cnnlSingleQueryCachedKVAttn_v3(
|
||||
handle, att_desc, descs[0].get(), getAtTensorPtr(q_ori), descs[1].get(),
|
||||
getAtTensorPtr(k_cache), descs[2].get(), getAtTensorPtr(v_cache), descs[3].get(),
|
||||
getAtTensorPtr(k_cache_quant_scale), descs[4].get(), getAtTensorPtr(v_cache_quant_scale),
|
||||
descs[5].get(), getAtTensorPtr(context_lens), descs[6].get(), getAtTensorPtr(block_tables),
|
||||
descs[7].get(), getAtTensorPtr(alibi_slopes), cache_quant_layout, (int)max_context_len,
|
||||
windows_size_left, windows_size_right, softmax_scale, return_lse, getAtTensorPtr(workspace),
|
||||
workspace_size, descs[8].get(), getAtTensorPtr(output), descs[9].get(),
|
||||
getAtTensorPtr(output_lse)));
|
||||
cnpxPop();
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
112
torch_mlu_ops-v1.3.2/csrc/torch_api/smooth_quant.cpp
Normal file
112
torch_mlu_ops-v1.3.2/csrc/torch_api/smooth_quant.cpp
Normal file
@@ -0,0 +1,112 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "torch_ops_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
void smooth_quant(const at::Tensor &input,
|
||||
const at::Tensor &input_scale,
|
||||
const at::Tensor &output,
|
||||
const at::Tensor &output_scale,
|
||||
const c10::optional<at::Tensor> &input_zero,
|
||||
const c10::optional<at::Tensor> &token_count,
|
||||
const c10::optional<at::Tensor> &gather_index,
|
||||
const c10::optional<at::Tensor> &gather_index_start_position,
|
||||
const std::string &quant_mode,
|
||||
const bool &dynamic_quant) {
|
||||
bool is_pertoken = quant_mode == "per_token";
|
||||
CHECK_TENSOR_CONTIGUOUS(input_scale)
|
||||
if (output_scale.dim() > 0) {
|
||||
CHECK_TENSOR_CONTIGUOUS(output_scale)
|
||||
}
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(input_zero)
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(token_count)
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(gather_index)
|
||||
CHECK_OPTIONAL_TENSOR_CONTIGUOUS(gather_index_start_position)
|
||||
|
||||
TORCH_CHECK(isMlu(input), "input must be mlu tensor.");
|
||||
checkTensorSameAttr<TensorAttr::DEVICE>(input, input_scale, input_zero, token_count, gather_index,
|
||||
gather_index_start_position);
|
||||
|
||||
TORCH_CHECK(quant_mode == "per_token" || quant_mode == "per_tensor",
|
||||
"quant_mode must be 'per_token' or per_tensor.")
|
||||
if (dynamic_quant) {
|
||||
TORCH_CHECK(is_pertoken, "only support per_token if dynamic_quant == true")
|
||||
TORCH_CHECK(output_scale.numel() > 0, "invalid output_scale if dynamic_quant = true")
|
||||
} else {
|
||||
TORCH_CHECK(output_scale.numel() <= 0, "not support output_scale if dynamic_quant = false")
|
||||
}
|
||||
if (gather_index.has_value() || token_count.has_value()) {
|
||||
TORCH_CHECK(input.dim() == 2, "input.dim() == 2 if has gather_index or token_count")
|
||||
}
|
||||
TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2")
|
||||
TORCH_CHECK(output.dim() >= 2, "output.dim() >= 2")
|
||||
if (!gather_index.has_value()) {
|
||||
TORCH_CHECK(output.sizes() == input.sizes(), "input and output must have the same shape")
|
||||
}
|
||||
if (output_scale.numel() > 0) {
|
||||
if (!gather_index.has_value()) {
|
||||
TORCH_CHECK(std::vector<int64_t>(output_scale.sizes().begin(), output_scale.sizes().end()) ==
|
||||
std::vector<int64_t>(input.sizes().begin(), input.sizes().end() - 1),
|
||||
"output_scale_shape must be equal to input_shape[0:-1]")
|
||||
}
|
||||
}
|
||||
if (gather_index_start_position.has_value()) {
|
||||
TORCH_CHECK(gather_index.has_value(),
|
||||
"gather_index must exist if gather_index_start_position has value")
|
||||
TORCH_CHECK(gather_index.value().dim() == 1, "gather_index.dim() == 1")
|
||||
}
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(input.device());
|
||||
// flatten input
|
||||
at::Tensor input_flat = input.flatten(0, input.dim() - 2);
|
||||
TORCH_CHECK(input_flat.data_ptr() == input.data_ptr(), "check the input strides")
|
||||
// flatten output
|
||||
at::Tensor output_flat = output.flatten(0, output.dim() - 2);
|
||||
TORCH_CHECK(output_flat.data_ptr() == output.data_ptr(), "check the output stride")
|
||||
// flatten output_scale
|
||||
at::Tensor output_scale_flat;
|
||||
if (dynamic_quant) {
|
||||
output_scale_flat = output_scale.flatten(0, -1);
|
||||
}
|
||||
|
||||
// Convert torch tensor to tensor descriptor
|
||||
auto descs = createTensorDescs(
|
||||
{input_flat, input_scale, input_zero.value_or(at::Tensor()),
|
||||
token_count.value_or(at::Tensor()), gather_index.value_or(at::Tensor()),
|
||||
gather_index_start_position.value_or(at::Tensor()), output_flat, output_scale_flat});
|
||||
|
||||
auto input_shape = input.sizes();
|
||||
int64_t in_channel = input_shape.back();
|
||||
int64_t m =
|
||||
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, std::multiplies<int>());
|
||||
SmoothQuantTheory obj(m, in_channel, getCnnlDataType(input.scalar_type()), dynamic_quant);
|
||||
cnpxPush(obj);
|
||||
CNNL_CHECK_FATAL(cnnlSmoothQuantOnline_v4(
|
||||
torch_mlu::getCurrentHandle(), nullptr /*smooth_quant_online_desc*/, descs[0].get(),
|
||||
getAtTensorPtr(input_flat), /*input*/
|
||||
descs[1].get(), getAtTensorPtr(input_scale), /*input_scale*/
|
||||
descs[2].get(), getAtTensorPtr(input_zero), /*input_zero*/
|
||||
descs[3].get(), getAtTensorPtr(token_count), /*token_count*/
|
||||
descs[4].get(), getAtTensorPtr(gather_index), /*gather_index*/
|
||||
descs[5].get(), getAtTensorPtr(gather_index_start_position), /*gather_index_start_position*/
|
||||
descs[6].get(), getAtTensorPtr(output_flat), /*output*/
|
||||
descs[7].get(), getAtTensorPtr(output_scale_flat), /*output_scale*/
|
||||
nullptr, nullptr, /*output_zero*/
|
||||
nullptr, 0 /*worksapce*/));
|
||||
cnpxPop();
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
65
torch_mlu_ops-v1.3.2/csrc/torch_api/swap_blocks.cpp
Normal file
65
torch_mlu_ops-v1.3.2/csrc/torch_api/swap_blocks.cpp
Normal file
@@ -0,0 +1,65 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/swap_blocks.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
void swap_blocks(torch::Tensor &dst,
|
||||
const torch::Tensor &src,
|
||||
const c10::Dict<int64_t, int64_t> &block_mapping_dict) {
|
||||
std::map<int64_t, int64_t> block_mapping;
|
||||
for (const auto &item : block_mapping_dict) {
|
||||
block_mapping[item.key()] = item.value();
|
||||
}
|
||||
|
||||
// check shape
|
||||
TORCH_CHECK(src[0].numel() == dst[0].numel(), "the block_size of src and dst are not the same.")
|
||||
|
||||
// check dtype
|
||||
TORCH_CHECK(src.dtype() == dst.dtype(), "the data type of src and dst are not the same.")
|
||||
TORCH_CHECK(src.dtype() == torch::kInt8 || src.dtype() == torch::kUInt8 ||
|
||||
src.dtype() == torch::kInt16 || src.dtype() == torch::kInt32 ||
|
||||
src.dtype() == torch::kLong || src.dtype() == torch::kFloat16 ||
|
||||
src.dtype() == torch::kFloat32 || src.dtype() == torch::kBFloat16,
|
||||
"data type only supports torch::kInt8, torch::kUInt8, torch::kInt16, torch::kInt32, "
|
||||
"torch::kLong, torch::kFloat16, torch::kFloat32 and torch::kBFloat16");
|
||||
|
||||
torch::Device src_device = src.device();
|
||||
torch::Device dst_device = dst.device();
|
||||
torch::Device mlu_device = src_device;
|
||||
cnrtMemTransDir_t memcpy_type;
|
||||
bool src_is_mlu = isMlu(src);
|
||||
bool dst_is_mlu = isMlu(dst);
|
||||
|
||||
if (src_is_mlu && dst_is_mlu) {
|
||||
TORCH_CHECK(src_device.index() == dst_device.index(), "src and dst must be on the same MLU");
|
||||
memcpy_type = cnrtMemcpyDevToDev;
|
||||
} else if (src_is_mlu && dst_device.is_cpu()) {
|
||||
memcpy_type = cnrtMemcpyDevToHost;
|
||||
} else if (src_device.is_cpu() && dst_is_mlu) {
|
||||
memcpy_type = cnrtMemcpyHostToDev;
|
||||
mlu_device = dst_device;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid device combination");
|
||||
}
|
||||
|
||||
const torch_mlu::mlu::MLUGuard device_guard(mlu_device);
|
||||
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
TMO_KERNEL_CHECK_FATAL(invokeSwapBlocksKernel(handle, dst.data_ptr(), src.data_ptr(),
|
||||
block_size_in_bytes, memcpy_type, block_mapping));
|
||||
}
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
472
torch_mlu_ops-v1.3.2/csrc/torch_api/torch_ops_api.h
Normal file
472
torch_mlu_ops-v1.3.2/csrc/torch_api/torch_ops_api.h
Normal file
@@ -0,0 +1,472 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_TORCH_API_TORCH_OPS_API_H_
|
||||
#define CSRC_TORCH_API_TORCH_OPS_API_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
#include "op_theory.h"
|
||||
#include "ops/kernel_api.h"
|
||||
#include "torch/extension.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
void fused_layernorm(const at::Tensor &input,
|
||||
const at::Tensor &out,
|
||||
const c10::optional<at::Tensor> &residual,
|
||||
const c10::optional<at::Tensor> &gamma,
|
||||
const c10::optional<at::Tensor> &beta,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<at::Tensor> &quant_scale,
|
||||
const c10::optional<at::Tensor> &residual_out,
|
||||
const c10::optional<at::Tensor> &smooth_quant_scale,
|
||||
const std::string &norm_mode,
|
||||
double eps,
|
||||
bool store_output_before_norm,
|
||||
bool dynamic_quant);
|
||||
|
||||
std::vector<at::Tensor> attention_project(const at::Tensor &input,
|
||||
const at::Tensor &q_weight,
|
||||
const c10::optional<at::Tensor> &q_bias,
|
||||
const c10::optional<at::Tensor> &k_weight,
|
||||
const c10::optional<at::Tensor> &k_bias,
|
||||
const c10::optional<at::Tensor> &v_weight,
|
||||
const c10::optional<at::Tensor> &v_bias,
|
||||
const c10::optional<at::Tensor> &norm_weight,
|
||||
const c10::optional<at::Tensor> &norm_bias,
|
||||
const c10::optional<at::Tensor> &residual,
|
||||
const std::string &out_layout,
|
||||
int64_t head_size,
|
||||
double eps,
|
||||
double alpha,
|
||||
double beta,
|
||||
bool norm_out);
|
||||
|
||||
at::Tensor ffn(const at::Tensor &input,
|
||||
const at::Tensor &up_fc_weight,
|
||||
const c10::optional<at::Tensor> &up_fc_bias,
|
||||
const at::Tensor &down_proj_weight,
|
||||
const c10::optional<at::Tensor> &down_proj_bias,
|
||||
const c10::optional<at::Tensor> &gate_up_proj_weight,
|
||||
const c10::optional<at::Tensor> &gate_up_proj_bias,
|
||||
const c10::optional<at::Tensor> &layernorm_weight,
|
||||
const c10::optional<at::Tensor> &layernorm_bias,
|
||||
const std::string &act_mode,
|
||||
const std::string &residual_is,
|
||||
double eps,
|
||||
double alpha,
|
||||
double beta);
|
||||
|
||||
void flash_attention(const at::Tensor &q,
|
||||
const at::Tensor &k,
|
||||
const at::Tensor &v,
|
||||
const at::Tensor &out,
|
||||
const c10::optional<at::Tensor> &output_lse,
|
||||
const c10::optional<at::Tensor> &cu_seq_lens_q,
|
||||
const c10::optional<at::Tensor> &cu_seq_lens_kv,
|
||||
const c10::optional<at::Tensor> &alibi_slope,
|
||||
const c10::optional<at::Tensor> &attn_bias,
|
||||
const c10::optional<at::Tensor> &k_cache_quant_scale,
|
||||
const c10::optional<at::Tensor> &v_cache_quant_scale,
|
||||
const c10::optional<at::Tensor> &block_tables,
|
||||
const int64_t max_seq_len_q,
|
||||
const int64_t max_seq_len_kv,
|
||||
const double softmax_scale,
|
||||
const bool is_causal,
|
||||
const int64_t window_size_left,
|
||||
const int64_t window_size_right,
|
||||
const std::string &compute_dtype,
|
||||
bool return_lse);
|
||||
|
||||
void single_query_cached_kv_attn(
|
||||
const torch::Tensor &q_ori,
|
||||
const torch::Tensor &k_cache,
|
||||
const torch::Tensor &v_cache,
|
||||
const torch::Tensor &output,
|
||||
const torch::Tensor &block_tables,
|
||||
const torch::Tensor &context_lens, // [batch]
|
||||
const c10::optional<torch::Tensor> &output_lse,
|
||||
const c10::optional<torch::Tensor> &k_cache_quant_scale,
|
||||
const c10::optional<torch::Tensor> &v_cache_quant_scale,
|
||||
const c10::optional<torch::Tensor> &alibi_slopes, // [bs, head_num] or [head_num]
|
||||
int64_t max_context_len,
|
||||
int64_t windows_size_left,
|
||||
int64_t windows_size_right,
|
||||
double softmax_scale,
|
||||
bool return_lse,
|
||||
int64_t kv_cache_quant_bit_size);
|
||||
|
||||
void reshape_linear_cache(const at::Tensor &key,
|
||||
const c10::optional<at::Tensor> &value,
|
||||
at::Tensor &key_cache,
|
||||
const c10::optional<at::Tensor> &value_cache,
|
||||
const at::Tensor &context_lengths,
|
||||
const int64_t max_context_len,
|
||||
bool packed,
|
||||
const c10::optional<at::Tensor> &context_seq_offset,
|
||||
const c10::optional<at::Tensor> &cache_bs_id,
|
||||
const c10::optional<at::Tensor> &cache_seqlen_offset);
|
||||
|
||||
void reshape_paged_cache(const torch::Tensor &k,
|
||||
const c10::optional<torch::Tensor> &v,
|
||||
torch::Tensor &k_cache,
|
||||
const c10::optional<torch::Tensor> &v_cache,
|
||||
const torch::Tensor &slot_mapping);
|
||||
|
||||
void quant_to_paged_cache(const torch::Tensor &k,
|
||||
const c10::optional<torch::Tensor> &v,
|
||||
torch::Tensor &k_cache,
|
||||
const c10::optional<torch::Tensor> &v_cache,
|
||||
torch::Tensor &k_cache_scale,
|
||||
const c10::optional<torch::Tensor> &v_cache_scale,
|
||||
const torch::Tensor &slot_mapping);
|
||||
|
||||
void offline_quant_to_paged_cache(const torch::Tensor &k,
|
||||
const c10::optional<torch::Tensor> &v,
|
||||
const torch::Tensor &k_cache_scale,
|
||||
const c10::optional<torch::Tensor> &v_cache_scale,
|
||||
const torch::Tensor &slot_mapping,
|
||||
torch::Tensor &k_cache,
|
||||
const c10::optional<torch::Tensor> &v_cache);
|
||||
|
||||
void quant_to_linear_cache(const at::Tensor &key,
|
||||
const c10::optional<at::Tensor> &value,
|
||||
at::Tensor &key_cache,
|
||||
const c10::optional<at::Tensor> &value_cache,
|
||||
at::Tensor &key_cache_scale,
|
||||
const c10::optional<at::Tensor> &value_cache_scale,
|
||||
const at::Tensor &context_lengths,
|
||||
const int64_t max_context_len,
|
||||
bool packed,
|
||||
const c10::optional<at::Tensor> &context_seq_offset,
|
||||
const c10::optional<at::Tensor> &cache_bs_id,
|
||||
const c10::optional<at::Tensor> &cache_seqlen_offset,
|
||||
const int64_t quant_bit);
|
||||
|
||||
void offline_quant_to_linear_cache(const at::Tensor &key,
|
||||
const c10::optional<at::Tensor> &value,
|
||||
at::Tensor &key_cache,
|
||||
const c10::optional<at::Tensor> &value_cache,
|
||||
const at::Tensor &key_cache_scale,
|
||||
const c10::optional<at::Tensor> &value_cache_scale,
|
||||
const at::Tensor &context_lengths,
|
||||
const int64_t max_context_len,
|
||||
const int64_t quant_mode,
|
||||
const bool packed,
|
||||
const c10::optional<at::Tensor> &context_seq_offset,
|
||||
const c10::optional<at::Tensor> &cache_bs_id,
|
||||
const c10::optional<at::Tensor> &cache_seqlen_offset);
|
||||
|
||||
void apply_rotary(const torch::Tensor &input,
|
||||
const torch::Tensor &sin_cache,
|
||||
const torch::Tensor &cos_cache,
|
||||
const c10::optional<torch::Tensor> &position_ids,
|
||||
const c10::optional<torch::Tensor> &cu_seqlens,
|
||||
bool interleaved,
|
||||
bool discrete,
|
||||
bool dynamic_ntk,
|
||||
int64_t max_seqlen);
|
||||
|
||||
void swap_blocks(torch::Tensor &dst,
|
||||
const torch::Tensor &src,
|
||||
const c10::Dict<int64_t, int64_t> &block_mapping);
|
||||
|
||||
void copy_blocks(const std::vector<torch::Tensor> &k_caches,
|
||||
const std::vector<torch::Tensor> &v_caches,
|
||||
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping);
|
||||
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> copy_blocks_out_of_place(
|
||||
const std::vector<torch::Tensor> &k_caches,
|
||||
const std::vector<torch::Tensor> &v_caches,
|
||||
const c10::Dict<int64_t, c10::List<int64_t>> &block_mapping);
|
||||
|
||||
at::Tensor fused_moe(const at::Tensor &hidden_states,
|
||||
const at::Tensor &gating_output,
|
||||
const at::Tensor &w1,
|
||||
const at::Tensor &w2,
|
||||
const c10::optional<at::Tensor> &bias1,
|
||||
const c10::optional<at::Tensor> &bias2,
|
||||
const c10::optional<at::Tensor> &residual,
|
||||
const c10::optional<at::Tensor> &input_smooth,
|
||||
const c10::optional<at::Tensor> &act_smooth,
|
||||
const c10::optional<at::Tensor> &w1_scale,
|
||||
const c10::optional<at::Tensor> &w2_scale,
|
||||
const c10::optional<at::List<int64_t>> &w1_quant_flag,
|
||||
const c10::optional<at::List<int64_t>> &w2_quant_flag,
|
||||
const int64_t topk,
|
||||
const bool renormalize,
|
||||
const bool gated,
|
||||
const std::string &act_mode,
|
||||
const int64_t start_expert_id,
|
||||
const int64_t block_n = 0,
|
||||
const int64_t cncl_comm = 0);
|
||||
|
||||
// d = (a * b + bias) * alpha + c * beta
|
||||
// d = (((a * a_scale * b * b_scale) * gemm_output_scale + bias) * alpha + c * beta) * d_scale
|
||||
at::Tensor quant_matmul(
|
||||
const at::Tensor &a_tensor, // input
|
||||
const c10::optional<at::Tensor> &a_scale, // input scale, smooth quant
|
||||
const c10::optional<at::Tensor> &a_zero, // input zero
|
||||
const at::Tensor &b_tensor, // weight
|
||||
const c10::optional<at::Tensor> &b_scale, // weight scale, smooth quant
|
||||
const c10::optional<at::Tensor> &b_zero, // weight zero
|
||||
const c10::optional<at::Tensor> &bias, // bias
|
||||
const c10::optional<at::Tensor> &c_tensor, // residual
|
||||
const c10::optional<at::Tensor> &c_scale, // residual scale
|
||||
const c10::optional<at::Tensor> &c_zero, // residual zero
|
||||
const c10::optional<at::Tensor> &gemm_output_scale, // for dequant, weight only
|
||||
const c10::optional<at::Tensor> &gemm_output_zero, // for dequant, preserve
|
||||
const c10::optional<std::string> &data_type, // output data type
|
||||
const c10::optional<at::Tensor> &d, // output
|
||||
const std::string &quant_algo, // quant type
|
||||
const std::string &a_quant_layout, // input quant mode
|
||||
const std::string &b_quant_layout, // weight quant mode
|
||||
int64_t quant_bit_size = 8, // weight quant bit size
|
||||
const std::string &act_mode = "none", // act mode
|
||||
bool use_hp_active = false, // for active, whether uses high precision mode
|
||||
double act_coef = 1.f, // active_coef
|
||||
double alpha = 1.f, // for quant matmul
|
||||
double beta = 1.f, // for residual
|
||||
bool trans_a = false, // whether transpose input
|
||||
bool trans_b = true); // whether transpose weight
|
||||
|
||||
at::Tensor quant_matmul_allreduce(const int64_t cncl_comm,
|
||||
const at::Tensor &a_tensor,
|
||||
const c10::optional<at::Tensor> &a_scale,
|
||||
const c10::optional<at::Tensor> &a_zero,
|
||||
const at::Tensor &b_tensor,
|
||||
const c10::optional<at::Tensor> &b_scale,
|
||||
const c10::optional<at::Tensor> &b_zero,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<at::Tensor> &c_tensor,
|
||||
const c10::optional<at::Tensor> &c_scale,
|
||||
const c10::optional<at::Tensor> &c_zero,
|
||||
const c10::optional<at::Tensor> &gemm_output_scale,
|
||||
const c10::optional<at::Tensor> &gemm_output_zero,
|
||||
const c10::optional<std::string> &data_type,
|
||||
const c10::optional<at::Tensor> &d,
|
||||
const std::string &quant_algo,
|
||||
const std::string &a_quant_layout,
|
||||
const std::string &b_quant_layout,
|
||||
int64_t quant_bit_size,
|
||||
double alpha,
|
||||
double beta,
|
||||
bool trans_a,
|
||||
bool trans_b,
|
||||
const int64_t block_m);
|
||||
|
||||
void active(const torch::Tensor &input,
|
||||
const torch::Tensor &output,
|
||||
const c10::optional<torch::Tensor> &bias,
|
||||
const c10::optional<torch::Tensor> &cusum_token_count,
|
||||
const std::string &act_mode,
|
||||
bool is_gated,
|
||||
int64_t start_expert_id,
|
||||
int64_t expert_size,
|
||||
double active_coef = 1.0);
|
||||
|
||||
void smooth_quant(const at::Tensor &input,
|
||||
const at::Tensor &input_scale,
|
||||
const at::Tensor &output,
|
||||
const at::Tensor &output_scale,
|
||||
const c10::optional<at::Tensor> &input_zero,
|
||||
const c10::optional<at::Tensor> &token_count,
|
||||
const c10::optional<at::Tensor> &gather_index,
|
||||
const c10::optional<at::Tensor> &gather_index_start_position,
|
||||
const std::string &quant_mode,
|
||||
const bool &dynamic_quant);
|
||||
|
||||
at::Tensor matmul(const at::Tensor &a,
|
||||
const at::Tensor &b,
|
||||
const c10::optional<at::Tensor> &d,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<at::Tensor> &c,
|
||||
const c10::optional<std::string> &dtype,
|
||||
const std::string &act_mode,
|
||||
double alpha,
|
||||
double beta,
|
||||
bool fast_act,
|
||||
bool approximate,
|
||||
double a_scale,
|
||||
double b_scale,
|
||||
bool trans_a,
|
||||
bool trans_b);
|
||||
|
||||
at::Tensor matmul_allreduce(const int64_t cncl_comm,
|
||||
const at::Tensor &a,
|
||||
const at::Tensor &b,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<at::Tensor> &c,
|
||||
const c10::optional<at::Tensor> &d,
|
||||
const double alpha,
|
||||
const double beta,
|
||||
const int64_t block_m);
|
||||
|
||||
at::Tensor group_gemm(const at::Tensor &a_tensor,
|
||||
const at::Tensor &b_tensor,
|
||||
const at::Tensor &m_list,
|
||||
const c10::optional<at::Tensor> &gather_idx,
|
||||
const c10::optional<at::Tensor> &c_tensor,
|
||||
const c10::optional<at::Tensor> &alpha,
|
||||
const c10::optional<at::Tensor> &beta,
|
||||
const c10::optional<at::Tensor> &a_scale,
|
||||
const c10::optional<at::Tensor> &b_scale,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const c10::optional<std::string> &data_type,
|
||||
const c10::optional<at::List<int64_t>> &quant_flag,
|
||||
const c10::optional<at::Tensor> &b_offset,
|
||||
const int64_t max_m /*max m in m_list*/);
|
||||
|
||||
void preload(const torch::Tensor &weight, const int64_t size);
|
||||
|
||||
at::Tensor group_gemm_combine_result_allreduce(const int64_t cncl_comm,
|
||||
const at::Tensor &a_tensor,
|
||||
const at::Tensor &b_tensor,
|
||||
const at::Tensor &m_list,
|
||||
const at::Tensor &combine_idx,
|
||||
const at::Tensor &combine_weight,
|
||||
const c10::optional<at::Tensor> &c_tensor,
|
||||
const c10::optional<at::Tensor> &alpha,
|
||||
const c10::optional<at::Tensor> &beta,
|
||||
const c10::optional<at::Tensor> &a_scale,
|
||||
const c10::optional<at::Tensor> &b_scale,
|
||||
const c10::optional<std::string> &data_type,
|
||||
const int64_t num_token,
|
||||
const int64_t topk,
|
||||
const int64_t block_n);
|
||||
|
||||
at::Tensor flash_attn_sq_mm_allreduce(const int64_t cncl_comm,
|
||||
const at::Tensor &q,
|
||||
const at::Tensor &k,
|
||||
const at::Tensor &v,
|
||||
const c10::optional<at::Tensor> &cu_seq_lens_q,
|
||||
const c10::optional<at::Tensor> &cu_seq_lens_kv,
|
||||
const c10::optional<at::Tensor> &alibi_slope,
|
||||
const c10::optional<at::Tensor> &attn_bias,
|
||||
const at::Tensor &smooth,
|
||||
const at::Tensor &weight,
|
||||
const at::Tensor &weight_scale,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const int64_t max_seq_len_q,
|
||||
const int64_t max_seq_len_kv,
|
||||
const double softmax_scale,
|
||||
const bool is_causal,
|
||||
const int64_t window_size_left,
|
||||
const int64_t window_size_right,
|
||||
const std::string &compute_dtype,
|
||||
const int64_t block_seq);
|
||||
|
||||
std::vector<at::Tensor> moe_softmax_topk(const at::Tensor &input,
|
||||
int64_t topk,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
bool normalize,
|
||||
const c10::optional<at::Tensor> &mask,
|
||||
const std::string &normed_by);
|
||||
|
||||
at::Tensor moe_expand_input(const torch::Tensor &input,
|
||||
const torch::Tensor &gather_idx,
|
||||
const c10::optional<torch::Tensor> &cusum_token_count,
|
||||
int64_t start_expert_id,
|
||||
int64_t expert_size);
|
||||
|
||||
std::vector<at::Tensor> moe_gen_idx(const torch::Tensor &expert_id, int64_t expert_num);
|
||||
|
||||
at::Tensor moe_combine_result(const at::Tensor &input,
|
||||
const at::Tensor &reduce_weight,
|
||||
const at::Tensor &gather_ids,
|
||||
const c10::optional<at::Tensor> &residual,
|
||||
const c10::optional<at::Tensor> &cusum_token_count,
|
||||
int64_t start_expert_id,
|
||||
int64_t expert_size,
|
||||
const c10::optional<at::Tensor> &bias);
|
||||
|
||||
void fused_rope(at::Tensor &qkv,
|
||||
at::Tensor &key_cache_hp,
|
||||
at::Tensor &value_cache_hp,
|
||||
const c10::optional<at::Tensor> &key_cache_lp,
|
||||
const c10::optional<at::Tensor> &value_cache_lp,
|
||||
const at::Tensor &sin_table,
|
||||
const at::Tensor &cos_table,
|
||||
const at::Tensor &position_ids,
|
||||
const at::Tensor &gamma,
|
||||
const at::Tensor &beta,
|
||||
const c10::optional<at::Tensor> &key_scale_hp,
|
||||
const c10::optional<at::Tensor> &value_scale_hp,
|
||||
const c10::optional<at::Tensor> &key_scale_lp,
|
||||
const c10::optional<at::Tensor> &value_scale_lp,
|
||||
const c10::optional<at::Tensor> &cache_bs_id_hp,
|
||||
const c10::optional<at::Tensor> &cache_seq_offsets_hp,
|
||||
const c10::optional<at::Tensor> &cache_bs_id_lp,
|
||||
const c10::optional<at::Tensor> &cache_seq_offsets_lp,
|
||||
const c10::optional<at::Tensor> &slot_mapping_hp,
|
||||
const c10::optional<at::Tensor> &slot_mapping_lp,
|
||||
const double eps);
|
||||
|
||||
void batch_matmul(const at::Tensor &a,
|
||||
const at::Tensor &b,
|
||||
const at::Tensor &c,
|
||||
double alpha,
|
||||
double beta,
|
||||
double a_scale,
|
||||
double b_scale,
|
||||
bool trans_a,
|
||||
bool trans_b);
|
||||
|
||||
at::Tensor moe_cast_gating(const at::Tensor &input, const at::Tensor &weight);
|
||||
|
||||
void update_out_and_lse(at::Tensor &out,
|
||||
at::Tensor &lse,
|
||||
const at::Tensor &block_out,
|
||||
const at::Tensor &block_lse,
|
||||
const c10::optional<at::Tensor> &seq_offsets,
|
||||
const c10::optional<at::Tensor> &cu_seqs,
|
||||
const c10::optional<at::Tensor> &block_cu_seqs);
|
||||
|
||||
void cnpxPush(const OpTheory &op);
|
||||
|
||||
void cnpxPop(void);
|
||||
|
||||
void dequant_from_linear_cache(at::Tensor &key,
|
||||
const c10::optional<at::Tensor> &value,
|
||||
const at::Tensor &key_cache,
|
||||
const c10::optional<at::Tensor> &value_cache,
|
||||
const at::Tensor &key_quant_scale,
|
||||
const c10::optional<at::Tensor> &value_quant_scale,
|
||||
const at::Tensor &context_lengths,
|
||||
const int64_t max_context_len,
|
||||
const c10::optional<at::Tensor> &context_seq_offset,
|
||||
const c10::optional<at::Tensor> &cache_bs_id,
|
||||
const c10::optional<at::Tensor> &cache_seq_offset,
|
||||
const int64_t quant_mode,
|
||||
const int64_t quant_bit);
|
||||
|
||||
void dequant_from_paged_cache(at::Tensor &key,
|
||||
const c10::optional<at::Tensor> &value,
|
||||
const at::Tensor &key_cache,
|
||||
const c10::optional<at::Tensor> &value_cache,
|
||||
const at::Tensor &key_cache_quant_scale,
|
||||
const c10::optional<at::Tensor> &value_cache_quant_scale,
|
||||
const at::Tensor &context_lengths,
|
||||
int64_t max_context_len,
|
||||
const c10::optional<at::Tensor> &context_seq_offset,
|
||||
const at::Tensor &block_tables,
|
||||
int64_t quant_mode,
|
||||
int64_t quant_bit);
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_TORCH_API_TORCH_OPS_API_H_
|
||||
263
torch_mlu_ops-v1.3.2/csrc/torch_api/torch_register_function.cpp
Normal file
263
torch_mlu_ops-v1.3.2/csrc/torch_api/torch_register_function.cpp
Normal file
@@ -0,0 +1,263 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include <torch/library.h>
|
||||
#include "glue_ops.h"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(torch_mlu_ops, m) {
|
||||
// torch 2.1.0 does not support impl_abstract_pystub, enable it when torch 2.3.0 is released
|
||||
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3
|
||||
m.impl_abstract_pystub("torch_mlu_ops.abstract");
|
||||
#endif
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::attention_project(Tensor input, Tensor q_weight, Tensor? q_bias, Tensor? "
|
||||
"k_weight, "
|
||||
"Tensor? k_bias, Tensor? v_weight, Tensor? v_bias, Tensor? norm_weight, Tensor? norm_bias, "
|
||||
"Tensor? residual, str out_layout, int head_size, float eps, float alpha, float beta, bool "
|
||||
"norm_out) -> (Tensor[])"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::ffn(Tensor input, Tensor up_fc_weight, Tensor? up_fc_bias, Tensor "
|
||||
"down_proj_weight, "
|
||||
"Tensor? down_proj_bias, Tensor? gate_up_proj_weight, Tensor? gate_up_proj_bias, Tensor? "
|
||||
"layernorm_weight, Tensor? layernorm_bias, str act_mode, str residual_is, float eps, float "
|
||||
"alpha, float beta) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::flash_attention(Tensor q, Tensor k, Tensor v, Tensor(a!) out, Tensor(b!)? "
|
||||
"out_lse, Tensor? cu_seq_lens_q, Tensor? cu_seq_lens_kv, "
|
||||
"Tensor? alibi_slope, Tensor? attn_bias, Tensor? k_quant_scale, Tensor? v_quant_scale, "
|
||||
"Tensor? block_tables, int max_seq_len_q, int "
|
||||
"max_seq_len_kv, float softmax_scale, bool is_causal, int window_size_left, int "
|
||||
"window_size_right, str compute_dtype, bool return_lse) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::single_query_cached_kv_attn(Tensor q_ori, Tensor k_cache, Tensor "
|
||||
"v_cache, Tensor(a!) output, Tensor block_tables, Tensor context_lens, "
|
||||
"Tensor(b!)? out_lse, Tensor? k_cache_quant_scale, Tensor? v_cache_quant_scale, "
|
||||
"Tensor? alibi_slopes, int max_contxt_len, int windows_size_left, int "
|
||||
"windows_size_right, float softmax_scale, bool return_lse, int kv_cache_quant_bit_size) -> "
|
||||
"()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::apply_rotary(Tensor(a!) input, Tensor sin_cache, Tensor cos_cache, "
|
||||
"Tensor? position_ids, Tensor? cu_seqlens, bool interleaved, bool "
|
||||
"discrete, bool dynamic_ntk, int max_seqlen) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::reshape_linear_cache(Tensor key, Tensor? value, Tensor(a!) key_cache, "
|
||||
"Tensor(b!)? "
|
||||
"value_cache, Tensor context_lengths, int max_context_len, bool packed, Tensor? "
|
||||
"context_seq_offset, Tensor? cache_bs_id, Tensor? cache_seqlen_offset) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::reshape_paged_cache(Tensor k, Tensor? v, Tensor(a!) k_cache, "
|
||||
"Tensor(b!)? v_cache, Tensor slot_mapping) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::quant_to_paged_cache(Tensor k, Tensor? v, Tensor(a!) k_cache, Tensor(b!)? "
|
||||
"v_cache, "
|
||||
"Tensor(c!) k_cache_scale, Tensor(d!)? v_cache_scale, Tensor slot_mapping) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::offline_quant_to_paged_cache(Tensor k, Tensor? v, Tensor k_cache_scale,"
|
||||
"Tensor? v_cache_scale, Tensor slot_mapping, Tensor(a!) k_cache, Tensor(b!)? v_cache)-> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::quant_to_linear_cache(Tensor key, Tensor? value, Tensor(a!) key_cache, "
|
||||
"Tensor(b!)? "
|
||||
"value_cache, Tensor(c!) key_cache_scale, Tensor(d!)? value_cache_scale, Tensor "
|
||||
"context_lengths, "
|
||||
"int max_context_len, bool packed, Tensor? context_seq_offset, Tensor? cache_bs_id, Tensor? "
|
||||
"cache_seqlen_offset, int quant_bit=8) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::offline_quant_to_linear_cache(Tensor key, Tensor? value, Tensor(a!) "
|
||||
"key_cache,"
|
||||
" Tensor(b!)? value_cache, Tensor(c!) key_cache_scale, Tensor(d!)? value_cache_scale, "
|
||||
"Tensor context_lengths, int max_context_len, int quant_mode, bool packed, "
|
||||
"Tensor? context_seq_offset, Tensor? cache_bs_id, Tensor? "
|
||||
"cache_seqlen_offset) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::quant_matmul(Tensor a_tensor, Tensor? a_scale, Tensor? a_zero, Tensor "
|
||||
"b_tensor, "
|
||||
"Tensor? b_scale, Tensor? b_zero, Tensor? bias, Tensor? c_tensor, Tensor? c_scale, Tensor? "
|
||||
"c_zero, Tensor? gemm_output_scale, Tensor? gemm_output_zero, str? data_type, Tensor? d, "
|
||||
"str "
|
||||
"quant_algo, str a_quant_layout, str b_quant_layout, int quant_bit_size=8, str "
|
||||
"act_mode='none', bool use_hp_active=False, float act_coef=1.0, float alpha=1.0, float "
|
||||
"beta=1.0, bool trans_a=False, bool trans_b=True) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::quant_matmul_allreduce(int cncl_comm, Tensor a_tensor, Tensor? a_scale, "
|
||||
"Tensor? a_zero, Tensor b_tensor, Tensor? b_scale, Tensor? b_zero, "
|
||||
"Tensor? bias, Tensor? c_tensor, Tensor? c_scale, Tensor? "
|
||||
"c_zero, Tensor? gemm_output_scale, Tensor? gemm_output_zero, str? data_type, Tensor? d, "
|
||||
"str quant_algo, str a_quant_layout, str b_quant_layout, int quant_bit_size=8, "
|
||||
"float alpha=1.0, float beta=1.0, bool trans_a=False, bool trans_b=True, int block_m=0) -> "
|
||||
"Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::active(Tensor input, Tensor(a!) output, Tensor? bias, "
|
||||
"Tensor? cusum_token_count, str act_mode, bool is_gated, int start_expert_id, "
|
||||
"int expert_size, float active_coef=1.0) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::smooth_quant(Tensor input, Tensor input_scale, Tensor(a!) output, "
|
||||
"Tensor(b!) output_scale, "
|
||||
"Tensor? input_zero, Tensor? token_count, Tensor? gather_index, "
|
||||
"Tensor? gather_index_start_position, str quant_mode, bool dynamic_quant) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::fused_layernorm(Tensor input, Tensor(a!) out, "
|
||||
"Tensor? residual, Tensor? gamma, Tensor? beta, Tensor? bias, Tensor? "
|
||||
"quant_scale, Tensor(b!)? residual_out, Tensor? smooth_quant_scale, str norm_mode, "
|
||||
"float eps, bool store_output_before_norm, bool dynamic_quant) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::fused_moe(Tensor hidden_states, Tensor gating_output, Tensor "
|
||||
"w1, Tensor w2, Tensor? bias1, Tensor? bias2, Tensor? residual, Tensor? "
|
||||
"input_smooth, Tensor? act_smooth, Tensor? w1_scale, Tensor? w2_scale, "
|
||||
"int[]? w1_quant_flag, int[]? w2_quant_flag, "
|
||||
"int topk, bool renormalize, bool gated, str act_mode, int start_expert_id, "
|
||||
"int block_n, int cncl_comm) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::matmul(Tensor a, Tensor b, Tensor? d, Tensor? bias, Tensor? c, str? dtype,"
|
||||
" str act_mode, float alpha, float beta, bool fast_act, bool approximate, float a_scale, "
|
||||
"float "
|
||||
"b_scale, bool trans_a, bool trans_b) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::batch_matmul(Tensor a, Tensor b, Tensor(a!) c,"
|
||||
"float alpha, float beta, float a_scale, float b_scale, bool trans_a, bool trans_b) -> ()"));
|
||||
m.def(
|
||||
TORCH_SELECTIVE_SCHEMA("torch_mlu_ops::matmul_allreduce(int cncl_comm, Tensor a, Tensor b, "
|
||||
"Tensor? bias, Tensor? c, Tensor? d, "
|
||||
"float alpha, float beta, int block_m) -> Tensor"));
|
||||
m.def(
|
||||
TORCH_SELECTIVE_SCHEMA("torch_mlu_ops::swap_blocks(Tensor(a!) dst, Tensor src, Dict(int, "
|
||||
"int) block_mapping) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::copy_blocks(Tensor(a!)[] k_caches, Tensor(b!)[] v_caches, "
|
||||
"Dict(int, int[]) block_mapping) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::copy_blocks_out_of_place(Tensor[] k_caches, Tensor[] v_caches, "
|
||||
"Dict(int, int[]) block_mapping) -> (Tensor[], Tensor[])"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::group_gemm(Tensor a, Tensor b, Tensor m_list, Tensor? idx, Tensor? c, "
|
||||
"Tensor? alpha, Tensor? beta, Tensor? a_scale, Tensor? b_scale, Tensor? bias,"
|
||||
"str? data_type, int[]? quant_flag, Tensor? b_offset, int max_m) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("torch_mlu_ops::preload(Tensor(a!) weight, int size) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::group_gemm_combine_result_allreduce(int cncl_comm, Tensor a, Tensor b, "
|
||||
"Tensor m_list, Tensor combine_idx, Tensor combine_weight, Tensor? c, "
|
||||
"Tensor? alpha, Tensor? beta, Tensor? a_scale, Tensor? b_scale, "
|
||||
"str? data_type, int num_token, int topk, int block_n) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::flash_attn_sq_mm_allreduce(int cncl_comm, Tensor q, Tensor k, "
|
||||
"Tensor v, Tensor? cu_seq_lens_q, Tensor? cu_seq_lens_k, Tensor? alibi_slope, "
|
||||
"Tensor? attn_bias, Tensor smooth, Tensor weight, Tensor weight_scale, "
|
||||
"Tensor? bias, int max_seq_len_q, int max_seq_len_kv, float softmax_scale, bool is_causal, "
|
||||
"int window_size_left, int window_size_right, "
|
||||
"str compute_dtype, int block_seq) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::moe_softmax_topk(Tensor input, int topk, int num_expert_group, "
|
||||
"int topk_group, bool normalize, Tensor? mask, str normed_by) -> Tensor[]"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::moe_expand_input(Tensor input, Tensor gather_idx, Tensor ? "
|
||||
"cusum_token_count, int start_expert_id, int expert_size) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::moe_gen_idx(Tensor expert_id, int expert_num) -> Tensor[]"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::moe_combine_result(Tensor input, Tensor reduce_weight,"
|
||||
"Tensor gather_ids, Tensor? residual,"
|
||||
"Tensor? cusum_token_count, int start_expert_id, int expert_size, Tensor? bias) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::fused_rope(Tensor(a!) input, Tensor(b!) k_cache_hp, Tensor(c!) v_cache_hp, "
|
||||
"Tensor(d!)? k_cache_lp, Tensor(e!)? v_cache_lp, Tensor sin_table, Tensor cos_table, "
|
||||
"Tensor position_ids, Tensor gamma, Tensor beta, Tensor? k_scale_hp, Tensor? v_scale_hp, "
|
||||
"Tensor(f!)? k_scale_lp, Tensor(g!)? v_scale_lp, Tensor? cache_bs_id_hp, "
|
||||
"Tensor? cache_seq_offsets_hp, Tensor? cache_bs_id_lp, Tensor? cache_seq_offsets_lp, "
|
||||
"Tensor? slot_mapping_hp, Tensor? slot_mapping_lp, float eps) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::moe_cast_gating(Tensor input, Tensor weight) -> Tensor"));
|
||||
m.def(
|
||||
TORCH_SELECTIVE_SCHEMA("torch_mlu_ops::update_out_and_lse(Tensor(a!) out, Tensor(b!) lse, "
|
||||
"Tensor block_out, Tensor block_lse,"
|
||||
"Tensor? seq_offsets, Tensor? cu_seqs, Tensor? block_cu_seqs) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::dequant_from_linear_cache(Tensor(a!) key, Tensor(b!)? value, "
|
||||
"Tensor key_cache, Tensor? value_cache, Tensor key_cache_quant_scale, "
|
||||
"Tensor? value_cache_quant_scale, Tensor context_lengths, int max_context_len, "
|
||||
"Tensor? context_seq_offset, Tensor? cache_bs_id, Tensor? cache_seq_offset, "
|
||||
"int quant_mode=0, int quant_bit=8) -> ()"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"torch_mlu_ops::dequant_from_paged_cache(Tensor(a!) key, Tensor(b!)? value, "
|
||||
"Tensor key_cache, Tensor? value_cache, Tensor key_cache_quant_scale, "
|
||||
"Tensor? value_cache_quant_scale, Tensor context_lengths, int max_context_len, "
|
||||
"Tensor? context_seq_offset, Tensor block_tables, int quant_mode=0, int quant_bit=8) -> ()"));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(torch_mlu_ops, PrivateUse1, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::attention_project"),
|
||||
TORCH_FN(tmo::torch_api::attention_project));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::ffn"), TORCH_FN(tmo::torch_api::ffn));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::flash_attention"),
|
||||
TORCH_FN(tmo::torch_api::flash_attention));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::single_query_cached_kv_attn"),
|
||||
TORCH_FN(tmo::torch_api::single_query_cached_kv_attn));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::apply_rotary"),
|
||||
TORCH_FN(tmo::torch_api::apply_rotary));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::reshape_linear_cache"),
|
||||
TORCH_FN(tmo::torch_api::reshape_linear_cache));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::reshape_paged_cache"),
|
||||
TORCH_FN(tmo::torch_api::reshape_paged_cache));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::quant_to_paged_cache"),
|
||||
TORCH_FN(tmo::torch_api::quant_to_paged_cache));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::offline_quant_to_paged_cache"),
|
||||
TORCH_FN(tmo::torch_api::offline_quant_to_paged_cache));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::quant_to_linear_cache"),
|
||||
TORCH_FN(tmo::torch_api::quant_to_linear_cache));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::offline_quant_to_linear_cache"),
|
||||
TORCH_FN(tmo::torch_api::offline_quant_to_linear_cache));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::quant_matmul"),
|
||||
TORCH_FN(tmo::torch_api::quant_matmul));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::quant_matmul_allreduce"),
|
||||
TORCH_FN(tmo::torch_api::quant_matmul_allreduce));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::active"), TORCH_FN(tmo::torch_api::active));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::smooth_quant"),
|
||||
TORCH_FN(tmo::torch_api::smooth_quant));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::fused_layernorm"),
|
||||
TORCH_FN(tmo::torch_api::fused_layernorm));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::fused_moe"), TORCH_FN(tmo::torch_api::fused_moe));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::matmul"), TORCH_FN(tmo::torch_api::matmul));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::batch_matmul"),
|
||||
TORCH_FN(tmo::torch_api::batch_matmul));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::matmul_allreduce"),
|
||||
TORCH_FN(tmo::torch_api::matmul_allreduce));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::swap_blocks"), TORCH_FN(tmo::torch_api::swap_blocks));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::copy_blocks"), TORCH_FN(tmo::torch_api::copy_blocks));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::copy_blocks_out_of_place"),
|
||||
TORCH_FN(tmo::torch_api::copy_blocks_out_of_place));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::group_gemm"), TORCH_FN(tmo::torch_api::group_gemm));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::preload"), TORCH_FN(tmo::torch_api::preload));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::group_gemm_combine_result_allreduce"),
|
||||
TORCH_FN(tmo::torch_api::group_gemm_combine_result_allreduce));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::flash_attn_sq_mm_allreduce"),
|
||||
TORCH_FN(tmo::torch_api::flash_attn_sq_mm_allreduce));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_softmax_topk"),
|
||||
TORCH_FN(tmo::torch_api::moe_softmax_topk));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_expand_input"),
|
||||
TORCH_FN(tmo::torch_api::moe_expand_input));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_gen_idx"), TORCH_FN(tmo::torch_api::moe_gen_idx));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_combine_result"),
|
||||
TORCH_FN(tmo::torch_api::moe_combine_result));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_cast_gating"),
|
||||
TORCH_FN(tmo::torch_api::moe_cast_gating));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::fused_rope"), TORCH_FN(tmo::torch_api::fused_rope));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::update_out_and_lse"),
|
||||
TORCH_FN(tmo::torch_api::update_out_and_lse));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::dequant_from_linear_cache"),
|
||||
TORCH_FN(tmo::torch_api::dequant_from_linear_cache));
|
||||
m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::dequant_from_paged_cache"),
|
||||
TORCH_FN(tmo::torch_api::dequant_from_paged_cache));
|
||||
}
|
||||
|
||||
// Function(copy_blocks__functionalization_glue) is no longer needed in the torch 2.5
|
||||
TORCH_LIBRARY_IMPL(torch_mlu_ops, Functionalize, m) {
|
||||
m.impl("torch_mlu_ops::copy_blocks", copy_blocks__functionalization_glue);
|
||||
}
|
||||
115
torch_mlu_ops-v1.3.2/csrc/torch_api/update_out_and_lse.cpp
Normal file
115
torch_mlu_ops-v1.3.2/csrc/torch_api/update_out_and_lse.cpp
Normal file
@@ -0,0 +1,115 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/update_out_and_lse.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
void update_out_and_lse(at::Tensor &out,
|
||||
at::Tensor &lse,
|
||||
const at::Tensor &block_out,
|
||||
const at::Tensor &block_lse,
|
||||
const c10::optional<at::Tensor> &seq_offsets,
|
||||
const c10::optional<at::Tensor> &cu_seqs,
|
||||
const c10::optional<at::Tensor> &block_cu_seqs) {
|
||||
const torch_mlu::mlu::MLUGuard device_guard(out.device());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
checkTensorSameAttr<TensorAttr::DEVICE>(out, lse, block_out, block_lse, seq_offsets, cu_seqs,
|
||||
block_cu_seqs);
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(seq_offsets, cu_seqs, block_cu_seqs);
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(out, block_out);
|
||||
checkTensorSameAttr<TensorAttr::DTYPE>(lse, block_lse);
|
||||
|
||||
TORCH_CHECK(out.dim() == 3 || out.dim() == 4, "the dim of out must be 3 or 4");
|
||||
TORCH_CHECK(block_out.dim() == 3 || block_out.dim() == 4, "the dim of block_out must be 3 or 4");
|
||||
TORCH_CHECK(lse.dim() == 3, "the dim of lse must be 3");
|
||||
TORCH_CHECK(block_lse.dim() == 3, "the dim of block_lse must be 3");
|
||||
|
||||
bool packed = out.dim() == 3;
|
||||
|
||||
int32_t batch{0}, head_num{0}, head_size{0}, max_seq_len{0}, block_seq_len{0};
|
||||
int64_t bs_stride{0}, seq_stride{0}, head_stride{0}, block_bs_stride{0}, block_seq_stride{0},
|
||||
block_head_stride{0};
|
||||
auto data_type = getCnnlDataType(out.scalar_type());
|
||||
|
||||
batch = lse.size(0);
|
||||
head_num = lse.size(1);
|
||||
max_seq_len = lse.size(2);
|
||||
block_seq_len = block_lse.size(2);
|
||||
head_size = out.size(-1);
|
||||
|
||||
TORCH_CHECK(block_seq_len <= max_seq_len,
|
||||
"the max block_seq_len of block_lse can not be greater than the max_seq_len of lse. "
|
||||
"block_seq_len: ",
|
||||
block_seq_len, " max_seq_len: ", max_seq_len);
|
||||
TORCH_CHECK(data_type == CNNL_DTYPE_FLOAT || data_type == CNNL_DTYPE_BFLOAT16 ||
|
||||
data_type == CNNL_DTYPE_HALF,
|
||||
"the data_type of out only support float, bfloat16 and half");
|
||||
|
||||
if (packed) {
|
||||
seq_stride = out.stride(0);
|
||||
head_stride = out.stride(1);
|
||||
block_seq_stride = block_out.stride(0);
|
||||
block_head_stride = block_out.stride(1);
|
||||
|
||||
if (max_seq_len == block_seq_len && block_seq_len == 1) {
|
||||
bs_stride = seq_stride;
|
||||
block_bs_stride = block_seq_stride;
|
||||
}
|
||||
|
||||
CHECK_SHAPE(out, out.size(0), head_num, head_size);
|
||||
CHECK_SHAPE(block_out, block_out.size(0), head_num, head_size);
|
||||
} else {
|
||||
bs_stride = out.stride(0);
|
||||
seq_stride = out.stride(1);
|
||||
head_stride = out.stride(2);
|
||||
block_bs_stride = block_out.stride(0);
|
||||
block_seq_stride = block_out.stride(1);
|
||||
block_head_stride = block_out.stride(2);
|
||||
CHECK_SHAPE(out, batch, max_seq_len, head_num, head_size);
|
||||
CHECK_SHAPE(block_out, batch, block_seq_len, head_num, head_size);
|
||||
}
|
||||
|
||||
TORCH_CHECK(head_size <= 1024, "the head_size can not be greater than 1024.");
|
||||
|
||||
if (seq_offsets.has_value()) {
|
||||
TORCH_CHECK(seq_offsets.value().is_contiguous(), "seq_offsets must be contiguous");
|
||||
TORCH_CHECK(seq_offsets.value().dtype() == torch::kInt32,
|
||||
"the dtype of seq_offsets must be int32");
|
||||
TORCH_CHECK(seq_offsets.value().dim() == 1, "the dim of seq_offsets must be 1");
|
||||
CHECK_SHAPE(seq_offsets.value(), batch);
|
||||
}
|
||||
|
||||
if (cu_seqs.has_value()) {
|
||||
TORCH_CHECK(cu_seqs.value().is_contiguous(), "cu_seqs must be contiguous");
|
||||
TORCH_CHECK(cu_seqs.value().dim() == 1, "the dim of cu_seqs must be 1");
|
||||
CHECK_SHAPE(cu_seqs.value(), batch + 1);
|
||||
}
|
||||
if (block_cu_seqs.has_value()) {
|
||||
TORCH_CHECK(block_cu_seqs.value().is_contiguous(), "block_cu_seqs must be contiguous");
|
||||
TORCH_CHECK(block_cu_seqs.value().dim() == 1, "the dim of block_cu_seqs must be 1");
|
||||
CHECK_SHAPE(block_cu_seqs.value(), batch + 1);
|
||||
}
|
||||
|
||||
invokeUpdateOutAndLse(queue, getAtTensorPtr(out), (float *)getAtTensorPtr(lse),
|
||||
getAtTensorPtr(block_out), (float *)getAtTensorPtr(block_lse),
|
||||
(int32_t *)getAtTensorPtr(seq_offsets), (int32_t *)getAtTensorPtr(cu_seqs),
|
||||
(int32_t *)getAtTensorPtr(block_cu_seqs), batch, head_num, head_size,
|
||||
max_seq_len, block_seq_len, bs_stride, seq_stride, head_stride,
|
||||
block_bs_stride, block_seq_stride, block_head_stride, packed, data_type);
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
75
torch_mlu_ops-v1.3.2/csrc/torch_api/utils.cpp
Normal file
75
torch_mlu_ops-v1.3.2/csrc/torch_api/utils.cpp
Normal file
@@ -0,0 +1,75 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "utils.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
#define CNNL_TYPE_AND_SCALAR_TYPE_WITHOUT_64BIT(_) \
|
||||
_(CNNL_DTYPE_FLOAT, at::kFloat) \
|
||||
_(CNNL_DTYPE_BFLOAT16, at::kBFloat16) \
|
||||
_(CNNL_DTYPE_HALF, at::kHalf) \
|
||||
_(CNNL_DTYPE_INT32, at::kInt) \
|
||||
_(CNNL_DTYPE_INT8, at::kChar) \
|
||||
_(CNNL_DTYPE_UINT8, at::kByte) \
|
||||
_(CNNL_DTYPE_BOOL, at::kBool) \
|
||||
_(CNNL_DTYPE_INT16, at::kShort) \
|
||||
_(CNNL_DTYPE_COMPLEX_HALF, at::kComplexHalf) \
|
||||
_(CNNL_DTYPE_COMPLEX_FLOAT, at::kComplexFloat)
|
||||
|
||||
cnnlDataType_t getCnnlDataType(const at::ScalarType &data_type) {
|
||||
switch (data_type) {
|
||||
#define DEFINE_CASE(cnnl_dtype, scalar_type) \
|
||||
case scalar_type: \
|
||||
return cnnl_dtype;
|
||||
CNNL_TYPE_AND_SCALAR_TYPE_WITHOUT_64BIT(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
case at::kLong:
|
||||
return CNNL_DTYPE_INT32;
|
||||
case at::kDouble:
|
||||
return CNNL_DTYPE_FLOAT;
|
||||
case at::kComplexDouble:
|
||||
return CNNL_DTYPE_COMPLEX_FLOAT;
|
||||
default:
|
||||
std::string msg("getCnnlDataType() not supported for ");
|
||||
throw std::runtime_error(msg + c10::toString(data_type));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TensorDesc> createTensorDescs(const std::initializer_list<at::Tensor> &tensors) {
|
||||
std::vector<TensorDesc> descs;
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
descs.emplace_back(TensorDesc{nullptr, cnnlDestroyTensorDescriptor});
|
||||
auto tensor = tensors.begin()[i];
|
||||
if (!tensor.defined()) {
|
||||
continue;
|
||||
}
|
||||
cnnlTensorDescriptor_t desc;
|
||||
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&desc));
|
||||
descs[i].reset(desc);
|
||||
cnnlDataType_t data_type = getCnnlDataType(tensor.scalar_type());
|
||||
if (tensor.strides().size() == 0) {
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(descs[i].get(), CNNL_LAYOUT_ARRAY, data_type,
|
||||
tensor.sizes().size(), tensor.sizes().data()));
|
||||
} else {
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(descs[i].get(), CNNL_LAYOUT_ARRAY, data_type,
|
||||
tensor.sizes().size(), tensor.sizes().data(),
|
||||
tensor.strides().data()));
|
||||
}
|
||||
}
|
||||
return descs;
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
164
torch_mlu_ops-v1.3.2/csrc/torch_api/utils.h
Normal file
164
torch_mlu_ops-v1.3.2/csrc/torch_api/utils.h
Normal file
@@ -0,0 +1,164 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#ifndef CSRC_TORCH_API_UTILS_H_
|
||||
#define CSRC_TORCH_API_UTILS_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include "ATen/ScalarType.h"
|
||||
#include "ATen/Tensor.h"
|
||||
#include "aten/cnnl/cnnlHandle.h"
|
||||
#include "c10/util/Exception.h"
|
||||
#include "cnnl.h"
|
||||
#include "framework/core/MLUStream.h"
|
||||
#include "framework/core/caching_allocator.h"
|
||||
#include "framework/core/device.h"
|
||||
#include "framework/core/mlu_guard.h"
|
||||
#include "torch/torch.h"
|
||||
#include "torch/version.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
using TensorDesc = std::unique_ptr<std::remove_pointer_t<cnnlTensorDescriptor_t>,
|
||||
decltype(&cnnlDestroyTensorDescriptor)>;
|
||||
std::vector<TensorDesc> createTensorDescs(const std::initializer_list<at::Tensor> &tensors);
|
||||
cnnlDataType_t getCnnlDataType(const at::ScalarType &data_type);
|
||||
|
||||
template <typename T>
|
||||
bool isMlu(const T &tensor) {
|
||||
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
|
||||
return tensor.device().is_privateuseone();
|
||||
#else
|
||||
return tensor.is_mlu();
|
||||
#endif
|
||||
}
|
||||
|
||||
enum class TensorAttr { DEVICE, DTYPE, ALL };
|
||||
struct attr_t {
|
||||
int64_t device_id;
|
||||
at::ScalarType dtype;
|
||||
};
|
||||
|
||||
inline void checkDevice(int64_t &device_id, const at::Tensor &tensor) {
|
||||
auto tensor_device_id = tensor.get_device();
|
||||
if (device_id == -1) {
|
||||
device_id = tensor_device_id;
|
||||
return;
|
||||
}
|
||||
TORCH_CHECK(tensor_device_id == device_id,
|
||||
"Tensor device id is not same, original device_id: ", device_id,
|
||||
"now device_id is: ", tensor_device_id);
|
||||
}
|
||||
|
||||
inline void checkDtype(at::ScalarType &dtype, const at::Tensor &tensor) {
|
||||
auto tensor_dtype = tensor.scalar_type();
|
||||
if (dtype == at::ScalarType::Undefined) {
|
||||
dtype = tensor_dtype;
|
||||
return;
|
||||
}
|
||||
TORCH_CHECK(tensor_dtype == dtype, "Tensor dtype is not same. original dtype: ", dtype,
|
||||
"now dtype is: ", tensor_dtype);
|
||||
}
|
||||
|
||||
template <TensorAttr attr>
|
||||
inline void checkTensorAttr(attr_t &attr_states, const at::Tensor &tensor) {
|
||||
if (attr == TensorAttr::DEVICE) {
|
||||
checkDevice(attr_states.device_id, tensor);
|
||||
} else if (attr == TensorAttr::DTYPE) {
|
||||
checkDtype(attr_states.dtype, tensor);
|
||||
} else if (attr == TensorAttr::ALL) {
|
||||
checkDevice(attr_states.device_id, tensor);
|
||||
checkDtype(attr_states.dtype, tensor);
|
||||
}
|
||||
}
|
||||
|
||||
template <TensorAttr attr,
|
||||
typename T,
|
||||
typename = typename std::enable_if<
|
||||
std::is_same<typename std::decay<T>::type, at::Tensor>::value>::type>
|
||||
void checkTensorSameWithSpecificAttr(attr_t &attr_states, const c10::optional<T> &tensor) {
|
||||
if (!tensor.has_value() || !tensor->defined()) return;
|
||||
auto temp_tensor = tensor.value();
|
||||
TORCH_CHECK(isMlu(temp_tensor), "Only support mlu tensor.");
|
||||
checkTensorAttr<attr>(attr_states, temp_tensor);
|
||||
}
|
||||
|
||||
template <TensorAttr attr,
|
||||
typename T,
|
||||
typename = typename std::enable_if<
|
||||
std::is_same<typename std::decay<T>::type, at::Tensor>::value>::type>
|
||||
void checkTensorSameWithSpecificAttr(attr_t &attr_states, const T &tensor) {
|
||||
if (!tensor.defined()) return;
|
||||
TORCH_CHECK(isMlu(tensor), "Only support mlu tensor.");
|
||||
checkTensorAttr<attr>(attr_states, tensor);
|
||||
}
|
||||
|
||||
template <TensorAttr attr, typename T, typename... Args>
|
||||
void checkTensorSameWithSpecificAttr(attr_t &attr_states, const T &tensor, Args &&...args) {
|
||||
checkTensorSameWithSpecificAttr<attr>(attr_states, tensor);
|
||||
checkTensorSameWithSpecificAttr<attr>(attr_states, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <TensorAttr attr, typename... Args>
|
||||
void checkTensorSameAttr(Args &&...args) {
|
||||
attr_t attr_states = {-1, at::ScalarType::Undefined};
|
||||
checkTensorSameWithSpecificAttr<attr>(attr_states, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
inline at::ScalarType str2TorchDtype(const std::string &type) {
|
||||
static std::map<std::string, at::ScalarType> dtype_map = {
|
||||
{"float", torch::kFloat32}, {"half", torch::kHalf}, {"bfloat16", torch::kBFloat16},
|
||||
{"int32", torch::kInt32}, {"int8", torch::kInt8},
|
||||
};
|
||||
return dtype_map.at(type);
|
||||
}
|
||||
|
||||
inline std::string &torchDtype2Str(const at::ScalarType type) {
|
||||
static std::map<at::ScalarType, std::string> torch_dtype_map = {
|
||||
{torch::kFloat32, "float"}, {torch::kHalf, "half"}, {torch::kBFloat16, "bfloat16"},
|
||||
{torch::kInt32, "int32"}, {torch::kInt8, "int8"},
|
||||
};
|
||||
return torch_dtype_map.at(type);
|
||||
}
|
||||
|
||||
inline cnnlDataType_t str2CnnlDtype(const std::string &type) {
|
||||
static std::map<std::string, cnnlDataType_t> cnnl_dtype_map = {
|
||||
{"float", CNNL_DTYPE_FLOAT}, {"half", CNNL_DTYPE_HALF}, {"bfloat16", CNNL_DTYPE_BFLOAT16},
|
||||
{"int32", CNNL_DTYPE_INT32}, {"int8", CNNL_DTYPE_INT8},
|
||||
};
|
||||
return cnnl_dtype_map.at(type);
|
||||
}
|
||||
|
||||
#define CHECK_SHAPE(x, ...) \
|
||||
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
|
||||
#x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define CHECK_TENSOR_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.")
|
||||
|
||||
#define CHECK_OPTIONAL_TENSOR_CONTIGUOUS(x) \
|
||||
if (x.has_value()) TORCH_CHECK(x.value().is_contiguous(), #x " must be contiguous.")
|
||||
|
||||
inline void *getAtTensorPtr(const c10::optional<at::Tensor> &tensor) {
|
||||
return tensor.has_value() ? tensor.value().data_ptr() : nullptr;
|
||||
}
|
||||
|
||||
inline void *getAtTensorPtr(const at::Tensor &tensor) {
|
||||
return tensor.defined() ? tensor.data_ptr() : nullptr;
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_TORCH_API_UTILS_H_
|
||||
Reference in New Issue
Block a user