#include #include #include #include #include #include #include "utils.h" /* * How to write a meta implementation for a custom operator (meta kernel): * * Meta implementations are used for shape and dtype inference, tracing, and export. * They do NOT perform any real computation or allocate device memory. * Instead, they return empty tensors with the correct shapes, dtypes, and device types. * * Steps to write a meta implementation: * 1. The function signature should match the operator's schema, but only use the arguments * necessary to infer output shapes and dtypes. * 2. Use input tensor shapes, dtypes, and any relevant arguments to compute the output shapes. * 3. Return empty tensors (e.g., at::empty_symint, at::empty_like) with the correct shape and dtype. * 4. Do NOT perform any real computation or data movement. * 5. Register the meta implementation with the "Meta" dispatch key using TORCH_LIBRARY_IMPL or similar. * * Example: * std::tuple my_op_meta( * at::Tensor &input, int64_t some_param) { * // Infer output shape based on input and parameters * auto out_shape = ...; * at::Tensor out = at::empty_symint(out_shape, input.options()); * // Return empty tensor(s) with correct shape/dtype * return {out, ...}; * } * * See below for real examples. */ namespace vllm_ascend { namespace meta { const int64_t INT4_NUMS_IN_INT32 = 8; std::tuple rotary_embedding_meta( at::Tensor &positions, at::Tensor &query, at::Tensor &key, int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox) { auto num_tokens = positions.sym_numel(); auto query_hidden_size = query.sym_numel() / num_tokens; auto key_hidden_size = key.sym_numel() / num_tokens; auto num_heads = query_hidden_size / head_size; auto num_kv_heads = key_hidden_size / head_size; at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options()); at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options()); return {query_dst, key_dst}; } std::tuple get_masked_input_and_mask_meta( at::Tensor &input, const int64_t org_vocab_start_index, const int64_t org_vocab_end_index, const int64_t num_org_vocab_padding, const int64_t added_vocab_start_index, const int64_t added_vocab_end_index) { at::Tensor masked_input = at::empty_like(input); at::Tensor mask = at::empty_like(input, input.options().dtype(at::kBool)); return {masked_input, mask}; } at::Tensor bgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y, int64_t slice_offset, int64_t slice_size) { at::Tensor y_out = at::empty_like(y); return y_out; } at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len, at::Tensor &y, int64_t slice_offset, int64_t slice_size) { at::Tensor y_out = at::empty_like(y); return y_out; } std::tuple mla_preprocess( const at::Tensor &hiddenState, const at::Tensor &wdqkv, const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq, const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0, const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1, const c10::optional &ctkv_scale, const c10::optional &q_nope_scale, c10::optional cache_mode, c10::optional quant_mode, at::Tensor &q_out0, at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1) { return {q_out0, kv_cache_out0, q_out1, kv_cache_out1}; } std::tuple grouped_matmul_swiglu_quant( const at::Tensor &x, const at::Tensor &weight, const at::Tensor &weight_scale, const at::Tensor &x_scale, const at::Tensor &group_list, const c10::optional &bias, const c10::optional &offset) { int m = x.sizes()[0]; int n = weight.sizes()[2]; bool is_a8w4 = x.dtype() == at::kChar && weight.dtype() == at::kInt; if (is_a8w4) { n *= INT4_NUMS_IN_INT32; } at::Tensor output = at::empty({m, n/2}, x.options().dtype(c10::ScalarType::Char)); at::Tensor output_scale = at::empty({m}, x.options().dtype(c10::ScalarType::Float)); at::Tensor output_offset = at::empty({}, x.options().dtype(c10::ScalarType::Float)); return {output, output_scale, output_offset}; } std::tuple grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta( const at::Tensor & x, const at::TensorList & weight, const at::TensorList & weight_scale, const at::Tensor & x_scale, const at::Tensor & group_list, const c10::optional & bias, const c10::optional & offset) { auto x_size = x.sizes(); int n = weight[0].sizes()[1]; int m = x_size[0]; int k = x_size[1]; at::Tensor output = at::zeros({m, n/2}, c10::dtype(c10::ScalarType::Char)); at::Tensor output_scale = at::zeros({m}, c10::dtype(c10::ScalarType::Float)); at::Tensor output_offset = at::zeros({m}, c10::dtype(c10::ScalarType::Float)); return std::tuple(output, output_scale, output_offset); } } // namespace meta } // namespace vllm_ascend namespace { // Register the meta implementations of the custom kernels for symbolic tracing, this will also // the custom kernel been captured into aclgraph TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { // Rotary embedding meta implementation ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); // Masked input and mask meta implementation ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta); // Bgmv expand ops.impl("bgmv_expand", &vllm_ascend::meta::bgmv_expand_meta); // Sgmv expand ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta); // MLA preprocess ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess); // grouped_matmul_swiglu_quant meta implementation ops.impl("grouped_matmul_swiglu_quant", &vllm_ascend::meta::grouped_matmul_swiglu_quant); // Grouped matmul swiglu quant weight nz tensor list ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", &vllm_ascend::meta::grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta); } }