[Feat]support sequence parallelism by pass for VL models (#5632)

This commit is contained in:
realliujiaxu
2026-02-27 08:27:41 +08:00
committed by GitHub
parent ed175d6d92
commit 5def28dcd3
22 changed files with 460 additions and 101 deletions

View File

@@ -132,6 +132,8 @@ e2e-multicard-2-cards:
estimated_time: 215
- name: tests/e2e/multicard/2-cards/test_disaggregated_encoder.py
estimated_time: 90
- name: tests/e2e/multicard/2-cards/test_sp_pass.py
estimated_time: 300
e2e-multicard-4-cards:
# TODO: recover skipped tests

View File

@@ -45,6 +45,7 @@ The following table lists additional configuration options available in vLLM Asc
| `pa_shape_list` | list | `[]` | The custom shape list of page attention ops. |
| `enable_kv_nz` | bool | `False` | Whether to enable KV cache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear |
| `sp_threshold` | int | `1000` | For dense models, only num_tokens > threshold will enable sequence parallelism. |
The details of each configuration option are as follows:

View File

@@ -24,4 +24,5 @@ speculative_decoding
context_parallel
npugraph_ex
weight_prefetch
sequence_parallelism
:::

View File

@@ -0,0 +1,57 @@
# Sequence Parallelism
## What is Sequence Parallelism
Sequence Parallelism (SP) was first introduced in [Megatron](https://arxiv.org/pdf/2205.05198), with the original intention of reducing training activation memory. The core modification was changing `Allreduce->LayerNorm` to `ReduceScatter->LayerNorm->Allgather`. This technique was later applied to inference by vllm. It should be noted that splitting Allreduce into ReduceScatter and Allgather does not inherently bring performance benefits; it reduces the computation load of LayerNorm, but this gain is minimal. The real benefits of SP come from:
1. LLM inference deployment often uses quantization. Taking INT8 quantization commonly used on NPUs as an example, after LayerNorm, a Quant operator quantizes the hidden states from BF16 to INT8. The communication volume of Allgather is halved, and the time consumption is almost halved.
2. ReduceScatter and Allgather can be fused with the preceding and following Matmul operations respectively into communication-computation parallel operators, reducing latency.
## How to Use
Currently, vllm-ascend has implemented Sequence Parallelism for VL-class models based on the Inductor pass. It can be enabled in the following way:
```bash
vllm serve Qwen/Qwen3-VL-2B-Instruct \
--tensor-parallel-size 2 \
--compilation-config '{"pass_config": {"enable_sp": true}}' \
--additional_config={"sp_threshold": 1000}
```
- `"pass_config": {"enable_sp": true}`: This is the switch for SP. Since SP relies on graph mode, it must be enabled and is not supported in eager mode.
- `--additional_config={"sp_threshold": 1000}`: Based on our experiments, when the number of tokens is small (empirical value is less than 1000), SP can actually bring negative benefits. This is because when the communication volume is small, the fixed overhead of the communication operator becomes the dominant factor. Therefore, when one communication operator (Allreduce) is split into two communication operators (ReduceScatter+Allgather), the end-to-end latency often becomes longer. Thus, we have reserved the `sp_threshold`parameter; SP will only take effect when `num_tokens >= sp_threshold`. **The default value is 1000, which generally does not need to be modified.** `sp_threshold` will be appended into `compile_ranges_split_points`, which is a parameter provided by vllm that splits the graph compilation range `[1, max_num_batched_tokens]` into `{[1, split_points[0]], [split_points[0] + 1, split_points[1]], ..., [split_points[-1] + 1, max_num_batched_tokens]}`, and sequentially checks whether the `is_applicable_for_range` of the pass returns `True`.
Without modifying `sp_threshold`, the simplest way and recommended way to enable SP is:
```bash
vllm serve Qwen/Qwen3-VL-2B-Instruct \
--tensor-parallel-size 2 \
--compilation-config '{"pass_config": {"enable_sp": true}}'
```
## Difference Between SP and Flash Comm V1
[Flash Comm V1 (FC1)](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/ascend-inference-cluster-flashcomm.md) is an enhanced version of Sequence Parallelism developed based on NPU. The enhancements include:
1. For models using the MLA structure, Allgather is postponed until after QKV projection, further reducing communication volume.
2. For MoE models, Allgather is postponed until after Gating+DynamicQuant, also aiming to reduce communication volume.
FC1 is a unique optimization in vllm-ascend, currently implemented based on Custom OP, but it is difficult to support VL-class models (reasons detailed in [[RFC]: support sequence parallelism by pass](https://github.com/vllm-project/vllm-ascend/issues/5712) ). Therefore, currently FC1 and SP are complementary.
## Support Matrix
### Without Quantization
| | VL + Dense | VL + MoE | non-VL + Dense | non-VL + MoE |
| -------------------- | ---------- | -------- | -------------- | ------------ |
| Sequence Parallelism | graph | x | x | x |
| Flash Comm V1 | x | x | eager/graph | eager/graph |
### With Quantization
SP currently does not support quantization and is under adaptation.
| | VL + Dense | VL + MoE | non-VL + Dense | non-VL + MoE |
| -------------------- | ---------- | -------- | -------------- | ------------ |
| Sequence Parallelism | x | x | x | x |
| Flash Comm V1 | x | x | eager/graph | eager/graph |

View File

@@ -0,0 +1,64 @@
import os
import pytest
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
MODELS = [
"Qwen/Qwen3-VL-2B-Instruct",
]
@pytest.mark.parametrize("model", MODELS)
def test_qwen3_vl_sp_tp2(model: str) -> None:
prompts = [
"Hello, my name is", "The capital of the United States is",
"The capital of France is", "The future of AI is"
]
sampling_params = SamplingParams(max_tokens=10, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
compilation_config={
"cudagraph_capture_sizes": [2, 4],
"cudagraph_mode": "FULL_DECODE_ONLY",
"pass_config": {"enable_sp": False}
},
additional_config={"npugraph_ex_config": {"enable": False}}
) as runner:
no_sp_outputs = runner.model.generate(prompts, sampling_params)
with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
compilation_config={
"cudagraph_capture_sizes": [2, 4],
"cudagraph_mode": "FULL_DECODE_ONLY",
"pass_config": {"enable_sp": True}
},
additional_config={"sp_threshold": 10, "npugraph_ex_config": {"enable": False}}
) as runner:
sp_outputs = runner.model.generate(
prompts, sampling_params)
no_sp_outputs_list = []
for output in no_sp_outputs:
no_sp_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
sp_outputs_list = []
for output in sp_outputs:
sp_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
check_outputs_equal(
outputs_0_lst=no_sp_outputs_list,
outputs_1_lst=sp_outputs_list,
name_0="no_sp_outputs",
name_1="sp_outputs",
)

View File

@@ -157,7 +157,7 @@ class TestAscendMultiHeadLatentAttention(TestBase):
hidden_states = torch.randn(3, self.hidden_size)
mock_forward_context = MagicMock(spec=ForwardContext)
mock_forward_context.sp_enabled = False
mock_forward_context.flash_comm_v1_enabled = False
mock_get_forward_context.return_value = mock_forward_context
mock_mla_forward.return_value = (3, self.hidden_size)

View File

@@ -390,7 +390,7 @@ class TestEagleProposerDummyRun(TestBase):
# cpu does not support parallel-group, let alone `sp`
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
**{"return_value.sp_enabled": False})
**{"return_value.flash_comm_v1_enabled": False})
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_basic(self, mock_context, mock_get_context):
num_tokens = 32
@@ -406,7 +406,7 @@ class TestEagleProposerDummyRun(TestBase):
# cpu does not support parallel-group, let alone `sp`
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
**{"return_value.sp_enabled": False})
**{"return_value.flash_comm_v1_enabled": False})
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
mock_context.return_value.__enter__.return_value = None
@@ -426,7 +426,7 @@ class TestEagleProposerDummyRun(TestBase):
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
mock_return_context.capturing = True
# cpu does not support parallel-group, let alone `sp`
mock_return_context.sp_enabled = False
mock_return_context.flash_comm_v1_enabled = False
mock_get_context.return_value = mock_return_context
self.proposer.use_cuda_graph = True
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
@@ -449,7 +449,7 @@ class TestEagleProposerDummyRun(TestBase):
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
mock_return_context.capturing = False
# cpu does not support parallel-group, let alone `sp`
mock_return_context.sp_enabled = False
mock_return_context.flash_comm_v1_enabled = False
mock_get_context.return_value = mock_return_context
self.proposer.use_cuda_graph = True
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`

View File

@@ -30,6 +30,7 @@ class AscendConfig:
"""
def __init__(self, vllm_config: "VllmConfig"):
self.vllm_config = vllm_config
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
xlite_graph_config = additional_config.get("xlite_graph_config", {})
@@ -160,6 +161,47 @@ class AscendConfig:
stacklevel=2,
)
def update_compile_ranges_split_points(self):
vllm_config = self.vllm_config
if self.npugraph_ex_config.enable:
if self.npugraph_ex_config.fuse_allreduce_rms:
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to "
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
)
else:
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to "
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
)
from vllm_ascend.utils import is_moe_model
if vllm_config.compilation_config.pass_config.enable_sp and not is_moe_model(vllm_config):
from vllm_ascend.compilation.passes.sequence_parallelism import get_sp_threshold
sp_threshold = get_sp_threshold(vllm_config)
new_compile_ranges_split_points.append(sp_threshold)
logger.debug(f"add {sp_threshold} to compile_ranges_split_points for sequence parallelism")
if len(new_compile_ranges_split_points) > len(vllm_config.compilation_config.compile_ranges_split_points):
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
class FinegrainedTPConfig:
"""

View File

@@ -12,7 +12,7 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (
AscendDeviceType,
enable_sp,
enable_flash_comm_v1,
flashcomm2_enable,
get_ascend_device_type,
has_layer_idx,
@@ -92,22 +92,22 @@ def set_ascend_forward_context(
# main model and drafter model may have different architecture
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
if is_context_moe_model:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None
mmrs_fusion = False
elif is_draft_model:
# TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`.
# Disable it to avoid more problems.
sp_enabled = False
flash_comm_v1_enabled = False
else:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens
forward_context.sp_enabled = sp_enabled
forward_context.flash_comm_v1_enabled = flash_comm_v1_enabled
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
forward_context.pad_size = 0
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
@@ -131,7 +131,7 @@ def set_ascend_forward_context(
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
forward_context.padded_length = padded_length

View File

@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import functools
from collections.abc import Callable
from typing import Any
@@ -129,6 +130,10 @@ class AscendCompiler(CompilerInterface):
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
npugraph_ex_config = get_ascend_config().npugraph_ex_config
if npugraph_ex_config.enable:
assert hasattr(self, "vllm_config")

View File

@@ -70,3 +70,8 @@ class GraphFusionPassManager:
from .passes.allreduce_rmsnorm_fusion_pass import MatmulAllReduceAddRMSNormPass
self.passes.append(MatmulAllReduceAddRMSNormPass(config))
if config.compilation_config.pass_config.enable_sp:
from .passes.sequence_parallelism import AscendSequenceParallelismPass
self.passes.append(AscendSequenceParallelismPass(config))

View File

@@ -0,0 +1,202 @@
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm_ascend.utils import is_moe_model, vllm_version_is
if vllm_version_is("0.15.0"):
from vllm.compilation.vllm_inductor_pass import VllmInductorPass # type: ignore
else:
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce
from vllm.logger import logger
SP_THRESHOLD = 1000
def get_sp_threshold(config: VllmConfig):
if is_moe_model(config):
return 1
additional_config = config.additional_config if config.additional_config is not None else {}
return additional_config.get("sp_threshold", SP_THRESHOLD)
class _SequenceParallelPatternHelper:
"""Helper for sequence parallelism patterns."""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
):
self.eps = epsilon
self.dtype = dtype
self.device = device
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return tensor_model_parallel_all_reduce(x)
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.reduce_scatter(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
def empty(self, *args, **kws):
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def empty(self, *args, **kws):
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
def get_inputs(self):
"""
Generate example inputs.
"""
input = self.empty(8, 16)
weight = self.empty(16)
residual = self.empty(8, 16)
return [input, weight, residual]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
x = self._all_reduce(input)
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(x, residual, weight, None, self.eps)
return result, residual
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual)
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
reduce_scatter, residual, weight, None, self.eps
)
all_gather = self._all_gather(result)
return all_gather, residual
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def get_inputs(self):
input = self.empty(8, 16)
weight = self.empty(16)
residual = self.empty(8, 16)
return [input, weight, residual]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
x = self._all_reduce(input)
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(x, residual, weight, None, self.eps)
return result
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
reduce_scatter = self._reduce_scatter(input)
residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual)
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(reduce_scatter, residual, weight, None, self.eps)
all_gather = self._all_gather(result)
return all_gather
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def get_inputs(self):
input = self.empty(8, 16)
weight = self.empty(16)
residual = self.empty(8, 16)
deepstack_input_embeds = self.empty(8, 16)
return [input, weight, residual, deepstack_input_embeds]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
deepstack_input_embeds: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
x = self._all_reduce(input)
add_ = x + deepstack_input_embeds
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
return result, residual
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
deepstack_input_embeds: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
chunk = deepstack_input_embeds.chunk(self.tp_size)[self.tp_rank]
add_ = reduce_scatter + chunk
residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual)
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
all_gather = self._all_gather(result)
return all_gather, residual
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AscendSequenceParallelismPass(VllmInductorPass):
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_pass")
for epsilon in [1e-5, 1e-6]:
AscendMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
AscendLastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
AscendQwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
self.min_tokens = get_sp_threshold(config)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
self.end_and_log()
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""
applicable = compile_range.start >= self.min_tokens
logger.debug(f"SequenceParallelismPass {compile_range=} {applicable=}")
return applicable

