【main】SP For Qwen3 MoE (#2209)
### What this PR does / why we need it?
Qwen3 MoE supports SP. In scenarios like AlltoAll, AlltoAllv, and MC2,
replacing AllReduce with Reduce-Scatter and AllGather achieves
computational benefits in norm operations while saving one AllGather
communication. This feature is enabled during the P-phase and delivers
notable gains in long-sequence scenarios (e.g., 16k–25k), with
performance improvements reaching 5%–10%.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
```
compilation_config={
"pass_config":{
"enable_sequence_parallelism": True
}
},
enable_expert_parallel=True,
```
- vLLM version: v0.10.0
- vLLM main:
9edd1db02b
---------
Signed-off-by: libaokui <libaokui@huawei.com>
Co-authored-by: libaokui <libaokui@huawei.com>
This commit is contained in:
1
.github/workflows/vllm_ascend_test.yaml
vendored
1
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -284,6 +284,7 @@ jobs:
|
|||||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv
|
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv
|
||||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
|
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
|
||||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
|
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
|
||||||
|
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
|
||||||
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
||||||
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
|
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
|
||||||
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \
|
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \
|
||||||
|
|||||||
@@ -234,3 +234,27 @@ def test_models_distributed_DeepSeek_W4A8DYNAMIC():
|
|||||||
},
|
},
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_model.generate_greedy(prompts, max_tokens)
|
vllm_model.generate_greedy(prompts, max_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sp_for_qwen3_moe() -> None:
|
||||||
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
]
|
||||||
|
sampling_params = SamplingParams(max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.9)
|
||||||
|
|
||||||
|
with VllmRunner(
|
||||||
|
snapshot_download("Qwen/Qwen3-30B-A3B"),
|
||||||
|
dtype="auto",
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
distributed_executor_backend="mp",
|
||||||
|
compilation_config={
|
||||||
|
"pass_config": {
|
||||||
|
"enable_sequence_parallelism": True
|
||||||
|
}
|
||||||
|
},
|
||||||
|
enable_expert_parallel=True,
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_model.generate(example_prompts, sampling_params)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class TestNPUPlatform(TestBase):
|
|||||||
self.mock_vllm_config.cache_config = MagicMock()
|
self.mock_vllm_config.cache_config = MagicMock()
|
||||||
self.mock_vllm_config.scheduler_config = MagicMock()
|
self.mock_vllm_config.scheduler_config = MagicMock()
|
||||||
self.mock_vllm_config.speculative_config = None
|
self.mock_vllm_config.speculative_config = None
|
||||||
|
self.mock_vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False
|
||||||
|
|
||||||
self.mock_ascend_config = MagicMock()
|
self.mock_ascend_config = MagicMock()
|
||||||
self.mock_ascend_config.torchair_graph_config.enabled = False
|
self.mock_ascend_config.torchair_graph_config.enabled = False
|
||||||
|
|||||||
@@ -151,6 +151,7 @@ class AscendMetadata:
|
|||||||
slot_mapping: torch.Tensor = None
|
slot_mapping: torch.Tensor = None
|
||||||
|
|
||||||
enable_dbo_across_dp: bool = False
|
enable_dbo_across_dp: bool = False
|
||||||
|
is_only_prefill: bool = False
|
||||||
|
|
||||||
|
|
||||||
class AscendAttentionMetadataBuilder:
|
class AscendAttentionMetadataBuilder:
|
||||||
@@ -166,7 +167,8 @@ class AscendAttentionMetadataBuilder:
|
|||||||
num_reqs,
|
num_reqs,
|
||||||
num_actual_tokens,
|
num_actual_tokens,
|
||||||
max_query_len,
|
max_query_len,
|
||||||
enable_dbo_across_dp: bool = False):
|
enable_dbo_across_dp: bool = False,
|
||||||
|
is_only_prefill: bool = False):
|
||||||
|
|
||||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||||
)
|
)
|
||||||
@@ -203,7 +205,8 @@ class AscendAttentionMetadataBuilder:
|
|||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
attn_state=attn_state,
|
attn_state=attn_state,
|
||||||
enable_dbo_across_dp=enable_dbo_across_dp)
|
enable_dbo_across_dp=enable_dbo_across_dp,
|
||||||
|
is_only_prefill=is_only_prefill)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -223,7 +223,9 @@ class AscendAttentionTorchairMetadataBuilder:
|
|||||||
num_actual_tokens,
|
num_actual_tokens,
|
||||||
max_query_len,
|
max_query_len,
|
||||||
graph_pad_size: int = -1,
|
graph_pad_size: int = -1,
|
||||||
enable_dbo_across_dp: bool = False):
|
enable_dbo_across_dp: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
|
|
||||||
|
|||||||
@@ -384,6 +384,8 @@ class AscendMLAMetadataBuilder:
|
|||||||
graph_pad_size: int = -1,
|
graph_pad_size: int = -1,
|
||||||
query_start_loc: torch.Tensor = None,
|
query_start_loc: torch.Tensor = None,
|
||||||
enable_dbo_across_dp: bool = False,
|
enable_dbo_across_dp: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> AscendMLAMetadata:
|
) -> AscendMLAMetadata:
|
||||||
assert self._num_decodes + self._num_prefills == num_reqs
|
assert self._num_decodes + self._num_prefills == num_reqs
|
||||||
|
|
||||||
|
|||||||
@@ -16,14 +16,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
from typing import Optional
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
|
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||||
get_tp_group)
|
get_tp_group)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@@ -44,8 +45,11 @@ from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
|
|||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
PPMissingLayer, extract_layer_index,
|
PPMissingLayer, extract_layer_index,
|
||||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||||
|
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
|
||||||
|
init_metadata_for_sp)
|
||||||
|
|
||||||
|
|
||||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||||
@@ -96,6 +100,7 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attn_metadata=None,
|
attn_metadata=None,
|
||||||
|
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||||
):
|
):
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
@@ -114,6 +119,7 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
|||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
enable_force_load_balance=enable_force_load_balance,
|
enable_force_load_balance=enable_force_load_balance,
|
||||||
shared_experts=None,
|
shared_experts=None,
|
||||||
|
_metadata_for_padding=_metadata_for_padding,
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -155,14 +161,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
|||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||||
config.mlp_only_layers)
|
config.mlp_only_layers)
|
||||||
use_aclgraph = (vllm_config is not None
|
self.use_aclgraph = (vllm_config is not None
|
||||||
and vllm_config.compilation_config.level
|
and vllm_config.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE
|
== CompilationLevel.PIECEWISE
|
||||||
and not vllm_config.model_config.enforce_eager)
|
and not vllm_config.model_config.enforce_eager)
|
||||||
if (layer_idx not in mlp_only_layers) and (
|
if (layer_idx not in mlp_only_layers) and (
|
||||||
config.num_experts > 0 and
|
config.num_experts > 0 and
|
||||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||||
if not use_aclgraph:
|
if not self.use_aclgraph:
|
||||||
# FIXME: custom sparse moe block doesn't work with aclgraph.
|
# FIXME: custom sparse moe block doesn't work with aclgraph.
|
||||||
self.mlp = CustomSparseMoeBlock(config=config,
|
self.mlp = CustomSparseMoeBlock(config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@@ -182,6 +188,60 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
|||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
self.enable_sequence_parallelism = (
|
||||||
|
vllm_config.compilation_config.pass_config.
|
||||||
|
enable_sequence_parallelism if vllm_config is not None else False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# To prevent precision issues during the decoder phase when only prefilling enables SP
|
||||||
|
if not self.enable_sequence_parallelism:
|
||||||
|
self.self_attn.o_proj.reduce_results = True
|
||||||
|
else:
|
||||||
|
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||||
|
residual = _metadata_for_padding.padding_slice(residual)
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
|
||||||
|
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||||
|
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||||
|
hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||||
|
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
|
||||||
|
hidden_states)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
|
||||||
|
if not self.use_aclgraph:
|
||||||
|
hidden_states = self.mlp(
|
||||||
|
hidden_states, _metadata_for_padding=_metadata_for_padding)
|
||||||
|
else:
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class CustomQwen3MoeModel(Qwen3MoeModel):
|
class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||||
@@ -216,6 +276,45 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
|
|||||||
make_empty_intermediate_tensors_factory(
|
make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], config.hidden_size))
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
_metadata_for_padding=_metadata_for_padding)
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual
|
||||||
|
})
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||||
|
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||||
|
hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
@@ -253,6 +352,7 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
||||||
# Set MoE hyperparameters
|
# Set MoE hyperparameters
|
||||||
self.expert_weights: list[torch.Tensor] = []
|
self.expert_weights: list[torch.Tensor] = []
|
||||||
|
|
||||||
@@ -273,3 +373,16 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
|||||||
self.num_moe_layers = len(self.moe_layers)
|
self.num_moe_layers = len(self.moe_layers)
|
||||||
self.num_expert_groups = 1
|
self.num_expert_groups = 1
|
||||||
self.num_shared_experts = 0
|
self.num_shared_experts = 0
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
_metadata_for_padding = init_metadata_for_sp(
|
||||||
|
input_ids, self.enable_sequence_parallelism)
|
||||||
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
|
inputs_embeds, _metadata_for_padding)
|
||||||
|
return hidden_states
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
|
|||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
||||||
|
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
||||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||||
get_all_reduce_merge_state,
|
get_all_reduce_merge_state,
|
||||||
@@ -1347,7 +1348,8 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
shared_experts: Optional[Any] = None,
|
shared_experts: Optional[Any] = None,
|
||||||
gate=None,
|
gate=None,
|
||||||
replace_allreduce: bool = False):
|
replace_allreduce: bool = False,
|
||||||
|
_metadata_for_padding: Optional[MetadataForPadding] = None):
|
||||||
|
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
@@ -1381,7 +1383,17 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||||
shared_hidden_states = shared_experts(hidden_states)
|
shared_hidden_states = shared_experts(hidden_states)
|
||||||
|
|
||||||
|
mc2_mask = forward_context.mc2_mask
|
||||||
|
|
||||||
|
enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
if enable_sp:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
|
||||||
|
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
|
||||||
|
mc2_mask = chunk_mc2_mask[tp_rank]
|
||||||
|
replace_allreduce = True
|
||||||
|
|
||||||
if (fused_moe_state not in [
|
if (fused_moe_state not in [
|
||||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||||
FusedMoEState.NaiveMulticast
|
FusedMoEState.NaiveMulticast
|
||||||
|
|||||||
120
vllm_ascend/ops/sequence_parallel.py
Normal file
120
vllm_ascend/ops/sequence_parallel.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||||
|
get_tp_group, tensor_model_parallel_all_gather,
|
||||||
|
tensor_model_parallel_reduce_scatter)
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
|
from vllm_ascend.platform import NPUPlatform
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataForPadding:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
padding_flag=False,
|
||||||
|
lengths_sum_padding=0,
|
||||||
|
lengths_sum_unpadding=0,
|
||||||
|
pad_size=0,
|
||||||
|
not_dummy_and_is_prefill=False):
|
||||||
|
self.padding_flag = padding_flag
|
||||||
|
self.not_dummy_and_is_prefill = not_dummy_and_is_prefill
|
||||||
|
|
||||||
|
self.lengths_sum_padding = lengths_sum_padding
|
||||||
|
self.lengths_sum_unpadding = lengths_sum_unpadding
|
||||||
|
self.pad_size = pad_size
|
||||||
|
|
||||||
|
self.tp_size = get_tp_group().world_size
|
||||||
|
self.tp_rank_in_group = get_tp_group().rank_in_group
|
||||||
|
|
||||||
|
assert self.lengths_sum_padding % self.tp_size == 0
|
||||||
|
self.slice_size = self.lengths_sum_padding // self.tp_size
|
||||||
|
|
||||||
|
self.mc2_mask = torch.zeros(
|
||||||
|
self.lengths_sum_padding,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=NPUPlatform.device_type,
|
||||||
|
)
|
||||||
|
self.mc2_mask[:lengths_sum_unpadding] = True
|
||||||
|
|
||||||
|
def padding_aligned_reduce_scatter(self,
|
||||||
|
data: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.padding_flag:
|
||||||
|
pad_size = self.pad_size
|
||||||
|
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||||
|
else:
|
||||||
|
padded_data = data
|
||||||
|
padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(
|
||||||
|
padded_data, 0)
|
||||||
|
|
||||||
|
return padded_data_reduce_scatter
|
||||||
|
|
||||||
|
def allgather_unpadding_aligned(self,
|
||||||
|
padded_data: torch.Tensor) -> torch.Tensor:
|
||||||
|
padded_data_allgather = tensor_model_parallel_all_gather(
|
||||||
|
padded_data, 0)
|
||||||
|
if self.padding_flag:
|
||||||
|
lengths_sum_unpadding = self.lengths_sum_unpadding
|
||||||
|
unpadding_data = padded_data_allgather[:lengths_sum_unpadding]
|
||||||
|
else:
|
||||||
|
unpadding_data = padded_data_allgather
|
||||||
|
return unpadding_data
|
||||||
|
|
||||||
|
def padding_slice(self, data: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
padded_data = F.pad(data, (0, 0, 0, self.pad_size))
|
||||||
|
start = self.tp_rank_in_group * self.slice_size
|
||||||
|
end = start + self.slice_size
|
||||||
|
slice_data = padded_data[start:end]
|
||||||
|
|
||||||
|
return slice_data
|
||||||
|
|
||||||
|
def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.padding_flag:
|
||||||
|
pad_size = self.pad_size
|
||||||
|
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||||
|
else:
|
||||||
|
padded_data = data
|
||||||
|
# padded_data = data
|
||||||
|
padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0)
|
||||||
|
|
||||||
|
padded_data_reduce_scatter = padded_data[self.tp_rank_in_group]
|
||||||
|
|
||||||
|
return padded_data_reduce_scatter
|
||||||
|
|
||||||
|
|
||||||
|
def init_metadata_for_sp(input_ids, enable_sequence_parallelism):
|
||||||
|
if not enable_sequence_parallelism:
|
||||||
|
return MetadataForPadding(padding_flag=False,
|
||||||
|
not_dummy_and_is_prefill=False)
|
||||||
|
|
||||||
|
is_perifll = 0
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
if attn_metadata is not None:
|
||||||
|
if hasattr(attn_metadata,
|
||||||
|
'is_only_prefill') and attn_metadata.is_only_prefill:
|
||||||
|
is_perifll = 1
|
||||||
|
if hasattr(attn_metadata,
|
||||||
|
'num_prefills') and attn_metadata.num_prefills > 0:
|
||||||
|
is_perifll = 1
|
||||||
|
|
||||||
|
if is_perifll:
|
||||||
|
lengths_sum_unpadding = input_ids.shape[0]
|
||||||
|
lengths_sum_padding = (
|
||||||
|
(lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
|
||||||
|
if lengths_sum_unpadding == lengths_sum_padding:
|
||||||
|
padding_flag = False
|
||||||
|
else:
|
||||||
|
padding_flag = True
|
||||||
|
pad_size = lengths_sum_padding - lengths_sum_unpadding
|
||||||
|
_metadata_for_padding = MetadataForPadding(
|
||||||
|
lengths_sum_unpadding=lengths_sum_unpadding,
|
||||||
|
lengths_sum_padding=lengths_sum_padding,
|
||||||
|
padding_flag=padding_flag,
|
||||||
|
pad_size=pad_size,
|
||||||
|
not_dummy_and_is_prefill=True)
|
||||||
|
|
||||||
|
return _metadata_for_padding
|
||||||
|
|
||||||
|
return MetadataForPadding(padding_flag=False,
|
||||||
|
not_dummy_and_is_prefill=False)
|
||||||
@@ -195,6 +195,12 @@ class NPUPlatform(Platform):
|
|||||||
ascend_config.ascend_scheduler_config)
|
ascend_config.ascend_scheduler_config)
|
||||||
vllm_config.scheduler_config = ascend_scheduler_config
|
vllm_config.scheduler_config = ascend_scheduler_config
|
||||||
|
|
||||||
|
if compilation_config.pass_config.enable_sequence_parallelism:
|
||||||
|
if not parallel_config.enable_expert_parallel or vllm_config.model_config.hf_config.model_type != "qwen3_moe":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"For better performance in Qwen3 MoE, SP only works exclusively with MC2, AllToAll, and AllToAllV."
|
||||||
|
)
|
||||||
|
|
||||||
# register Ascend CustomOp
|
# register Ascend CustomOp
|
||||||
register_ascend_customop()
|
register_ascend_customop()
|
||||||
|
|
||||||
|
|||||||
@@ -1160,6 +1160,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
with_prefill = attn_state not in [
|
with_prefill = attn_state not in [
|
||||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||||
]
|
]
|
||||||
|
|
||||||
|
is_only_prefill = bool(np.all(num_valid_tokens != 1))
|
||||||
|
extra_builder_kwargs['is_only_prefill'] = is_only_prefill
|
||||||
|
|
||||||
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
||||||
attn_state,
|
attn_state,
|
||||||
total_num_scheduled_tokens)
|
total_num_scheduled_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user