From 46cee945b31ea5bafff95be62bac5ae52d3a307e Mon Sep 17 00:00:00 2001 From: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Date: Fri, 30 Jan 2026 11:21:37 +0800 Subject: [PATCH] [doc][npugraph_ex]add npugraph_ex introduction doc (#6306) ### What this PR does / why we need it? As part of the preparation work for the [RFC](https://github.com/vllm-project/vllm-ascend/issues/6214) We have added a documentation about npugraph_ex, which mainly explains and introduces its usage and FX graph optimization. The introduction to FX graph optimization also includes specific explanations of the default passes, the implementation methods for custom fusion passes, and how to capture the FX graph during the optimization process through environment variable configuration. --------- Signed-off-by: chencangtao Co-authored-by: chencangtao --- .../developer_guide/feature_guide/index.md | 1 + .../feature_guide/npugraph_ex.md | 101 ++++++++++++++++++ docs/source/user_guide/feature_guide/index.md | 1 + .../user_guide/feature_guide/npugraph_ex.md | 35 ++++++ 4 files changed, 138 insertions(+) create mode 100644 docs/source/developer_guide/feature_guide/npugraph_ex.md create mode 100644 docs/source/user_guide/feature_guide/npugraph_ex.md diff --git a/docs/source/developer_guide/feature_guide/index.md b/docs/source/developer_guide/feature_guide/index.md index 9c92246c..3f4f2d2c 100644 --- a/docs/source/developer_guide/feature_guide/index.md +++ b/docs/source/developer_guide/feature_guide/index.md @@ -14,4 +14,5 @@ KV_Cache_Pool_Guide add_custom_aclnn_op context_parallel quantization +npugraph_ex ::: diff --git a/docs/source/developer_guide/feature_guide/npugraph_ex.md b/docs/source/developer_guide/feature_guide/npugraph_ex.md new file mode 100644 index 00000000..1d446812 --- /dev/null +++ b/docs/source/developer_guide/feature_guide/npugraph_ex.md @@ -0,0 +1,101 @@ +# Npugraph_ex + +## How it works? + +Optimization based on Fx graphs, can be considered an acceleration solution for the aclgraph mode. + +You can get its code [here](https://gitcode.com/Ascend/torchair) + +## Default Fx Graph Optimization + +### Fx Graph pass + +- For the intermediate nodes of the model, replace the non-in-place operators contained in the nodes with in-place operators to reduce memory movement during computation and improve performance. +- For the original input parameters of the model, if they include in-place operators, Dynamo's Functionalize process will replace the in-place operators with a form of non-in-place operators + copy operators. npugraph_ex will reverse this process, restoring the in-place operators and reducing memory movement. + +### Fx fusion pass + +npugraph_ex now provides three default operator fusion passes, and more will be added in the future. + +Operator combinations that meet the replacement rules can be replaced with the corresponding fused operators. + +You can get the default fusion pass list [here](https://www.hiascend.com/document/detail/zh/Pytorch/730/modthirdparty/torchairuseguide/torchair_00017.html) + +## Custom fusion pass + +Users can register a custom graph fusion pass in TorchAir to modify PyTorch FX graphs. The registration relies on the register_replacement API. + +Below is the declaration of this API and a demo of its usage. + +```python +register_replacement(search_fn, replace_fn, example_inputs, trace_fn=fwd_only, extra_check=_return_true, search_fn_pattern=None) +``` + +|Parameter Name| Input/Output |Explanation|Is necessary| +|--|--------------|---|-------| +|search_fn|Input|This function is the operator combination or calculation logic that you want to recognize in the FX graph, such as the operator combination that needs to be fused|Yes| +|replace_fn|Input|When the combination corresponding to search_fn is found in the target graph, this function's computation logic will replace the original subgraph to achieve operator fusion or optimization.|Yes| +|example_inputs|Input|Example input tensors used to track search_fn and replace_fn. The shape and dtype of the input should match the actual scenario.|Yes| +|trace_fn|Input|By default, only the forward computation graph is tracked, which is suitable for optimization during the inference phase; if training scenarios need to be supported, a function that supports backward tracking can be provided.|No| +|extra_check|Input|Find the extra verification function after operator fusion. The function's input parameter must be a Match object from torch._inductor.pattern_matcher, and it is used for further custom checks on the matching result, such as checking whether the fused operators are on the same stream, checking the device type, checking the input shapes, and so on.|No| +|search_fn_pattern|Input|A custom pattern object is generally unnecessary to provide. Its definition follows the rules of the native PyTorch MultiOutputPattern object. After passing this parameter, search_fn will no longer be used to match operator combinations; instead, this parameter will be used directly as the matching rule.|No| + +Usage Example + +```python +import functools +import torch, torch_npu, torchair + +from torch._inductor.pattern_matcher import Match +from torch._subclasses.fake_tensor import FakeTensorMode +from torchair.core.utils import logger + +# Assume fusing the add operator and the npu_rms_norm operator into the npu_add_rms_norm operator +# Define a search_fn to find the operator combinations in the original FX graph before fusion. +def search_fn(x1, x2, gamma): + xOut = torch.add(x1, x2) + y, _ = torch_npu.npu_rms_norm(xOut, gamma) + return y, xOut + +# Define a replace_fn, that is, a fusion operator, used to replace operator combinations in the FX graph +def replace_fn(x1, x2, gamma): + y, _, xOut = torch_npu.npu_add_rms_norm( + x1, x2, gamma + ) + return y, xOut + +# extra_check can pass in additional validation logic. Here, it is used to check whether the last dimension of the first input parameter x1 is a specific value; if it is not the specific value, fusion is not allowed. +def extra_check(match: Match): + x1 = match.kwargs.get("x1") + + if x1 is None: + return False + if not hasattr(x1, "meta") or "val" not in x1.meta: + return False + + a_shape = x1.meta["val"].shape + return a_shape[-1] == 7168 + + +# Define some sample inputs to trace search_fn and replace_fn into an FX graph +fake_mode = FakeTensorMode() +with fake_mode: + # sizes/values don't actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + input_tensor = functools.partial(torch.empty, (1, 1, 2), device="npu", dtype=torch.float16) + kwargs_tensor = functools.partial(torch.empty, 2, device="npu", dtype=torch.float16) + + # Call the torchair.register_replacement API with search_fn, replace_fn, and example_inputs. If there are additional validations, you can pass them in as extra_check. + torchair.register_replacement( + search_fn=search_fn, + replace_fn=replace_fn, + example_inputs=(input_tensor(), input_tensor(), kwargs_tensor()), + extra_check=extra_check + ) +``` + +The default fusion pass in npugraph_ex is also implemented based on this API. You can see more examples of using this API in the vllm-ascned and nugraph_ex code repositories. + +### DFX + +By reusing the TORCH_COMPILE_DEBUG environment variable from the PyTorch community, when TORCH_COMPILE_DEBUG=1 is set, it will output the FX graphs throughout the entire process. diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md index 96872abf..a8318be4 100644 --- a/docs/source/user_guide/feature_guide/index.md +++ b/docs/source/user_guide/feature_guide/index.md @@ -22,4 +22,5 @@ Fine_grained_TP layer_sharding speculative_decoding context_parallel +npugraph_ex ::: diff --git a/docs/source/user_guide/feature_guide/npugraph_ex.md b/docs/source/user_guide/feature_guide/npugraph_ex.md new file mode 100644 index 00000000..fa318e9a --- /dev/null +++ b/docs/source/user_guide/feature_guide/npugraph_ex.md @@ -0,0 +1,35 @@ +# Npugraph_ex + +## Introduction + +As introduced in the [RFC](https://github.com/vllm-project/vllm-ascend/issues/4715), this is a simple ACLGraph graph mode acceleration solution based on Fx graphs. + +## Using npugraph_ex + +Npugraph_ex will be enabled by default in the future, Take Qwen series models as an example to show how to configure it. + +Offline example: + +```python +from vllm import LLM + +model = LLM( + model="path/to/Qwen2-7B-Instruct", + additional_config={ + "npugraph_ex_config": { + "enable": True, + "enable_static_kernel": False, + } + } +) +outputs = model.generate("Hello, how are you?") +``` + +Online example: + +```shell +vllm serve Qwen/Qwen2-7B-Instruct +--additional-config '{"npugraph_ex_config":{"enable":true, "enable_static_kernel":false}}' +``` + +You can find more details about npugraph_ex [here](https://www.hiascend.com/document/detail/zh/Pytorch/730/modthirdparty/torchairuseguide/torchair_00021.html)