View File

@@ -440,7 +440,7 @@ class AscendFusedMoE(FusedMoE):
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
replace_allreduce=forward_context.sp_enabled,
replace_allreduce=forward_context.flash_comm_v1_enabled,
enable_shared_expert_dp=self.enable_shared_expert_dp,
quant_type=self.quant_type,
)

View File

@@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
from vllm_ascend.utils import enable_sp, maybe_trans_nz
from vllm_ascend.utils import enable_flash_comm_v1, maybe_trans_nz
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
@@ -240,7 +240,7 @@ class AscendRowParallelLinear(RowParallelLinear):
disable_tp: bool = False,
):
# TODO(kunpengW-code): Specifying the prefix in linear layers of some models in the vLLM.
if enable_sp():
if enable_flash_comm_v1():
compilation_config = get_current_vllm_config().compilation_config
unique_prefix = prefix
if prefix in compilation_config.static_forward_context:

View File

@@ -70,7 +70,7 @@ from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import (
enable_dsa_cp,
enable_dsa_cp_with_layer_shard,
enable_sp,
enable_flash_comm_v1,
flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
get_weight_prefetch_method,
@@ -368,7 +368,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
else:
output = output_parallel
if not forward_context.sp_enabled:
if not forward_context.flash_comm_v1_enabled:
# flashcomm1 not enabled
output = get_tp_group().all_gather(output, 0)
if num_padding_tokens > 0:
@@ -466,7 +466,7 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
# Matrix multiply.
assert self.quant_method is not None
if enable_sp():
if enable_flash_comm_v1():
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
# Trigger async broadcast before matmul to overlap communication.
@@ -515,15 +515,15 @@ class SequenceRowParallelOp(CustomRowParallelOp):
assert self.quant_method is not None
try:
forward_context = get_forward_context()
sp_enabled = forward_context.sp_enabled
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
mmrs_fusion = forward_context.mmrs_fusion
except AssertionError:
sp_enabled = False
flash_comm_v1_enabled = False
mmrs_fusion = False
x = input_parallel
if not sp_enabled:
if not flash_comm_v1_enabled:
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
return tensor_model_parallel_all_reduce(output_parallel)
@@ -649,7 +649,7 @@ def _get_column_parallel_op(
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
return Flashcomm2OshardQKVParallelOp(layer)
if enable_sp():
if enable_flash_comm_v1():
if "shared_expert" in prefix:
return None
sp_column_prefix = [
@@ -688,7 +688,7 @@ def _get_row_parallel_op(
if flashcomm2_enable():
if "o_proj" in prefix or "out_proj" in prefix:
return Flashcomm2OProjRowParallelOp(layer)
if enable_sp():
if enable_flash_comm_v1():
if "shared_expert" in prefix:
return None
sp_row_prefixes = [

View File

@@ -150,7 +150,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
kv_cache: torch.Tensor | None = None,
attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor:
need_gather_q_kv = get_forward_context().sp_enabled
need_gather_q_kv = get_forward_context().flash_comm_v1_enabled
output_shape = hidden_states.shape
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)

View File

@@ -26,8 +26,6 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch
return residual
if x.size(0) != residual.size(0):
sp_enabled = forward_context.sp_enabled
assert sp_enabled is True, "Currently, this situation only occurs when sp is enabled"
pad_size = forward_context.pad_size
if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size))
@@ -44,8 +42,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
except AssertionError:
return x
sp_enabled = forward_context.sp_enabled
if sp_enabled and label:
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
if flash_comm_v1_enabled and label:
dp_metadata = forward_context.dp_metadata
if dp_metadata is None or not is_ep_comm:
x = tensor_model_parallel_all_gather(x, 0)
@@ -75,7 +73,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
except AssertionError:
return tensor_model_parallel_all_reduce(x)
if not getattr(forward_context, "sp_enabled", False):
if not getattr(forward_context, "flash_comm_v1_enabled", False):
return tensor_model_parallel_all_reduce(x)
dp_metadata = forward_context.dp_metadata
@@ -99,7 +97,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor:
if get_forward_context().sp_enabled and label:
if get_forward_context().flash_comm_v1_enabled and label:
return torch.empty(
(x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
)
@@ -108,7 +106,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
if get_forward_context().sp_enabled:
if get_forward_context().flash_comm_v1_enabled:
return torch.empty(
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
)
@@ -141,7 +139,10 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} or forward_context.sp_enabled:
if (
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
or forward_context.flash_comm_v1_enabled
):
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
@@ -161,7 +162,7 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str)
forward_context = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
num_tokens = input_parallel.size(0)
if forward_context.sp_enabled:
if forward_context.flash_comm_v1_enabled:
num_tokens = num_tokens // self.tp_size
output = torch.empty(
size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype
@@ -203,7 +204,7 @@ def _rope_forward_oot_impl_fake(
direct_register_custom_op(
op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: x,
fake_impl=lambda x, residual: torch.empty_like(x),
mutates_args=[],
dispatch_key="PrivateUse1",
)

View File

@@ -48,6 +48,7 @@ from vllm_ascend.utils import (
update_aclgraph_sizes,
update_cudagraph_capture_sizes,
is_310p,
enable_flash_comm_v1,
)
if TYPE_CHECKING:
@@ -198,32 +199,9 @@ class NPUPlatform(Platform):
if not isinstance(ascend_compilation_config, dict)
else ascend_compilation_config
)
ascend_config.update_compile_ranges_split_points()
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to "
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
)
npugraph_ex_config = ascend_config.npugraph_ex_config
if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms:
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to {new_compile_ranges_split_points} for matmul and allreduce fusion"
)
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
if model_config and hasattr(model_config.hf_text_config, "index_topk"):
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
ascend_fusion_config = ascend_config.ascend_fusion_config
@@ -417,15 +395,19 @@ class NPUPlatform(Platform):
)
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
if is_vl_model(vllm_config):
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool(
int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))
):
raise ValueError(
"Currently, VL models doesn't support "
"FLASHCOMM in vllm-ascend. We will fix this in the future. "
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0."
)
if enable_flash_comm_v1():
assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \
Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \
For optimal performance with VL models, we recommend enabling Sequence Parallelism \
via --compilation-config '{"pass_config": {"enable_sp": true}}'."""
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
"Flash Comm v1 is only supported when tp_size > 1."
)
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
"Flash Comm v1 requires enable_expert_parallel=True for MoE models."
)
# Set "PYTORCH_NPU_ALLOC_CONF=expandable_segments:True" by default to optimize NPU memory management.
# Find more details at https://docs.vllm.ai/projects/ascend/en/latest/faqs.html#how-to-handle-the-out-of-memory-issue
@@ -626,16 +608,16 @@ class NPUPlatform(Platform):
# communication methods.
mmrs_fusion = True
if is_moe_model(vllm_config):
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False
else:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
pad_size = None
padded_length = None
if sp_enabled or flashcomm_v2_enabled:
if flash_comm_v1_enabled or flashcomm_v2_enabled:
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
if num_tokens is None and attn_metadata is not None:
@@ -643,7 +625,7 @@ class NPUPlatform(Platform):
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and dp_metadata is not None:
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
if sp_enabled or flashcomm_v2_enabled:
if flash_comm_v1_enabled or flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
else:
@@ -664,7 +646,7 @@ class NPUPlatform(Platform):
"capturing": capturing,
"mmrs_fusion": mmrs_fusion,
"num_tokens": num_tokens,
"sp_enabled": sp_enabled,
"flash_comm_v1_enabled": flash_comm_v1_enabled,
"flashcomm_v2_enabled": flashcomm_v2_enabled,
"pad_size": pad_size,
"padded_length": padded_length,

View File

@@ -1171,7 +1171,7 @@ class EagleProposer(VllmEagleProposer):
positions = positions.squeeze(-1)
else:
forward_context = get_forward_context()
if forward_context.sp_enabled:
if forward_context.flash_comm_v1_enabled:
hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states)
return hidden_states, positions
@@ -1191,7 +1191,7 @@ class EagleProposer(VllmEagleProposer):
hidden_states = last_hidden_states
else:
forward_context = get_forward_context()
if forward_context.sp_enabled:
if forward_context.flash_comm_v1_enabled:
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True
)

View File

@@ -724,6 +724,19 @@ def matmul_allreduce_enable() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
def enable_flash_comm_v1():
return (
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
)
def enable_sp_by_pass(vllm_config: VllmConfig):
return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp
def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
global _ENABLE_SP
if _ENABLE_SP is None:
@@ -731,29 +744,12 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
_ENABLE_SP = (
vllm_config.compilation_config.pass_config.enable_sp
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
)
_ENABLE_SP = enable_sp_by_pass(vllm_config) or enable_flash_comm_v1()
if not _ENABLE_SP and enable_shared_expert_dp:
_ENABLE_SP = True
logger.info("shared_expert_dp requires enable_sp = True. has set enable_sp to True")
if not _ENABLE_SP:
return _ENABLE_SP
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
"Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1."
)
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
"Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
)
return _ENABLE_SP
@@ -1113,7 +1109,7 @@ def enable_dsa_cp() -> bool:
is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk"
)
return bool(is_ds_v32 and enable_sp())
return bool(is_ds_v32 and enable_flash_comm_v1())
@lru_cache(maxsize=1)

View File

@@ -112,6 +112,7 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (
enable_flash_comm_v1,
enable_sp,
is_drafter_moe_model,
is_moe_model,
@@ -1677,7 +1678,7 @@ class NPUModelRunner(GPUModelRunner):
self.speculative_config,
positions.shape[0],
)
if get_forward_context().sp_enabled and not isinstance(hidden_states, IntermediateTensors):
if get_forward_context().flash_comm_v1_enabled and not isinstance(hidden_states, IntermediateTensors):
hidden_states = self._all_gather_hidden_states_and_aux(hidden_states)
return hidden_states
@@ -1685,7 +1686,7 @@ class NPUModelRunner(GPUModelRunner):
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if enable_sp():
if enable_sp(self.vllm_config):
return round_up(num_scheduled_tokens, tp_size)
return num_scheduled_tokens
@@ -2223,7 +2224,7 @@ class NPUModelRunner(GPUModelRunner):
# tp_size; otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading
# to incorrect memory estimation and potentially causing OOM.
intermediate_tokens = num_tokens_padded
if enable_sp():
if enable_flash_comm_v1():
tp_size = get_tensor_model_parallel_world_size()
intermediate_tokens = (num_tokens_padded + tp_size - 1) // tp_size
if self.intermediate_tensors is None:

View File

@@ -55,7 +55,7 @@ from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
from vllm_ascend.utils import (
AscendDeviceType,
check_ascend_device_type,
enable_sp,
enable_flash_comm_v1,
get_ascend_device_type,
register_ascend_customop,
)
@@ -376,7 +376,7 @@ class NPUWorker(WorkerBase):
if forward_pass and not get_pp_group().is_first_rank:
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
# it will conflict with the all-gather operation in flashcomm1.
if enable_sp():
if enable_flash_comm_v1():
all_gather_group = None
else:
all_gather_group = get_tp_group()
@@ -393,7 +393,7 @@ class NPUWorker(WorkerBase):
assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
# it will conflict with the all-gather operation in flashcomm1.
if enable_sp():
if enable_flash_comm_v1():
all_gather_group = None
else:
all_gather_group = get_tp_group()