From c0f0b708137838c1387b4f98eeaeda383d06f0d9 Mon Sep 17 00:00:00 2001 From: Pleaplusone Date: Mon, 11 Aug 2025 15:59:42 +0800 Subject: [PATCH] [core] Support capture custom ops into aclgraph (#2113) ### What this PR does / why we need it? Thanks to the PR https://github.com/vllm-project/vllm-ascend/pull/426 make vllm-ascend support the aclgraph inference to reduce the host overhead. However, the capability of aclgraph strongly relies on the functionality provided by `torch.compile`, which is the key feature supported in torch 2.x . Therefore, capture custom op into aclgraph is only possible when it can be recognize and captured by `torch.compile`. In this PR, we register the meta implementation of current custom ops to enable the fx graph capture. And by doing that, insert those custom ops into aclgraph become a natural thing to the ascend runtime. ### Does this PR introduce _any_ user-facing change? No user face change. ### How was this patch tested? Tested in unittest, we will integrate the `rotary_embedding` op into a small custom model and use `torch.compile` and aclgraph to capture and replay it to verify its functionality. - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/1b9902806915040ac9b3029f2ab7522ec505afc3 --------- Signed-off-by: ganyi --- csrc/torch_binding.cpp | 11 ++ csrc/torch_binding_meta.cpp | 86 +++++++++++ csrc/utils.h | 12 -- .../singlecard/ops/test_rotary_embedding.py | 146 +++++++++++++++++- vllm_ascend/meta_registration.py | 86 +++++++++++ vllm_ascend/utils.py | 4 + 6 files changed, 332 insertions(+), 13 deletions(-) create mode 100644 csrc/torch_binding_meta.cpp create mode 100644 vllm_ascend/meta_registration.py diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index f2a0d1f..8bdc4b5 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -27,6 +27,17 @@ namespace vllm_ascend { +AscendType get_dtype_from_torch(at::ScalarType scalarType) +{ + if (scalarType == at::ScalarType::Float) { + return AscendType::FP32; + } else if (scalarType == at::ScalarType::BFloat16) { + return AscendType::BF16; + } else { + return AscendType::FP16; + } +} + std::tuple rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key, int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox) { diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp new file mode 100644 index 0000000..1f9464c --- /dev/null +++ b/csrc/torch_binding_meta.cpp @@ -0,0 +1,86 @@ +#include +#include +#include +#include +#include +#include +#include "utils.h" +/* + * How to write a meta implementation for a custom operator (meta kernel): + * + * Meta implementations are used for shape and dtype inference, tracing, and export. + * They do NOT perform any real computation or allocate device memory. + * Instead, they return empty tensors with the correct shapes, dtypes, and device types. + * + * Steps to write a meta implementation: + * 1. The function signature should match the operator's schema, but only use the arguments + * necessary to infer output shapes and dtypes. + * 2. Use input tensor shapes, dtypes, and any relevant arguments to compute the output shapes. + * 3. Return empty tensors (e.g., at::empty_symint, at::empty_like) with the correct shape and dtype. + * 4. Do NOT perform any real computation or data movement. + * 5. Register the meta implementation with the "Meta" dispatch key using TORCH_LIBRARY_IMPL or similar. + * + * Example: + * std::tuple my_op_meta( + * at::Tensor &input, int64_t some_param) { + * // Infer output shape based on input and parameters + * auto out_shape = ...; + * at::Tensor out = at::empty_symint(out_shape, input.options()); + * // Return empty tensor(s) with correct shape/dtype + * return {out, ...}; + * } + * + * See below for real examples. + */ + +namespace vllm_ascend { +namespace meta { + +std::tuple rotary_embedding_meta( + at::Tensor &positions, + at::Tensor &query, + at::Tensor &key, + int64_t head_size, + at::Tensor &cos_sin_cache, + bool is_neox) { + auto num_tokens = positions.sym_numel(); + auto query_hidden_size = query.sym_numel() / num_tokens; + auto key_hidden_size = key.sym_numel() / num_tokens; + + auto num_heads = query_hidden_size / head_size; + auto num_kv_heads = key_hidden_size / head_size; + at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options()); + at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options()); + + return {query_dst, key_dst}; +} + +std::tuple get_masked_input_and_mask_meta( + at::Tensor &input, + const int64_t org_vocab_start_index, + const int64_t org_vocab_end_index, + const int64_t num_org_vocab_padding, + const int64_t added_vocab_start_index, + const int64_t added_vocab_end_index) { + + at::Tensor masked_input = at::empty_like(input); + at::Tensor mask = at::empty_like(input, input.options().dtype(at::kBool)); + + return {masked_input, mask}; +} + + +} // namespace meta +} // namespace vllm_ascend + +namespace { + // Register the meta implementations of the custom kernels for symbolic tracing, this will also + // the custom kernel been captured into aclgraph + TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) { + // Rotary embedding meta implementation + ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); + // Masked input and mask meta implementation + ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta); + +} +} \ No newline at end of file diff --git a/csrc/utils.h b/csrc/utils.h index e94ad2d..74481e1 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -29,15 +29,3 @@ } -namespace vllm_ascend { -AscendType get_dtype_from_torch(at::ScalarType scalarType) -{ - if (scalarType == at::ScalarType::Float) { - return AscendType::FP32; - } else if (scalarType == at::ScalarType::BFloat16) { - return AscendType::BF16; - } else { - return AscendType::FP16; - } -} -} // namespace vllm_ascend diff --git a/tests/e2e/singlecard/ops/test_rotary_embedding.py b/tests/e2e/singlecard/ops/test_rotary_embedding.py index a3504a8..c750f01 100644 --- a/tests/e2e/singlecard/ops/test_rotary_embedding.py +++ b/tests/e2e/singlecard/ops/test_rotary_embedding.py @@ -17,11 +17,12 @@ enable_custom_op() # Only Neox style true scenario is supported for now IS_NEOX_STYLE = [True] DTYPES = [torch.half] -HEAD_SIZES = [64, 96, 128, 256] +HEAD_SIZES = [64, 64, 96, 128, 256] ROTARY_DIMS = [None, 32] # None means rotary dim == head size NUM_HEADS = [17] # Arbitrary values for testing BATCH_SIZES = [5] # Arbitrary values for testing SEQ_LENS = [11, 4096] # Arbitrary values for testing +NUM_TOKENS = [10, 21] SEEDS = [0] DEVICES = [f"npu:{0}"] # Set tolerance to 1 for quant ops @@ -198,3 +199,146 @@ def test_rotary_embedding_quant_with_leading_dim( ref_key, atol=DEFAULT_ATOL, rtol=DEFAULT_RTOL) + + +class ModelwithRotaryEmbedding(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3) + self.rope = RotaryEmbedding( + head_size=head_size, + rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + base=base, + is_neox_style=is_neox_style, + dtype=dtype, + ) + self.o_proj = nn.Linear(num_heads * head_size, hidden_size) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(3, dim=-1) + query, key = torch.ops._C.rotary_embedding( + positions, + q, + k, + self.rope.head_size, + self.rope.cos_sin_cache, + self.rope.is_neox_style, + ) + query = query.view(q.shape) + key = key.view(k.shape) + o = self.o_proj(query) + return o + + +# The first graph seems will have some accuracy issue when directly run pytest on the ops folder, +# add a warmup graph replay for workaround +ACL_GRPAH_FIRST_RUN = True + + +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("num_tokens", BATCH_SIZES) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_capture_rotary_embedding_in_aclgraph( + is_neox_style: bool, + num_tokens: int, + num_heads: int, + head_size: int, + rotary_dim: int, + dtype: torch.dtype, + seed: int, + device: str, + max_position_embeddings: int = 8192, + base: int = 10000, +): + """Test if the rotary embedding can be captured in aclgraph.""" + torch.manual_seed(seed) + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size + model = ModelwithRotaryEmbedding( + hidden_size=num_heads * head_size, + num_heads=num_heads, + head_size=head_size, + rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + base=base, + is_neox_style=is_neox_style, + dtype=dtype, + ) + + def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input): + # Validate if the rotary_embedding custom kernel is indeed inside the graph by + # string match + graph = str(gm.graph) + assert "_C.rotary_embedding" in graph + return gm + + static_positions = torch.randint(0, max_position_embeddings, + (num_tokens, )) + static_hidden_states = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device="npu") + compiled_model = torch.compile(model, backend=custom_op_checking_backend) + stream = torch.npu.Stream() + stream.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(stream): + # warmup the fx graph before capture + for i in range(3): + static_output = compiled_model(static_positions, + static_hidden_states, + offsets=None) + stream.wait_stream(torch.npu.current_stream()) + + aclgraph = torch.npu.NPUGraph() + + with torch.npu.graph(aclgraph): + # Capture the model in aclgraph. + static_output = compiled_model(static_positions, static_hidden_states) + # Capture the model in aclgraph. + random_filled_positions = torch.randint(0, + max_position_embeddings, + (num_tokens, ), + device="npu") + random_filled_hidden_states = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device="npu") + static_positions.copy_(random_filled_positions) + static_hidden_states.copy_(random_filled_hidden_states) + + aclgraph.replay() + global ACL_GRPAH_FIRST_RUN + if ACL_GRPAH_FIRST_RUN: + ACL_GRPAH_FIRST_RUN = False + return + output_reference = model(static_positions, static_hidden_states) + torch.testing.assert_close(static_output, + output_reference, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) diff --git a/vllm_ascend/meta_registration.py b/vllm_ascend/meta_registration.py new file mode 100644 index 0000000..600b5e7 --- /dev/null +++ b/vllm_ascend/meta_registration.py @@ -0,0 +1,86 @@ +import torch +from torch.library import Library + +# This file provides a template and registration utilities for writing "meta" implementations +# of custom operators in Python for the vllm_ascend project. +# +# We offer two ways to implement meta implementations for custom ops: +# 1. Python meta implementation (as shown in this file): Write a Python function that +# takes the same arguments as your operator and returns empty tensors with the correct +# shapes and dtypes. This is useful for rapid prototyping and for ops that are only +# used in Python. +# 2. C++ meta implementation: You can also implement the meta function in C++ for better +# performance or to match the C++ op logic more closely. See `torch_binding_meta.cpp` +# for examples of C++ meta implementations and how to register them. +# +# Both approaches enable tracing, export, and shape inference in PyTorch and vLLM, which +# is essential for supporting `torch.compile` and aclgraph. + +# How to add a new meta implementation in Python: +# ------------------------------------- +# 1. Write a Python function that takes the same arguments as your operator, and returns +# empty tensors (using torch.empty_like, torch.empty, etc.) with the correct shapes and dtypes. +# Do NOT perform any real computation or allocate device memory. +# +# 2. Register your meta function using `register_meta_if_necessary`, providing: +# - The namespace (usually "_C" for custom ops) +# - The operator name (as registered in C++) +# - The Python meta function +# - (Optional) The overload name, if your op has overloads +# +# 3. The registration utility will check if a meta implementation already exists for your op, +# and only register if necessary. This avoids duplicate registrations. +# +# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask. +# +# 5. When developing new custom ops, always provide a meta implementation to enable tracing, +# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile` +# and aclgraph. +# +# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors + +lib = Library("_C", "IMPL") + + +def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""): + if overload != "": + op_name = op_name + "." + overload + schema_to_find = ns + "::" + op_name + meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key( + "Meta") + if schema_to_find in meta_impl_list: + return + lib.impl(op_name, fn, "Meta") + + +def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool): + + num_tokens = positions.numel() + query_hidden_size = query.numel() // num_tokens + key_hidden_size = key.numel() // num_tokens + num_heads = query_hidden_size // head_size + num_kv_heads = key_hidden_size // head_size + + query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size) + key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size) + return query_dst, key_dst + + +def get_masked_input_and_mask_meta(input: torch.Tensor, + org_vocab_start_index: int, + org_vocab_end_index: int, + num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int): + + masked_input = torch.empty_like(input) + mask = torch.empty_like(input).to(torch.bool) + + return masked_input, mask + + +register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta) +register_meta_if_necessary("_C", "get_masked_input_and_mask", + get_masked_input_and_mask_meta) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index ee620b4..7c0f77f 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -214,8 +214,12 @@ def enable_custom_op(): if _CUSTOM_OP_ENABLED is not None: return _CUSTOM_OP_ENABLED try: + # isort: off # register custom ops into torch_library here import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401 + # register the meta implementation for custom kernel if necessary + import vllm_ascend.meta_registration # type: ignore # noqa: F401 + # isort: on _CUSTOM_OP_ENABLED = True except ImportError: _CUSTOM_OP_ENABLED = False