#include #include #include #include #include #include #include "utils.h" /* * How to write a meta implementation for a custom operator (meta kernel): * * Meta implementations are used for shape and dtype inference, tracing, and export. * They do NOT perform any real computation or allocate device memory. * Instead, they return empty tensors with the correct shapes, dtypes, and device types. * * Steps to write a meta implementation: * 1. The function signature should match the operator's schema, but only use the arguments * necessary to infer output shapes and dtypes. * 2. Use input tensor shapes, dtypes, and any relevant arguments to compute the output shapes. * 3. Return empty tensors (e.g., at::empty_symint, at::empty_like) with the correct shape and dtype. * 4. Do NOT perform any real computation or data movement. * 5. Register the meta implementation with the "Meta" dispatch key using TORCH_LIBRARY_IMPL or similar. * * Example: * std::tuple my_op_meta( * at::Tensor &input, int64_t some_param) { * // Infer output shape based on input and parameters * auto out_shape = ...; * at::Tensor out = at::empty_symint(out_shape, input.options()); * // Return empty tensor(s) with correct shape/dtype * return {out, ...}; * } * * See below for real examples. */ namespace vllm_ascend { namespace meta { std::tuple rotary_embedding_meta( at::Tensor &positions, at::Tensor &query, at::Tensor &key, int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox) { auto num_tokens = positions.sym_numel(); auto query_hidden_size = query.sym_numel() / num_tokens; auto key_hidden_size = key.sym_numel() / num_tokens; auto num_heads = query_hidden_size / head_size; auto num_kv_heads = key_hidden_size / head_size; at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options()); at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options()); return {query_dst, key_dst}; } std::tuple get_masked_input_and_mask_meta( at::Tensor &input, const int64_t org_vocab_start_index, const int64_t org_vocab_end_index, const int64_t num_org_vocab_padding, const int64_t added_vocab_start_index, const int64_t added_vocab_end_index) { at::Tensor masked_input = at::empty_like(input); at::Tensor mask = at::empty_like(input, input.options().dtype(at::kBool)); return {masked_input, mask}; } at::Tensor bgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y, int64_t slice_offset, int64_t slice_size) { at::Tensor y_out = at::empty_like(y); return y_out; } at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len, at::Tensor &y, int64_t slice_offset, int64_t slice_size) { at::Tensor y_out = at::empty_like(y); return y_out; } std::tuple mla_preprocess( const at::Tensor &hiddenState, const at::Tensor &wdqkv, const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq, const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0, const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1, const c10::optional &ctkv_scale, const c10::optional &q_nope_scale, c10::optional cache_mode, c10::optional quant_mode, at::Tensor &q_out0, at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1) { return {q_out0, kv_cache_out0, q_out1, kv_cache_out1}; } void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, c10::optional format_mode, c10::optional quant_mode) { return; } } // namespace meta } // namespace vllm_ascend namespace { // Register the meta implementations of the custom kernels for symbolic tracing, this will also // the custom kernel been captured into aclgraph TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { // Rotary embedding meta implementation ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); // Masked input and mask meta implementation ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta); // Bgmv expand ops.impl("bgmv_expand", &vllm_ascend::meta::bgmv_expand_meta); // Sgmv expand ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta); // MLA preprocess ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess); // batch_matmul_transpose ops.impl("batch_matmul_transpose", &vllm_ascend::meta::batch_matmul_transpose); } }