[Feat]support sequence parallelism by pass for VL models (#5632)
This commit is contained in:
2
.github/workflows/scripts/config.yaml
vendored
2
.github/workflows/scripts/config.yaml
vendored
@@ -132,6 +132,8 @@ e2e-multicard-2-cards:
|
||||
estimated_time: 215
|
||||
- name: tests/e2e/multicard/2-cards/test_disaggregated_encoder.py
|
||||
estimated_time: 90
|
||||
- name: tests/e2e/multicard/2-cards/test_sp_pass.py
|
||||
estimated_time: 300
|
||||
|
||||
e2e-multicard-4-cards:
|
||||
# TODO: recover skipped tests
|
||||
|
||||
@@ -45,6 +45,7 @@ The following table lists additional configuration options available in vLLM Asc
|
||||
| `pa_shape_list` | list | `[]` | The custom shape list of page attention ops. |
|
||||
| `enable_kv_nz` | bool | `False` | Whether to enable KV cache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
||||
| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear |
|
||||
| `sp_threshold` | int | `1000` | For dense models, only num_tokens > threshold will enable sequence parallelism. |
|
||||
|
||||
The details of each configuration option are as follows:
|
||||
|
||||
|
||||
@@ -24,4 +24,5 @@ speculative_decoding
|
||||
context_parallel
|
||||
npugraph_ex
|
||||
weight_prefetch
|
||||
sequence_parallelism
|
||||
:::
|
||||
|
||||
57
docs/source/user_guide/feature_guide/sequence_parallelism.md
Normal file
57
docs/source/user_guide/feature_guide/sequence_parallelism.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Sequence Parallelism
|
||||
|
||||
## What is Sequence Parallelism
|
||||
|
||||
Sequence Parallelism (SP) was first introduced in [Megatron](https://arxiv.org/pdf/2205.05198), with the original intention of reducing training activation memory. The core modification was changing `Allreduce->LayerNorm` to `ReduceScatter->LayerNorm->Allgather`. This technique was later applied to inference by vllm. It should be noted that splitting Allreduce into ReduceScatter and Allgather does not inherently bring performance benefits; it reduces the computation load of LayerNorm, but this gain is minimal. The real benefits of SP come from:
|
||||
|
||||
1. LLM inference deployment often uses quantization. Taking INT8 quantization commonly used on NPUs as an example, after LayerNorm, a Quant operator quantizes the hidden states from BF16 to INT8. The communication volume of Allgather is halved, and the time consumption is almost halved.
|
||||
2. ReduceScatter and Allgather can be fused with the preceding and following Matmul operations respectively into communication-computation parallel operators, reducing latency.
|
||||
|
||||
## How to Use
|
||||
|
||||
Currently, vllm-ascend has implemented Sequence Parallelism for VL-class models based on the Inductor pass. It can be enabled in the following way:
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen3-VL-2B-Instruct \
|
||||
--tensor-parallel-size 2 \
|
||||
--compilation-config '{"pass_config": {"enable_sp": true}}' \
|
||||
--additional_config={"sp_threshold": 1000}
|
||||
```
|
||||
|
||||
- `"pass_config": {"enable_sp": true}`: This is the switch for SP. Since SP relies on graph mode, it must be enabled and is not supported in eager mode.
|
||||
- `--additional_config={"sp_threshold": 1000}`: Based on our experiments, when the number of tokens is small (empirical value is less than 1000), SP can actually bring negative benefits. This is because when the communication volume is small, the fixed overhead of the communication operator becomes the dominant factor. Therefore, when one communication operator (Allreduce) is split into two communication operators (ReduceScatter+Allgather), the end-to-end latency often becomes longer. Thus, we have reserved the `sp_threshold`parameter; SP will only take effect when `num_tokens >= sp_threshold`. **The default value is 1000, which generally does not need to be modified.** `sp_threshold` will be appended into `compile_ranges_split_points`, which is a parameter provided by vllm that splits the graph compilation range `[1, max_num_batched_tokens]` into `{[1, split_points[0]], [split_points[0] + 1, split_points[1]], ..., [split_points[-1] + 1, max_num_batched_tokens]}`, and sequentially checks whether the `is_applicable_for_range` of the pass returns `True`.
|
||||
|
||||
Without modifying `sp_threshold`, the simplest way and recommended way to enable SP is:
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen3-VL-2B-Instruct \
|
||||
--tensor-parallel-size 2 \
|
||||
--compilation-config '{"pass_config": {"enable_sp": true}}'
|
||||
```
|
||||
|
||||
## Difference Between SP and Flash Comm V1
|
||||
|
||||
[Flash Comm V1 (FC1)](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/ascend-inference-cluster-flashcomm.md) is an enhanced version of Sequence Parallelism developed based on NPU. The enhancements include:
|
||||
|
||||
1. For models using the MLA structure, Allgather is postponed until after QKV projection, further reducing communication volume.
|
||||
2. For MoE models, Allgather is postponed until after Gating+DynamicQuant, also aiming to reduce communication volume.
|
||||
|
||||
FC1 is a unique optimization in vllm-ascend, currently implemented based on Custom OP, but it is difficult to support VL-class models (reasons detailed in [[RFC]: support sequence parallelism by pass](https://github.com/vllm-project/vllm-ascend/issues/5712) ). Therefore, currently FC1 and SP are complementary.
|
||||
|
||||
## Support Matrix
|
||||
|
||||
### Without Quantization
|
||||
|
||||
| | VL + Dense | VL + MoE | non-VL + Dense | non-VL + MoE |
|
||||
| -------------------- | ---------- | -------- | -------------- | ------------ |
|
||||
| Sequence Parallelism | graph | x | x | x |
|
||||
| Flash Comm V1 | x | x | eager/graph | eager/graph |
|
||||
|
||||
### With Quantization
|
||||
|
||||
SP currently does not support quantization and is under adaptation.
|
||||
|
||||
| | VL + Dense | VL + MoE | non-VL + Dense | non-VL + MoE |
|
||||
| -------------------- | ---------- | -------- | -------------- | ------------ |
|
||||
| Sequence Parallelism | x | x | x | x |
|
||||
| Flash Comm V1 | x | x | eager/graph | eager/graph |
|
||||
64
tests/e2e/multicard/2-cards/test_sp_pass.py
Normal file
64
tests/e2e/multicard/2-cards/test_sp_pass.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from vllm import SamplingParams
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
from tests.e2e.model_utils import check_outputs_equal
|
||||
|
||||
MODELS = [
|
||||
"Qwen/Qwen3-VL-2B-Instruct",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
def test_qwen3_vl_sp_tp2(model: str) -> None:
|
||||
prompts = [
|
||||
"Hello, my name is", "The capital of the United States is",
|
||||
"The capital of France is", "The future of AI is"
|
||||
]
|
||||
sampling_params = SamplingParams(max_tokens=10, temperature=0.0)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
compilation_config={
|
||||
"cudagraph_capture_sizes": [2, 4],
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"pass_config": {"enable_sp": False}
|
||||
},
|
||||
additional_config={"npugraph_ex_config": {"enable": False}}
|
||||
) as runner:
|
||||
no_sp_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
compilation_config={
|
||||
"cudagraph_capture_sizes": [2, 4],
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"pass_config": {"enable_sp": True}
|
||||
},
|
||||
additional_config={"sp_threshold": 10, "npugraph_ex_config": {"enable": False}}
|
||||
) as runner:
|
||||
sp_outputs = runner.model.generate(
|
||||
prompts, sampling_params)
|
||||
|
||||
no_sp_outputs_list = []
|
||||
for output in no_sp_outputs:
|
||||
no_sp_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
sp_outputs_list = []
|
||||
for output in sp_outputs:
|
||||
sp_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=no_sp_outputs_list,
|
||||
outputs_1_lst=sp_outputs_list,
|
||||
name_0="no_sp_outputs",
|
||||
name_1="sp_outputs",
|
||||
)
|
||||
@@ -157,7 +157,7 @@ class TestAscendMultiHeadLatentAttention(TestBase):
|
||||
hidden_states = torch.randn(3, self.hidden_size)
|
||||
|
||||
mock_forward_context = MagicMock(spec=ForwardContext)
|
||||
mock_forward_context.sp_enabled = False
|
||||
mock_forward_context.flash_comm_v1_enabled = False
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_mla_forward.return_value = (3, self.hidden_size)
|
||||
|
||||
@@ -390,7 +390,7 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
|
||||
# cpu does not support parallel-group, let alone `sp`
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
||||
**{"return_value.sp_enabled": False})
|
||||
**{"return_value.flash_comm_v1_enabled": False})
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||
def test_dummy_run_basic(self, mock_context, mock_get_context):
|
||||
num_tokens = 32
|
||||
@@ -406,7 +406,7 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
|
||||
# cpu does not support parallel-group, let alone `sp`
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
||||
**{"return_value.sp_enabled": False})
|
||||
**{"return_value.flash_comm_v1_enabled": False})
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||
def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
|
||||
mock_context.return_value.__enter__.return_value = None
|
||||
@@ -426,7 +426,7 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
mock_return_context.capturing = True
|
||||
# cpu does not support parallel-group, let alone `sp`
|
||||
mock_return_context.sp_enabled = False
|
||||
mock_return_context.flash_comm_v1_enabled = False
|
||||
mock_get_context.return_value = mock_return_context
|
||||
self.proposer.use_cuda_graph = True
|
||||
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
||||
@@ -449,7 +449,7 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
mock_return_context.capturing = False
|
||||
# cpu does not support parallel-group, let alone `sp`
|
||||
mock_return_context.sp_enabled = False
|
||||
mock_return_context.flash_comm_v1_enabled = False
|
||||
mock_get_context.return_value = mock_return_context
|
||||
self.proposer.use_cuda_graph = True
|
||||
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
||||
|
||||
@@ -30,6 +30,7 @@ class AscendConfig:
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.vllm_config = vllm_config
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
|
||||
xlite_graph_config = additional_config.get("xlite_graph_config", {})
|
||||
@@ -160,6 +161,47 @@ class AscendConfig:
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def update_compile_ranges_split_points(self):
|
||||
vllm_config = self.vllm_config
|
||||
if self.npugraph_ex_config.enable:
|
||||
if self.npugraph_ex_config.fuse_allreduce_rms:
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to "
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
else:
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to "
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
from vllm_ascend.utils import is_moe_model
|
||||
|
||||
if vllm_config.compilation_config.pass_config.enable_sp and not is_moe_model(vllm_config):
|
||||
from vllm_ascend.compilation.passes.sequence_parallelism import get_sp_threshold
|
||||
|
||||
sp_threshold = get_sp_threshold(vllm_config)
|
||||
new_compile_ranges_split_points.append(sp_threshold)
|
||||
logger.debug(f"add {sp_threshold} to compile_ranges_split_points for sequence parallelism")
|
||||
if len(new_compile_ranges_split_points) > len(vllm_config.compilation_config.compile_ranges_split_points):
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
|
||||
|
||||
class FinegrainedTPConfig:
|
||||
"""
|
||||
|
||||
@@ -12,7 +12,7 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (
|
||||
AscendDeviceType,
|
||||
enable_sp,
|
||||
enable_flash_comm_v1,
|
||||
flashcomm2_enable,
|
||||
get_ascend_device_type,
|
||||
has_layer_idx,
|
||||
@@ -92,22 +92,22 @@ def set_ascend_forward_context(
|
||||
# main model and drafter model may have different architecture
|
||||
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
|
||||
if is_context_moe_model:
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None
|
||||
mmrs_fusion = False
|
||||
elif is_draft_model:
|
||||
# TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`.
|
||||
# Disable it to avoid more problems.
|
||||
sp_enabled = False
|
||||
flash_comm_v1_enabled = False
|
||||
else:
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||
|
||||
flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None and num_tokens > 1000
|
||||
forward_context.mmrs_fusion = mmrs_fusion
|
||||
forward_context.num_tokens = num_tokens
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
forward_context.flash_comm_v1_enabled = flash_comm_v1_enabled
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
||||
|
||||
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
|
||||
forward_context.pad_size = 0
|
||||
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
|
||||
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
||||
forward_context.pad_size = pad_size
|
||||
|
||||
@@ -131,7 +131,7 @@ def set_ascend_forward_context(
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
|
||||
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
|
||||
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
forward_context.padded_length = padded_length
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import copy
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -129,6 +130,10 @@ class AscendCompiler(CompilerInterface):
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
npugraph_ex_config = get_ascend_config().npugraph_ex_config
|
||||
if npugraph_ex_config.enable:
|
||||
assert hasattr(self, "vllm_config")
|
||||
|
||||
@@ -70,3 +70,8 @@ class GraphFusionPassManager:
|
||||
from .passes.allreduce_rmsnorm_fusion_pass import MatmulAllReduceAddRMSNormPass
|
||||
|
||||
self.passes.append(MatmulAllReduceAddRMSNormPass(config))
|
||||
|
||||
if config.compilation_config.pass_config.enable_sp:
|
||||
from .passes.sequence_parallelism import AscendSequenceParallelismPass
|
||||
|
||||
self.passes.append(AscendSequenceParallelismPass(config))
|
||||
|
||||
202
vllm_ascend/compilation/passes/sequence_parallelism.py
Normal file
202
vllm_ascend/compilation/passes/sequence_parallelism.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm_ascend.utils import is_moe_model, vllm_version_is
|
||||
|
||||
if vllm_version_is("0.15.0"):
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass # type: ignore
|
||||
else:
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.logger import logger
|
||||
|
||||
SP_THRESHOLD = 1000
|
||||
|
||||
|
||||
def get_sp_threshold(config: VllmConfig):
|
||||
if is_moe_model(config):
|
||||
return 1
|
||||
|
||||
additional_config = config.additional_config if config.additional_config is not None else {}
|
||||
return additional_config.get("sp_threshold", SP_THRESHOLD)
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper:
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
self.eps = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
|
||||
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.reduce_scatter(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
|
||||
|
||||
def empty(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
|
||||
|
||||
|
||||
class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def empty(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
Generate example inputs.
|
||||
"""
|
||||
input = self.empty(8, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x = self._all_reduce(input)
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(x, residual, weight, None, self.eps)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual)
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
reduce_scatter, residual, weight, None, self.eps
|
||||
)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather, residual
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(8, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = self._all_reduce(input)
|
||||
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(x, residual, weight, None, self.eps)
|
||||
|
||||
return result
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual)
|
||||
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(reduce_scatter, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(8, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
deepstack_input_embeds = self.empty(8, 16)
|
||||
return [input, weight, residual, deepstack_input_embeds]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x = self._all_reduce(input)
|
||||
add_ = x + deepstack_input_embeds
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
chunk = deepstack_input_embeds.chunk(self.tp_size)[self.tp_rank]
|
||||
add_ = reduce_scatter + chunk
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual)
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather, residual
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendSequenceParallelismPass(VllmInductorPass):
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_pass")
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
AscendMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AscendLastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AscendQwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
self.min_tokens = get_sp_threshold(config)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
"""
|
||||
Check if the pass is applicable for the current configuration.
|
||||
"""
|
||||
applicable = compile_range.start >= self.min_tokens
|
||||
logger.debug(f"SequenceParallelismPass {compile_range=} {applicable=}")
|
||||
return applicable
|
||||
@@ -440,7 +440,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
replace_allreduce=forward_context.sp_enabled,
|
||||
replace_allreduce=forward_context.flash_comm_v1_enabled,
|
||||
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
||||
quant_type=self.quant_type,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
|
||||
from vllm_ascend.utils import enable_sp, maybe_trans_nz
|
||||
from vllm_ascend.utils import enable_flash_comm_v1, maybe_trans_nz
|
||||
|
||||
|
||||
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
|
||||
@@ -240,7 +240,7 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
# TODO(kunpengW-code): Specifying the prefix in linear layers of some models in the vLLM.
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
unique_prefix = prefix
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
|
||||
@@ -70,7 +70,7 @@ from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
||||
from vllm_ascend.utils import (
|
||||
enable_dsa_cp,
|
||||
enable_dsa_cp_with_layer_shard,
|
||||
enable_sp,
|
||||
enable_flash_comm_v1,
|
||||
flashcomm2_enable,
|
||||
get_flashcomm2_reorgnized_batch_ids,
|
||||
get_weight_prefetch_method,
|
||||
@@ -368,7 +368,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if not forward_context.sp_enabled:
|
||||
if not forward_context.flash_comm_v1_enabled:
|
||||
# flashcomm1 not enabled
|
||||
output = get_tp_group().all_gather(output, 0)
|
||||
if num_padding_tokens > 0:
|
||||
@@ -466,7 +466,7 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||
|
||||
# Trigger async broadcast before matmul to overlap communication.
|
||||
@@ -515,15 +515,15 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
assert self.quant_method is not None
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
|
||||
mmrs_fusion = forward_context.mmrs_fusion
|
||||
except AssertionError:
|
||||
sp_enabled = False
|
||||
flash_comm_v1_enabled = False
|
||||
mmrs_fusion = False
|
||||
|
||||
x = input_parallel
|
||||
|
||||
if not sp_enabled:
|
||||
if not flash_comm_v1_enabled:
|
||||
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
|
||||
return tensor_model_parallel_all_reduce(output_parallel)
|
||||
|
||||
@@ -649,7 +649,7 @@ def _get_column_parallel_op(
|
||||
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
|
||||
return Flashcomm2OshardQKVParallelOp(layer)
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
sp_column_prefix = [
|
||||
@@ -688,7 +688,7 @@ def _get_row_parallel_op(
|
||||
if flashcomm2_enable():
|
||||
if "o_proj" in prefix or "out_proj" in prefix:
|
||||
return Flashcomm2OProjRowParallelOp(layer)
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
sp_row_prefixes = [
|
||||
|
||||
@@ -150,7 +150,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
kv_cache: torch.Tensor | None = None,
|
||||
attn_metadata: AttentionMetadata | None = None,
|
||||
) -> torch.Tensor:
|
||||
need_gather_q_kv = get_forward_context().sp_enabled
|
||||
need_gather_q_kv = get_forward_context().flash_comm_v1_enabled
|
||||
output_shape = hidden_states.shape
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
@@ -26,8 +26,6 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch
|
||||
return residual
|
||||
|
||||
if x.size(0) != residual.size(0):
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
assert sp_enabled is True, "Currently, this situation only occurs when sp is enabled"
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||
@@ -44,8 +42,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
||||
except AssertionError:
|
||||
return x
|
||||
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled and label:
|
||||
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
|
||||
if flash_comm_v1_enabled and label:
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
x = tensor_model_parallel_all_gather(x, 0)
|
||||
@@ -75,7 +73,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
if not getattr(forward_context, "sp_enabled", False):
|
||||
if not getattr(forward_context, "flash_comm_v1_enabled", False):
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
@@ -99,7 +97,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if get_forward_context().sp_enabled and label:
|
||||
if get_forward_context().flash_comm_v1_enabled and label:
|
||||
return torch.empty(
|
||||
(x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
@@ -108,7 +106,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if get_forward_context().sp_enabled:
|
||||
if get_forward_context().flash_comm_v1_enabled:
|
||||
return torch.empty(
|
||||
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
@@ -141,7 +139,10 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
|
||||
def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} or forward_context.sp_enabled:
|
||||
if (
|
||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
or forward_context.flash_comm_v1_enabled
|
||||
):
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
@@ -161,7 +162,7 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str)
|
||||
forward_context = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
num_tokens = input_parallel.size(0)
|
||||
if forward_context.sp_enabled:
|
||||
if forward_context.flash_comm_v1_enabled:
|
||||
num_tokens = num_tokens // self.tp_size
|
||||
output = torch.empty(
|
||||
size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype
|
||||
@@ -203,7 +204,7 @@ def _rope_forward_oot_impl_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_chunk_residual",
|
||||
op_func=_maybe_chunk_residual_impl,
|
||||
fake_impl=lambda x, residual: x,
|
||||
fake_impl=lambda x, residual: torch.empty_like(x),
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
@@ -48,6 +48,7 @@ from vllm_ascend.utils import (
|
||||
update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes,
|
||||
is_310p,
|
||||
enable_flash_comm_v1,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -198,32 +199,9 @@ class NPUPlatform(Platform):
|
||||
if not isinstance(ascend_compilation_config, dict)
|
||||
else ascend_compilation_config
|
||||
)
|
||||
ascend_config.update_compile_ranges_split_points()
|
||||
|
||||
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to "
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
npugraph_ex_config = ascend_config.npugraph_ex_config
|
||||
if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms:
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to {new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||
if model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
|
||||
|
||||
ascend_fusion_config = ascend_config.ascend_fusion_config
|
||||
@@ -417,15 +395,19 @@ class NPUPlatform(Platform):
|
||||
)
|
||||
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
|
||||
|
||||
if is_vl_model(vllm_config):
|
||||
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool(
|
||||
int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently, VL models doesn't support "
|
||||
"FLASHCOMM in vllm-ascend. We will fix this in the future. "
|
||||
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0."
|
||||
)
|
||||
if enable_flash_comm_v1():
|
||||
assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \
|
||||
Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \
|
||||
For optimal performance with VL models, we recommend enabling Sequence Parallelism \
|
||||
via --compilation-config '{"pass_config": {"enable_sp": true}}'."""
|
||||
|
||||
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
|
||||
"Flash Comm v1 is only supported when tp_size > 1."
|
||||
)
|
||||
|
||||
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
|
||||
"Flash Comm v1 requires enable_expert_parallel=True for MoE models."
|
||||
)
|
||||
|
||||
# Set "PYTORCH_NPU_ALLOC_CONF=expandable_segments:True" by default to optimize NPU memory management.
|
||||
# Find more details at https://docs.vllm.ai/projects/ascend/en/latest/faqs.html#how-to-handle-the-out-of-memory-issue
|
||||
@@ -626,16 +608,16 @@ class NPUPlatform(Platform):
|
||||
# communication methods.
|
||||
mmrs_fusion = True
|
||||
if is_moe_model(vllm_config):
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
mmrs_fusion = False
|
||||
else:
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
||||
pad_size = None
|
||||
padded_length = None
|
||||
if sp_enabled or flashcomm_v2_enabled:
|
||||
if flash_comm_v1_enabled or flashcomm_v2_enabled:
|
||||
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
||||
|
||||
if num_tokens is None and attn_metadata is not None:
|
||||
@@ -643,7 +625,7 @@ class NPUPlatform(Platform):
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and dp_metadata is not None:
|
||||
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if sp_enabled or flashcomm_v2_enabled:
|
||||
if flash_comm_v1_enabled or flashcomm_v2_enabled:
|
||||
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
else:
|
||||
@@ -664,7 +646,7 @@ class NPUPlatform(Platform):
|
||||
"capturing": capturing,
|
||||
"mmrs_fusion": mmrs_fusion,
|
||||
"num_tokens": num_tokens,
|
||||
"sp_enabled": sp_enabled,
|
||||
"flash_comm_v1_enabled": flash_comm_v1_enabled,
|
||||
"flashcomm_v2_enabled": flashcomm_v2_enabled,
|
||||
"pad_size": pad_size,
|
||||
"padded_length": padded_length,
|
||||
|
||||
@@ -1171,7 +1171,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
positions = positions.squeeze(-1)
|
||||
else:
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.sp_enabled:
|
||||
if forward_context.flash_comm_v1_enabled:
|
||||
hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states)
|
||||
return hidden_states, positions
|
||||
|
||||
@@ -1191,7 +1191,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.sp_enabled:
|
||||
if forward_context.flash_comm_v1_enabled:
|
||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
last_hidden_states.contiguous(), True
|
||||
)
|
||||
|
||||
@@ -724,6 +724,19 @@ def matmul_allreduce_enable() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||
|
||||
|
||||
def enable_flash_comm_v1():
|
||||
return (
|
||||
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
|
||||
)
|
||||
|
||||
|
||||
def enable_sp_by_pass(vllm_config: VllmConfig):
|
||||
return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp
|
||||
|
||||
|
||||
def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
||||
global _ENABLE_SP
|
||||
if _ENABLE_SP is None:
|
||||
@@ -731,29 +744,12 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
_ENABLE_SP = (
|
||||
vllm_config.compilation_config.pass_config.enable_sp
|
||||
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
|
||||
)
|
||||
_ENABLE_SP = enable_sp_by_pass(vllm_config) or enable_flash_comm_v1()
|
||||
|
||||
if not _ENABLE_SP and enable_shared_expert_dp:
|
||||
_ENABLE_SP = True
|
||||
logger.info("shared_expert_dp requires enable_sp = True. has set enable_sp to True")
|
||||
|
||||
if not _ENABLE_SP:
|
||||
return _ENABLE_SP
|
||||
|
||||
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
|
||||
"Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1."
|
||||
)
|
||||
|
||||
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
|
||||
"Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
|
||||
)
|
||||
|
||||
return _ENABLE_SP
|
||||
|
||||
|
||||
@@ -1113,7 +1109,7 @@ def enable_dsa_cp() -> bool:
|
||||
is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk"
|
||||
)
|
||||
return bool(is_ds_v32 and enable_sp())
|
||||
return bool(is_ds_v32 and enable_flash_comm_v1())
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
|
||||
@@ -112,6 +112,7 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.utils import (
|
||||
enable_flash_comm_v1,
|
||||
enable_sp,
|
||||
is_drafter_moe_model,
|
||||
is_moe_model,
|
||||
@@ -1677,7 +1678,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.speculative_config,
|
||||
positions.shape[0],
|
||||
)
|
||||
if get_forward_context().sp_enabled and not isinstance(hidden_states, IntermediateTensors):
|
||||
if get_forward_context().flash_comm_v1_enabled and not isinstance(hidden_states, IntermediateTensors):
|
||||
hidden_states = self._all_gather_hidden_states_and_aux(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -1685,7 +1686,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if enable_sp():
|
||||
if enable_sp(self.vllm_config):
|
||||
return round_up(num_scheduled_tokens, tp_size)
|
||||
return num_scheduled_tokens
|
||||
|
||||
@@ -2223,7 +2224,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# tp_size; otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading
|
||||
# to incorrect memory estimation and potentially causing OOM.
|
||||
intermediate_tokens = num_tokens_padded
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_tokens = (num_tokens_padded + tp_size - 1) // tp_size
|
||||
if self.intermediate_tensors is None:
|
||||
|
||||
@@ -55,7 +55,7 @@ from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
from vllm_ascend.utils import (
|
||||
AscendDeviceType,
|
||||
check_ascend_device_type,
|
||||
enable_sp,
|
||||
enable_flash_comm_v1,
|
||||
get_ascend_device_type,
|
||||
register_ascend_customop,
|
||||
)
|
||||
@@ -376,7 +376,7 @@ class NPUWorker(WorkerBase):
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
|
||||
# it will conflict with the all-gather operation in flashcomm1.
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
all_gather_group = None
|
||||
else:
|
||||
all_gather_group = get_tp_group()
|
||||
@@ -393,7 +393,7 @@ class NPUWorker(WorkerBase):
|
||||
assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank
|
||||
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
|
||||
# it will conflict with the all-gather operation in flashcomm1.
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
all_gather_group = None
|
||||
else:
|
||||
all_gather_group = get_tp_group()
|
||||
|
||||
Reference in New Issue
Block a user