diff --git a/.github/workflows/scripts/config.yaml b/.github/workflows/scripts/config.yaml index 9784cbf7..acca617c 100644 --- a/.github/workflows/scripts/config.yaml +++ b/.github/workflows/scripts/config.yaml @@ -132,6 +132,8 @@ e2e-multicard-2-cards: estimated_time: 215 - name: tests/e2e/multicard/2-cards/test_disaggregated_encoder.py estimated_time: 90 + - name: tests/e2e/multicard/2-cards/test_sp_pass.py + estimated_time: 300 e2e-multicard-4-cards: # TODO: recover skipped tests diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index 4a7a2380..6396eb05 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -45,6 +45,7 @@ The following table lists additional configuration options available in vLLM Asc | `pa_shape_list` | list | `[]` | The custom shape list of page attention ops. | | `enable_kv_nz` | bool | `False` | Whether to enable KV cache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). | | `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear | +| `sp_threshold` | int | `1000` | For dense models, only num_tokens > threshold will enable sequence parallelism. | The details of each configuration option are as follows: diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md index 0e15a081..ef0d2555 100644 --- a/docs/source/user_guide/feature_guide/index.md +++ b/docs/source/user_guide/feature_guide/index.md @@ -24,4 +24,5 @@ speculative_decoding context_parallel npugraph_ex weight_prefetch +sequence_parallelism ::: diff --git a/docs/source/user_guide/feature_guide/sequence_parallelism.md b/docs/source/user_guide/feature_guide/sequence_parallelism.md new file mode 100644 index 00000000..a6243b91 --- /dev/null +++ b/docs/source/user_guide/feature_guide/sequence_parallelism.md @@ -0,0 +1,57 @@ +# Sequence Parallelism + +## What is Sequence Parallelism + +Sequence Parallelism (SP) was first introduced in [Megatron](https://arxiv.org/pdf/2205.05198), with the original intention of reducing training activation memory. The core modification was changing `Allreduce->LayerNorm` to `ReduceScatter->LayerNorm->Allgather`. This technique was later applied to inference by vllm. It should be noted that splitting Allreduce into ReduceScatter and Allgather does not inherently bring performance benefits; it reduces the computation load of LayerNorm, but this gain is minimal. The real benefits of SP come from: + +1. LLM inference deployment often uses quantization. Taking INT8 quantization commonly used on NPUs as an example, after LayerNorm, a Quant operator quantizes the hidden states from BF16 to INT8. The communication volume of Allgather is halved, and the time consumption is almost halved. +2. ReduceScatter and Allgather can be fused with the preceding and following Matmul operations respectively into communication-computation parallel operators, reducing latency. + +## How to Use + +Currently, vllm-ascend has implemented Sequence Parallelism for VL-class models based on the Inductor pass. It can be enabled in the following way: + +```bash +vllm serve Qwen/Qwen3-VL-2B-Instruct \ + --tensor-parallel-size 2 \ + --compilation-config '{"pass_config": {"enable_sp": true}}' \ + --additional_config={"sp_threshold": 1000} +``` + +- `"pass_config": {"enable_sp": true}`: This is the switch for SP. Since SP relies on graph mode, it must be enabled and is not supported in eager mode. +- `--additional_config={"sp_threshold": 1000}`: Based on our experiments, when the number of tokens is small (empirical value is less than 1000), SP can actually bring negative benefits. This is because when the communication volume is small, the fixed overhead of the communication operator becomes the dominant factor. Therefore, when one communication operator (Allreduce) is split into two communication operators (ReduceScatter+Allgather), the end-to-end latency often becomes longer. Thus, we have reserved the `sp_threshold`parameter; SP will only take effect when `num_tokens >= sp_threshold`. **The default value is 1000, which generally does not need to be modified.** `sp_threshold` will be appended into `compile_ranges_split_points`, which is a parameter provided by vllm that splits the graph compilation range `[1, max_num_batched_tokens]` into `{[1, split_points[0]], [split_points[0] + 1, split_points[1]], ..., [split_points[-1] + 1, max_num_batched_tokens]}`, and sequentially checks whether the `is_applicable_for_range` of the pass returns `True`. + +Without modifying `sp_threshold`, the simplest way and recommended way to enable SP is: + +```bash +vllm serve Qwen/Qwen3-VL-2B-Instruct \ + --tensor-parallel-size 2 \ + --compilation-config '{"pass_config": {"enable_sp": true}}' +``` + +## Difference Between SP and Flash Comm V1 + +[Flash Comm V1 (FC1)](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/ascend-inference-cluster-flashcomm.md) is an enhanced version of Sequence Parallelism developed based on NPU. The enhancements include: + +1. For models using the MLA structure, Allgather is postponed until after QKV projection, further reducing communication volume. +2. For MoE models, Allgather is postponed until after Gating+DynamicQuant, also aiming to reduce communication volume. + +FC1 is a unique optimization in vllm-ascend, currently implemented based on Custom OP, but it is difficult to support VL-class models (reasons detailed in [[RFC]: support sequence parallelism by pass](https://github.com/vllm-project/vllm-ascend/issues/5712) ). Therefore, currently FC1 and SP are complementary. + +## Support Matrix + +### Without Quantization + +| | VL + Dense | VL + MoE | non-VL + Dense | non-VL + MoE | +| -------------------- | ---------- | -------- | -------------- | ------------ | +| Sequence Parallelism | graph | x | x | x | +| Flash Comm V1 | x | x | eager/graph | eager/graph | + +### With Quantization + +SP currently does not support quantization and is under adaptation. + +| | VL + Dense | VL + MoE | non-VL + Dense | non-VL + MoE | +| -------------------- | ---------- | -------- | -------------- | ------------ | +| Sequence Parallelism | x | x | x | x | +| Flash Comm V1 | x | x | eager/graph | eager/graph | diff --git a/tests/e2e/multicard/2-cards/test_sp_pass.py b/tests/e2e/multicard/2-cards/test_sp_pass.py new file mode 100644 index 00000000..f5ac722f --- /dev/null +++ b/tests/e2e/multicard/2-cards/test_sp_pass.py @@ -0,0 +1,64 @@ +import os + +import pytest +from vllm import SamplingParams + +from tests.e2e.conftest import VllmRunner +from tests.e2e.model_utils import check_outputs_equal + +MODELS = [ + "Qwen/Qwen3-VL-2B-Instruct", +] + + +@pytest.mark.parametrize("model", MODELS) +def test_qwen3_vl_sp_tp2(model: str) -> None: + prompts = [ + "Hello, my name is", "The capital of the United States is", + "The capital of France is", "The future of AI is" + ] + sampling_params = SamplingParams(max_tokens=10, temperature=0.0) + + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + compilation_config={ + "cudagraph_capture_sizes": [2, 4], + "cudagraph_mode": "FULL_DECODE_ONLY", + "pass_config": {"enable_sp": False} + }, + additional_config={"npugraph_ex_config": {"enable": False}} + ) as runner: + no_sp_outputs = runner.model.generate(prompts, sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + compilation_config={ + "cudagraph_capture_sizes": [2, 4], + "cudagraph_mode": "FULL_DECODE_ONLY", + "pass_config": {"enable_sp": True} + }, + additional_config={"sp_threshold": 10, "npugraph_ex_config": {"enable": False}} + ) as runner: + sp_outputs = runner.model.generate( + prompts, sampling_params) + + no_sp_outputs_list = [] + for output in no_sp_outputs: + no_sp_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + sp_outputs_list = [] + for output in sp_outputs: + sp_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + check_outputs_equal( + outputs_0_lst=no_sp_outputs_list, + outputs_1_lst=sp_outputs_list, + name_0="no_sp_outputs", + name_1="sp_outputs", + ) diff --git a/tests/ut/ops/test_mla.py b/tests/ut/ops/test_mla.py index 22503a57..8080179f 100644 --- a/tests/ut/ops/test_mla.py +++ b/tests/ut/ops/test_mla.py @@ -157,7 +157,7 @@ class TestAscendMultiHeadLatentAttention(TestBase): hidden_states = torch.randn(3, self.hidden_size) mock_forward_context = MagicMock(spec=ForwardContext) - mock_forward_context.sp_enabled = False + mock_forward_context.flash_comm_v1_enabled = False mock_get_forward_context.return_value = mock_forward_context mock_mla_forward.return_value = (3, self.hidden_size) diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 4c363505..5f9ab2f7 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -390,7 +390,7 @@ class TestEagleProposerDummyRun(TestBase): # cpu does not support parallel-group, let alone `sp` @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context", - **{"return_value.sp_enabled": False}) + **{"return_value.flash_comm_v1_enabled": False}) @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_basic(self, mock_context, mock_get_context): num_tokens = 32 @@ -406,7 +406,7 @@ class TestEagleProposerDummyRun(TestBase): # cpu does not support parallel-group, let alone `sp` @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context", - **{"return_value.sp_enabled": False}) + **{"return_value.flash_comm_v1_enabled": False}) @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_with_prefill(self, mock_context, mock_get_context): mock_context.return_value.__enter__.return_value = None @@ -426,7 +426,7 @@ class TestEagleProposerDummyRun(TestBase): mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL mock_return_context.capturing = True # cpu does not support parallel-group, let alone `sp` - mock_return_context.sp_enabled = False + mock_return_context.flash_comm_v1_enabled = False mock_get_context.return_value = mock_return_context self.proposer.use_cuda_graph = True # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` @@ -449,7 +449,7 @@ class TestEagleProposerDummyRun(TestBase): mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL mock_return_context.capturing = False # cpu does not support parallel-group, let alone `sp` - mock_return_context.sp_enabled = False + mock_return_context.flash_comm_v1_enabled = False mock_get_context.return_value = mock_return_context self.proposer.use_cuda_graph = True # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 3bc55be3..34e72d68 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -30,6 +30,7 @@ class AscendConfig: """ def __init__(self, vllm_config: "VllmConfig"): + self.vllm_config = vllm_config additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} xlite_graph_config = additional_config.get("xlite_graph_config", {}) @@ -160,6 +161,47 @@ class AscendConfig: stacklevel=2, ) + def update_compile_ranges_split_points(self): + vllm_config = self.vllm_config + if self.npugraph_ex_config.enable: + if self.npugraph_ex_config.fuse_allreduce_rms: + from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD + + new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points + new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD) + new_compile_ranges_split_points = sorted(new_compile_ranges_split_points) + vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points + logger.debug( + "set compile_ranges_split_points to " + "{new_compile_ranges_split_points} for matmul and allreduce fusion" + ) + + else: + new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points + if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True): + from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD + + new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points + new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD) + new_compile_ranges_split_points = sorted(new_compile_ranges_split_points) + vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points + logger.debug( + "set compile_ranges_split_points to " + "{new_compile_ranges_split_points} for matmul and allreduce fusion" + ) + + from vllm_ascend.utils import is_moe_model + + if vllm_config.compilation_config.pass_config.enable_sp and not is_moe_model(vllm_config): + from vllm_ascend.compilation.passes.sequence_parallelism import get_sp_threshold + + sp_threshold = get_sp_threshold(vllm_config) + new_compile_ranges_split_points.append(sp_threshold) + logger.debug(f"add {sp_threshold} to compile_ranges_split_points for sequence parallelism") + if len(new_compile_ranges_split_points) > len(vllm_config.compilation_config.compile_ranges_split_points): + new_compile_ranges_split_points = sorted(new_compile_ranges_split_points) + vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points + class FinegrainedTPConfig: """ diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index dcd53535..acaf79fe 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -12,7 +12,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import ( AscendDeviceType, - enable_sp, + enable_flash_comm_v1, flashcomm2_enable, get_ascend_device_type, has_layer_idx, @@ -92,22 +92,22 @@ def set_ascend_forward_context( # main model and drafter model may have different architecture is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config) if is_context_moe_model: - sp_enabled = enable_sp(vllm_config) and num_tokens is not None + flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None mmrs_fusion = False elif is_draft_model: # TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`. # Disable it to avoid more problems. - sp_enabled = False + flash_comm_v1_enabled = False else: - sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000 - + flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None and num_tokens > 1000 forward_context.mmrs_fusion = mmrs_fusion forward_context.num_tokens = num_tokens - forward_context.sp_enabled = sp_enabled + forward_context.flash_comm_v1_enabled = flash_comm_v1_enabled # TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None - if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled: + forward_context.pad_size = 0 + if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled: pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size forward_context.pad_size = pad_size @@ -131,7 +131,7 @@ def set_ascend_forward_context( dp_world_size = get_dp_group().world_size if dp_world_size > 1 and forward_context.dp_metadata is not None: max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item() - if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled: + if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled: padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size pad_size = padded_length - num_tokens forward_context.padded_length = padded_length diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 69c2a377..2c67a185 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import copy import functools from collections.abc import Callable from typing import Any @@ -129,6 +130,10 @@ class AscendCompiler(CompilerInterface): compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: + # inductor can inplace modify the graph, so we need to copy it + # see https://github.com/pytorch/pytorch/issues/138980 + graph = copy.deepcopy(graph) + npugraph_ex_config = get_ascend_config().npugraph_ex_config if npugraph_ex_config.enable: assert hasattr(self, "vllm_config") diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 6ec6b1d0..29275fec 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -70,3 +70,8 @@ class GraphFusionPassManager: from .passes.allreduce_rmsnorm_fusion_pass import MatmulAllReduceAddRMSNormPass self.passes.append(MatmulAllReduceAddRMSNormPass(config)) + + if config.compilation_config.pass_config.enable_sp: + from .passes.sequence_parallelism import AscendSequenceParallelismPass + + self.passes.append(AscendSequenceParallelismPass(config)) diff --git a/vllm_ascend/compilation/passes/sequence_parallelism.py b/vllm_ascend/compilation/passes/sequence_parallelism.py new file mode 100644 index 00000000..d60d2f0e --- /dev/null +++ b/vllm_ascend/compilation/passes/sequence_parallelism.py @@ -0,0 +1,202 @@ +import torch +import torch._inductor.pattern_matcher as pm +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm_ascend.utils import is_moe_model, vllm_version_is + +if vllm_version_is("0.15.0"): + from vllm.compilation.vllm_inductor_pass import VllmInductorPass # type: ignore +else: + from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass +from vllm.config import VllmConfig +from vllm.config.utils import Range +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce +from vllm.logger import logger + +SP_THRESHOLD = 1000 + + +def get_sp_threshold(config: VllmConfig): + if is_moe_model(config): + return 1 + + additional_config = config.additional_config if config.additional_config is not None else {} + return additional_config.get("sp_threshold", SP_THRESHOLD) + + +class _SequenceParallelPatternHelper: + """Helper for sequence parallelism patterns.""" + + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + ): + self.eps = epsilon + self.dtype = dtype + self.device = device + self.tp_group = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tp_group().rank_in_group + + def _all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return tensor_model_parallel_all_reduce(x) + + def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.reduce_scatter(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name) + + def _all_gather(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name) + + def empty(self, *args, **kws): + return torch.empty(*args, dtype=self.dtype, device="npu", **kws) + + +class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device()) + + def empty(self, *args, **kws): + return torch.empty(*args, dtype=self.dtype, device="npu", **kws) + + def get_inputs(self): + """ + Generate example inputs. + """ + input = self.empty(8, 16) + weight = self.empty(16) + residual = self.empty(8, 16) + return [input, weight, residual] + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + x = self._all_reduce(input) + result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(x, residual, weight, None, self.eps) + + return result, residual + + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(input) + residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual) + result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias( + reduce_scatter, residual, weight, None, self.eps + ) + all_gather = self._all_gather(result) + return all_gather, residual + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) + + +class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device()) + + def get_inputs(self): + input = self.empty(8, 16) + weight = self.empty(16) + residual = self.empty(8, 16) + return [input, weight, residual] + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> torch.Tensor: + x = self._all_reduce(input) + result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(x, residual, weight, None, self.eps) + + return result + + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> torch.Tensor: + reduce_scatter = self._reduce_scatter(input) + residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual) + result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(reduce_scatter, residual, weight, None, self.eps) + all_gather = self._all_gather(result) + return all_gather + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) + + +class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device()) + + def get_inputs(self): + input = self.empty(8, 16) + weight = self.empty(16) + residual = self.empty(8, 16) + deepstack_input_embeds = self.empty(8, 16) + return [input, weight, residual, deepstack_input_embeds] + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + deepstack_input_embeds: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + x = self._all_reduce(input) + add_ = x + deepstack_input_embeds + result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps) + + return result, residual + + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + deepstack_input_embeds: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(input) + chunk = deepstack_input_embeds.chunk(self.tp_size)[self.tp_rank] + add_ = reduce_scatter + chunk + residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual) + result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps) + all_gather = self._all_gather(result) + return all_gather, residual + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) + + +class AscendSequenceParallelismPass(VllmInductorPass): + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_pass") + + for epsilon in [1e-5, 1e-6]: + AscendMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns) + + AscendLastAllReduceRMSNormPattern(config, epsilon).register(self.patterns) + + AscendQwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns) + + self.min_tokens = get_sp_threshold(config) + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + self.end_and_log() + + def is_applicable_for_range(self, compile_range: Range) -> bool: + """ + Check if the pass is applicable for the current configuration. + """ + applicable = compile_range.start >= self.min_tokens + logger.debug(f"SequenceParallelismPass {compile_range=} {applicable=}") + return applicable diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index d039d1fc..5f9e8553 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -440,7 +440,7 @@ class AscendFusedMoE(FusedMoE): hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare( hidden_states=hidden_states, router_logits=router_logits, - replace_allreduce=forward_context.sp_enabled, + replace_allreduce=forward_context.flash_comm_v1_enabled, enable_shared_expert_dp=self.enable_shared_expert_dp, quant_type=self.quant_type, ) diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index b2118606..53bd6e06 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op -from vllm_ascend.utils import enable_sp, maybe_trans_nz +from vllm_ascend.utils import enable_flash_comm_v1, maybe_trans_nz class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): @@ -240,7 +240,7 @@ class AscendRowParallelLinear(RowParallelLinear): disable_tp: bool = False, ): # TODO(kunpengW-code): Specifying the prefix in linear layers of some models in the vLLM. - if enable_sp(): + if enable_flash_comm_v1(): compilation_config = get_current_vllm_config().compilation_config unique_prefix = prefix if prefix in compilation_config.static_forward_context: diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 4aa558f4..9c78418f 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -70,7 +70,7 @@ from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager from vllm_ascend.utils import ( enable_dsa_cp, enable_dsa_cp_with_layer_shard, - enable_sp, + enable_flash_comm_v1, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, get_weight_prefetch_method, @@ -368,7 +368,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): else: output = output_parallel - if not forward_context.sp_enabled: + if not forward_context.flash_comm_v1_enabled: # flashcomm1 not enabled output = get_tp_group().all_gather(output, 0) if num_padding_tokens > 0: @@ -466,7 +466,7 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp): # Matrix multiply. assert self.quant_method is not None - if enable_sp(): + if enable_flash_comm_v1(): input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) # Trigger async broadcast before matmul to overlap communication. @@ -515,15 +515,15 @@ class SequenceRowParallelOp(CustomRowParallelOp): assert self.quant_method is not None try: forward_context = get_forward_context() - sp_enabled = forward_context.sp_enabled + flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled mmrs_fusion = forward_context.mmrs_fusion except AssertionError: - sp_enabled = False + flash_comm_v1_enabled = False mmrs_fusion = False x = input_parallel - if not sp_enabled: + if not flash_comm_v1_enabled: output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_) return tensor_model_parallel_all_reduce(output_parallel) @@ -649,7 +649,7 @@ def _get_column_parallel_op( if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")): return Flashcomm2OshardQKVParallelOp(layer) - if enable_sp(): + if enable_flash_comm_v1(): if "shared_expert" in prefix: return None sp_column_prefix = [ @@ -688,7 +688,7 @@ def _get_row_parallel_op( if flashcomm2_enable(): if "o_proj" in prefix or "out_proj" in prefix: return Flashcomm2OProjRowParallelOp(layer) - if enable_sp(): + if enable_flash_comm_v1(): if "shared_expert" in prefix: return None sp_row_prefixes = [ diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index 6f02cecd..e7f1f779 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -150,7 +150,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): kv_cache: torch.Tensor | None = None, attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: - need_gather_q_kv = get_forward_context().sp_enabled + need_gather_q_kv = get_forward_context().flash_comm_v1_enabled output_shape = hidden_states.shape # FIXME: This does not seem right, should make sure the buffer is fixed output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index f9404f0b..348c936a 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -26,8 +26,6 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch return residual if x.size(0) != residual.size(0): - sp_enabled = forward_context.sp_enabled - assert sp_enabled is True, "Currently, this situation only occurs when sp is enabled" pad_size = forward_context.pad_size if pad_size > 0: residual = F.pad(residual, (0, 0, 0, pad_size)) @@ -44,8 +42,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c except AssertionError: return x - sp_enabled = forward_context.sp_enabled - if sp_enabled and label: + flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled + if flash_comm_v1_enabled and label: dp_metadata = forward_context.dp_metadata if dp_metadata is None or not is_ep_comm: x = tensor_model_parallel_all_gather(x, 0) @@ -75,7 +73,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor except AssertionError: return tensor_model_parallel_all_reduce(x) - if not getattr(forward_context, "sp_enabled", False): + if not getattr(forward_context, "flash_comm_v1_enabled", False): return tensor_model_parallel_all_reduce(x) dp_metadata = forward_context.dp_metadata @@ -99,7 +97,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor: - if get_forward_context().sp_enabled and label: + if get_forward_context().flash_comm_v1_enabled and label: return torch.empty( (x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype ) @@ -108,7 +106,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor: - if get_forward_context().sp_enabled: + if get_forward_context().flash_comm_v1_enabled: return torch.empty( (x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype ) @@ -141,7 +139,10 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None: def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor: forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} or forward_context.sp_enabled: + if ( + moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} + or forward_context.flash_comm_v1_enabled + ): return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) @@ -161,7 +162,7 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str) forward_context = get_forward_context() self = forward_context.no_compile_layers[layer_name] num_tokens = input_parallel.size(0) - if forward_context.sp_enabled: + if forward_context.flash_comm_v1_enabled: num_tokens = num_tokens // self.tp_size output = torch.empty( size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype @@ -203,7 +204,7 @@ def _rope_forward_oot_impl_fake( direct_register_custom_op( op_name="maybe_chunk_residual", op_func=_maybe_chunk_residual_impl, - fake_impl=lambda x, residual: x, + fake_impl=lambda x, residual: torch.empty_like(x), mutates_args=[], dispatch_key="PrivateUse1", ) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index c5aeab71..ba209319 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -48,6 +48,7 @@ from vllm_ascend.utils import ( update_aclgraph_sizes, update_cudagraph_capture_sizes, is_310p, + enable_flash_comm_v1, ) if TYPE_CHECKING: @@ -198,32 +199,9 @@ class NPUPlatform(Platform): if not isinstance(ascend_compilation_config, dict) else ascend_compilation_config ) + ascend_config.update_compile_ranges_split_points() - if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True): - from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD - - new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points - new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD) - new_compile_ranges_split_points = sorted(new_compile_ranges_split_points) - vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points - logger.debug( - "set compile_ranges_split_points to " - "{new_compile_ranges_split_points} for matmul and allreduce fusion" - ) - - npugraph_ex_config = ascend_config.npugraph_ex_config - if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms: - from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD - - new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points - new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD) - new_compile_ranges_split_points = sorted(new_compile_ranges_split_points) - vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points - logger.debug( - "set compile_ranges_split_points to {new_compile_ranges_split_points} for matmul and allreduce fusion" - ) - - elif model_config and hasattr(model_config.hf_text_config, "index_topk"): + if model_config and hasattr(model_config.hf_text_config, "index_topk"): vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "") ascend_fusion_config = ascend_config.ascend_fusion_config @@ -417,15 +395,19 @@ class NPUPlatform(Platform): ) vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size - if is_vl_model(vllm_config): - if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool( - int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0")) - ): - raise ValueError( - "Currently, VL models doesn't support " - "FLASHCOMM in vllm-ascend. We will fix this in the future. " - "Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0." - ) + if enable_flash_comm_v1(): + assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \ + Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \ + For optimal performance with VL models, we recommend enabling Sequence Parallelism \ + via --compilation-config '{"pass_config": {"enable_sp": true}}'.""" + + assert vllm_config.parallel_config.tensor_parallel_size > 1, ( + "Flash Comm v1 is only supported when tp_size > 1." + ) + + assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, ( + "Flash Comm v1 requires enable_expert_parallel=True for MoE models." + ) # Set "PYTORCH_NPU_ALLOC_CONF=expandable_segments:True" by default to optimize NPU memory management. # Find more details at https://docs.vllm.ai/projects/ascend/en/latest/faqs.html#how-to-handle-the-out-of-memory-issue @@ -626,16 +608,16 @@ class NPUPlatform(Platform): # communication methods. mmrs_fusion = True if is_moe_model(vllm_config): - sp_enabled = enable_sp(vllm_config) and num_tokens is not None + flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None mmrs_fusion = False else: - sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000 + flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000 # TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None pad_size = None padded_length = None - if sp_enabled or flashcomm_v2_enabled: + if flash_comm_v1_enabled or flashcomm_v2_enabled: pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size if num_tokens is None and attn_metadata is not None: @@ -643,7 +625,7 @@ class NPUPlatform(Platform): dp_world_size = get_dp_group().world_size if dp_world_size > 1 and dp_metadata is not None: max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item() - if sp_enabled or flashcomm_v2_enabled: + if flash_comm_v1_enabled or flashcomm_v2_enabled: padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size pad_size = padded_length - num_tokens else: @@ -664,7 +646,7 @@ class NPUPlatform(Platform): "capturing": capturing, "mmrs_fusion": mmrs_fusion, "num_tokens": num_tokens, - "sp_enabled": sp_enabled, + "flash_comm_v1_enabled": flash_comm_v1_enabled, "flashcomm_v2_enabled": flashcomm_v2_enabled, "pad_size": pad_size, "padded_length": padded_length, diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 7609e681..15612c19 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -1171,7 +1171,7 @@ class EagleProposer(VllmEagleProposer): positions = positions.squeeze(-1) else: forward_context = get_forward_context() - if forward_context.sp_enabled: + if forward_context.flash_comm_v1_enabled: hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states) return hidden_states, positions @@ -1191,7 +1191,7 @@ class EagleProposer(VllmEagleProposer): hidden_states = last_hidden_states else: forward_context = get_forward_context() - if forward_context.sp_enabled: + if forward_context.flash_comm_v1_enabled: last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( last_hidden_states.contiguous(), True ) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 6a9a236c..dde1699d 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -724,6 +724,19 @@ def matmul_allreduce_enable() -> bool: return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE +def enable_flash_comm_v1(): + return ( + envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1 + # Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1 + # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. + or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) + ) + + +def enable_sp_by_pass(vllm_config: VllmConfig): + return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp + + def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool: global _ENABLE_SP if _ENABLE_SP is None: @@ -731,29 +744,12 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool: from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() - _ENABLE_SP = ( - vllm_config.compilation_config.pass_config.enable_sp - or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1 - # Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1 - # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. - or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) - ) + _ENABLE_SP = enable_sp_by_pass(vllm_config) or enable_flash_comm_v1() if not _ENABLE_SP and enable_shared_expert_dp: _ENABLE_SP = True logger.info("shared_expert_dp requires enable_sp = True. has set enable_sp to True") - if not _ENABLE_SP: - return _ENABLE_SP - - assert vllm_config.parallel_config.tensor_parallel_size > 1, ( - "Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1." - ) - - assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, ( - "Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models." - ) - return _ENABLE_SP @@ -1113,7 +1109,7 @@ def enable_dsa_cp() -> bool: is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr( vllm_config.model_config.hf_text_config, "index_topk" ) - return bool(is_ds_v32 and enable_sp()) + return bool(is_ds_v32 and enable_flash_comm_v1()) @lru_cache(maxsize=1) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8d80b719..9a5ae496 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -112,6 +112,7 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.utils import ( + enable_flash_comm_v1, enable_sp, is_drafter_moe_model, is_moe_model, @@ -1677,7 +1678,7 @@ class NPUModelRunner(GPUModelRunner): self.speculative_config, positions.shape[0], ) - if get_forward_context().sp_enabled and not isinstance(hidden_states, IntermediateTensors): + if get_forward_context().flash_comm_v1_enabled and not isinstance(hidden_states, IntermediateTensors): hidden_states = self._all_gather_hidden_states_and_aux(hidden_states) return hidden_states @@ -1685,7 +1686,7 @@ class NPUModelRunner(GPUModelRunner): # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if enable_sp(): + if enable_sp(self.vllm_config): return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens @@ -2223,7 +2224,7 @@ class NPUModelRunner(GPUModelRunner): # tp_size; otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading # to incorrect memory estimation and potentially causing OOM. intermediate_tokens = num_tokens_padded - if enable_sp(): + if enable_flash_comm_v1(): tp_size = get_tensor_model_parallel_world_size() intermediate_tokens = (num_tokens_padded + tp_size - 1) // tp_size if self.intermediate_tensors is None: diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index be6070a2..440bd75a 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -55,7 +55,7 @@ from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton from vllm_ascend.utils import ( AscendDeviceType, check_ascend_device_type, - enable_sp, + enable_flash_comm_v1, get_ascend_device_type, register_ascend_customop, ) @@ -376,7 +376,7 @@ class NPUWorker(WorkerBase): if forward_pass and not get_pp_group().is_first_rank: # If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise # it will conflict with the all-gather operation in flashcomm1. - if enable_sp(): + if enable_flash_comm_v1(): all_gather_group = None else: all_gather_group = get_tp_group() @@ -393,7 +393,7 @@ class NPUWorker(WorkerBase): assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank # If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise # it will conflict with the all-gather operation in flashcomm1. - if enable_sp(): + if enable_flash_comm_v1(): all_gather_group = None else: all_gather_group = get_tp_group()