[Refactor] [SP]The sequence parallelism characteristics in the MoE and Dense models are integrated into a single solution. (#3085)
What this PR does / why we need it?
there are two sets of sp implementations for moe and dense models. One
is called sequence_parallelism, and the other is flashcomm_v1.
We did the following things:
Merge two sets of code with the same implementation into one.
Remove the implementation of sequence_parallelism, as this solution
cannot support aclgraph.
Does this PR introduce any user-facing change?
No
How was this patch tested?
e2e&ut
- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
|||||||
set_forward_context)
|
set_forward_context)
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
|
from vllm_ascend.utils import enable_sp
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEState(Enum):
|
class FusedMoEState(Enum):
|
||||||
@@ -101,21 +102,19 @@ def set_ascend_forward_context(
|
|||||||
# due to multiple warmups before actual capturing
|
# due to multiple warmups before actual capturing
|
||||||
forward_context.capturing = False
|
forward_context.capturing = False
|
||||||
|
|
||||||
# set for flashcomm_v1, 1000 is the batchsize concurrency threshold for enabling the flashcomm_v1 feature.
|
# set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
|
||||||
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
|
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
|
||||||
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
||||||
# the performance may degrade due to the switching of communication methods.
|
# the performance may degrade due to the switching of communication methods.
|
||||||
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
|
sp_enabled = enable_sp() and \
|
||||||
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
|
|
||||||
tp_world_size > 1 and \
|
tp_world_size > 1 and \
|
||||||
num_tokens is not None and num_tokens > 1000
|
num_tokens is not None and num_tokens > 1000
|
||||||
|
|
||||||
if flashcomm_v1_enabled:
|
if sp_enabled:
|
||||||
pad_size = (tp_world_size -
|
pad_size = (tp_world_size -
|
||||||
(num_tokens % tp_world_size)) % tp_world_size
|
(num_tokens % tp_world_size)) % tp_world_size
|
||||||
forward_context.pad_size = pad_size
|
forward_context.pad_size = pad_size
|
||||||
|
forward_context.sp_enabled = sp_enabled
|
||||||
forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled
|
|
||||||
|
|
||||||
# set this for rope forward_oot using
|
# set this for rope forward_oot using
|
||||||
forward_context.is_first_layer = True
|
forward_context.is_first_layer = True
|
||||||
|
|||||||
@@ -163,7 +163,6 @@ class AscendMetadata:
|
|||||||
|
|
||||||
# *************************** Other Properties *************************** #
|
# *************************** Other Properties *************************** #
|
||||||
enable_dbo_across_dp: bool = False
|
enable_dbo_across_dp: bool = False
|
||||||
is_only_prefill: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class AscendAttentionMetadataBuilder:
|
class AscendAttentionMetadataBuilder:
|
||||||
@@ -236,8 +235,7 @@ class AscendAttentionMetadataBuilder:
|
|||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
attn_state=attn_state,
|
attn_state=attn_state,
|
||||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
|
||||||
is_only_prefill=common_attn_metadata.is_only_prefill)
|
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
def build_for_graph_capture(
|
def build_for_graph_capture(
|
||||||
|
|||||||
@@ -17,14 +17,14 @@
|
|||||||
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
|
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||||
get_tp_group)
|
get_tp_group)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@@ -45,11 +45,8 @@ from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
|
|||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
PPMissingLayer, extract_layer_index,
|
PPMissingLayer, extract_layer_index,
|
||||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||||
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
|
|
||||||
init_metadata_for_sp)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||||
@@ -100,7 +97,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attn_metadata=None,
|
attn_metadata=None,
|
||||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
|
||||||
):
|
):
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
@@ -119,7 +115,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
|||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
enable_force_load_balance=enable_force_load_balance,
|
enable_force_load_balance=enable_force_load_balance,
|
||||||
shared_experts=None,
|
shared_experts=None,
|
||||||
_metadata_for_padding=_metadata_for_padding,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -188,60 +183,6 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
|||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
self.enable_sequence_parallelism = (
|
|
||||||
vllm_config.compilation_config.pass_config.
|
|
||||||
enable_sequence_parallelism if vllm_config is not None else False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
residual: Optional[torch.Tensor],
|
|
||||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
# To prevent precision issues during the decoder phase when only prefilling enables SP
|
|
||||||
if not self.enable_sequence_parallelism:
|
|
||||||
self.self_attn.o_proj.reduce_results = True
|
|
||||||
else:
|
|
||||||
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
|
|
||||||
|
|
||||||
# Self Attention
|
|
||||||
if residual is None:
|
|
||||||
residual = hidden_states
|
|
||||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
|
||||||
residual = _metadata_for_padding.padding_slice(residual)
|
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
else:
|
|
||||||
hidden_states, residual = self.input_layernorm(
|
|
||||||
hidden_states, residual)
|
|
||||||
|
|
||||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
|
||||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
|
||||||
hidden_states)
|
|
||||||
|
|
||||||
hidden_states = self.self_attn(
|
|
||||||
positions=positions,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
|
||||||
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
|
|
||||||
hidden_states)
|
|
||||||
|
|
||||||
# Fully Connected
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
|
||||||
hidden_states, residual)
|
|
||||||
|
|
||||||
if not self.use_aclgraph:
|
|
||||||
hidden_states = self.mlp(
|
|
||||||
hidden_states, _metadata_for_padding=_metadata_for_padding)
|
|
||||||
else:
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states, residual
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class CustomQwen3MoeModel(Qwen3MoeModel):
|
class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||||
@@ -277,45 +218,6 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
|
|||||||
make_empty_intermediate_tensors_factory(
|
make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], config.hidden_size))
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
if get_pp_group().is_first_rank:
|
|
||||||
if inputs_embeds is not None:
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
else:
|
|
||||||
hidden_states = self.get_input_embeddings(input_ids)
|
|
||||||
residual = None
|
|
||||||
else:
|
|
||||||
assert intermediate_tensors is not None
|
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
|
||||||
residual = intermediate_tensors["residual"]
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
_metadata_for_padding=_metadata_for_padding)
|
|
||||||
if not get_pp_group().is_last_rank:
|
|
||||||
return IntermediateTensors({
|
|
||||||
"hidden_states": hidden_states,
|
|
||||||
"residual": residual
|
|
||||||
})
|
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
|
||||||
|
|
||||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
|
||||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
|
||||||
hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||||
|
|
||||||
@@ -340,7 +242,6 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
|
||||||
# Set MoE hyperparameters
|
# Set MoE hyperparameters
|
||||||
self.expert_weights: list[torch.Tensor] = []
|
self.expert_weights: list[torch.Tensor] = []
|
||||||
|
|
||||||
@@ -361,16 +262,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
|||||||
self.num_moe_layers = len(self.moe_layers)
|
self.num_moe_layers = len(self.moe_layers)
|
||||||
self.num_expert_groups = 1
|
self.num_expert_groups = 1
|
||||||
self.num_shared_experts = 0
|
self.num_shared_experts = 0
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
_metadata_for_padding = init_metadata_for_sp(
|
|
||||||
input_ids, self.enable_sequence_parallelism)
|
|
||||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
|
||||||
inputs_embeds, _metadata_for_padding)
|
|
||||||
return hidden_states
|
|
||||||
|
|||||||
@@ -216,7 +216,9 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||||
hidden_states=hidden_states, router_logits=router_logits)
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
replace_allreduce=forward_context.sp_enabled)
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply(
|
final_hidden_states = self.quant_method.apply(
|
||||||
|
|||||||
@@ -21,8 +21,7 @@ from typing import Any, Callable, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
get_tensor_model_parallel_world_size)
|
|
||||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||||
get_tp_group)
|
get_tp_group)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@@ -42,7 +41,6 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
|||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||||
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
|
||||||
get_all_reduce_merge_state,
|
get_all_reduce_merge_state,
|
||||||
get_rm_router_logits_state, is_310p,
|
get_rm_router_logits_state, is_310p,
|
||||||
@@ -360,8 +358,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
shared_experts: Optional[Any] = None,
|
shared_experts: Optional[Any] = None,
|
||||||
gate=None,
|
gate=None,
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False):
|
||||||
_metadata_for_padding: Optional[MetadataForPadding] = None):
|
|
||||||
|
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
@@ -379,13 +376,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||||
shared_hidden_states = shared_experts(hidden_states)
|
shared_hidden_states = shared_experts(hidden_states)
|
||||||
|
|
||||||
enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
|
if forward_context.sp_enabled:
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
if enable_sp:
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
|
|
||||||
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
|
|
||||||
mc2_mask = chunk_mc2_mask[tp_rank]
|
|
||||||
replace_allreduce = True
|
replace_allreduce = True
|
||||||
|
|
||||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||||
|
|||||||
@@ -48,8 +48,9 @@ from vllm.distributed.parallel_state import get_tp_group
|
|||||||
|
|
||||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||||
get_otp_group)
|
get_otp_group)
|
||||||
from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable,
|
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
|
||||||
mlp_tp_enable, oproj_tp_enable)
|
matmul_allreduce_enable, mlp_tp_enable,
|
||||||
|
oproj_tp_enable)
|
||||||
|
|
||||||
|
|
||||||
class CustomTensorParallelOp:
|
class CustomTensorParallelOp:
|
||||||
@@ -82,10 +83,17 @@ class CustomTensorParallelOp:
|
|||||||
self.skip_bias_add = self.layer.skip_bias_add
|
self.skip_bias_add = self.layer.skip_bias_add
|
||||||
self.return_bias = self.layer.return_bias
|
self.return_bias = self.layer.return_bias
|
||||||
self.quant_method = self.layer.quant_method
|
self.quant_method = self.layer.quant_method
|
||||||
|
self.prefix = self.layer.prefix
|
||||||
|
|
||||||
|
def apply_impl(self, input_):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
# Replace layer.forward to customize the layer computation process.
|
# Replace layer.forward to customize the layer computation process.
|
||||||
def apply(self, input_):
|
def apply(self, input_):
|
||||||
raise NotImplementedError
|
output, output_bias = self.apply_impl(input_)
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
class CustomColumnParallelOp(CustomTensorParallelOp):
|
class CustomColumnParallelOp(CustomTensorParallelOp):
|
||||||
@@ -113,6 +121,14 @@ class CustomRowParallelOp(CustomTensorParallelOp):
|
|||||||
self.reduce_results = self.layer.reduce_results
|
self.reduce_results = self.layer.reduce_results
|
||||||
self.input_size_per_partition = self.layer.input_size_per_partition
|
self.input_size_per_partition = self.layer.input_size_per_partition
|
||||||
|
|
||||||
|
def apply(self, input_):
|
||||||
|
output, output_bias = self.apply_impl(input_)
|
||||||
|
if dense_optim_enable():
|
||||||
|
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
class MLPColumnParallelOp(CustomColumnParallelOp):
|
class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||||
|
|
||||||
@@ -123,7 +139,7 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
|
|||||||
def comm_group(self):
|
def comm_group(self):
|
||||||
return get_mlp_tp_group()
|
return get_mlp_tp_group()
|
||||||
|
|
||||||
def apply(
|
def apply_impl(
|
||||||
self,
|
self,
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
@@ -134,14 +150,12 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
|
|||||||
output = self.quant_method.apply(self.layer, input_parallel, bias)
|
output = self.quant_method.apply(self.layer, input_parallel, bias)
|
||||||
|
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
if not self.return_bias:
|
|
||||||
return output
|
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
class DenseOptimMergedColumnParallelOp(CustomColumnParallelOp):
|
class SequenceMergedColumnParallelOp(CustomColumnParallelOp):
|
||||||
|
|
||||||
def apply(
|
def apply_impl(
|
||||||
self, input_: torch.Tensor
|
self, input_: torch.Tensor
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
"""Linear layer with column parallelism.
|
"""Linear layer with column parallelism.
|
||||||
@@ -164,18 +178,16 @@ class DenseOptimMergedColumnParallelOp(CustomColumnParallelOp):
|
|||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
if not self.return_bias:
|
|
||||||
return output
|
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
class DenseOptimQKVParallelOp(CustomColumnParallelOp):
|
class SequenceQKVParallelOp(CustomColumnParallelOp):
|
||||||
|
|
||||||
def __init__(self, layer, prefix):
|
def __init__(self, layer, prefix):
|
||||||
super().__init__(layer)
|
super().__init__(layer)
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
def apply(
|
def apply_impl(
|
||||||
self, input_: torch.Tensor
|
self, input_: torch.Tensor
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
"""Linear layer with column parallelism.
|
"""Linear layer with column parallelism.
|
||||||
@@ -201,8 +213,6 @@ class DenseOptimQKVParallelOp(CustomColumnParallelOp):
|
|||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
if not self.return_bias:
|
|
||||||
return output
|
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
@@ -215,7 +225,7 @@ class MLPRowParallelOp(CustomRowParallelOp):
|
|||||||
def comm_group(self):
|
def comm_group(self):
|
||||||
return get_mlp_tp_group()
|
return get_mlp_tp_group()
|
||||||
|
|
||||||
def apply(
|
def apply_impl(
|
||||||
self, input_: torch.Tensor
|
self, input_: torch.Tensor
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
if self.input_is_parallel:
|
if self.input_is_parallel:
|
||||||
@@ -234,8 +244,6 @@ class MLPRowParallelOp(CustomRowParallelOp):
|
|||||||
output = self.comm_group.reduce_scatter(output_parallel, 0)
|
output = self.comm_group.reduce_scatter(output_parallel, 0)
|
||||||
|
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
if not self.return_bias:
|
|
||||||
return output
|
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
@@ -248,7 +256,7 @@ class OProjRowParallelOp(CustomRowParallelOp):
|
|||||||
def comm_group(self):
|
def comm_group(self):
|
||||||
return get_otp_group()
|
return get_otp_group()
|
||||||
|
|
||||||
def apply(
|
def apply_impl(
|
||||||
self,
|
self,
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
@@ -294,8 +302,6 @@ class OProjRowParallelOp(CustomRowParallelOp):
|
|||||||
|
|
||||||
# Handle bias return based on configuration
|
# Handle bias return based on configuration
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
if not self.return_bias:
|
|
||||||
return output
|
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
def update_attrs(self):
|
def update_attrs(self):
|
||||||
@@ -311,7 +317,7 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
|||||||
super().__init__(layer)
|
super().__init__(layer)
|
||||||
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
|
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
|
||||||
|
|
||||||
def apply(
|
def apply_impl(
|
||||||
self, input_: torch.Tensor
|
self, input_: torch.Tensor
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
if self.input_is_parallel:
|
if self.input_is_parallel:
|
||||||
@@ -335,8 +341,6 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
|||||||
bias=bias_)
|
bias=bias_)
|
||||||
|
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
if not self.return_bias:
|
|
||||||
return output
|
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -359,13 +363,13 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
|||||||
self.weight_t = self.layer.weight.t()
|
self.weight_t = self.layer.weight.t()
|
||||||
|
|
||||||
|
|
||||||
class DenseOptimRowParallelOp(CustomRowParallelOp):
|
class SequenceRowParallelOp(CustomRowParallelOp):
|
||||||
|
|
||||||
def __init__(self, layer, prefix):
|
def __init__(self, layer, prefix):
|
||||||
super().__init__(layer)
|
super().__init__(layer)
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
def apply(
|
def apply_impl(
|
||||||
self, input_: torch.Tensor
|
self, input_: torch.Tensor
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
"""Linear layer with column parallelism.
|
"""Linear layer with column parallelism.
|
||||||
@@ -391,12 +395,8 @@ class DenseOptimRowParallelOp(CustomRowParallelOp):
|
|||||||
input_parallel,
|
input_parallel,
|
||||||
bias=bias_)
|
bias=bias_)
|
||||||
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
|
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
|
||||||
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
|
|
||||||
|
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
|
||||||
if not self.return_bias:
|
|
||||||
return output
|
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
def update_attrs(self):
|
def update_attrs(self):
|
||||||
@@ -407,23 +407,22 @@ class DenseOptimRowParallelOp(CustomRowParallelOp):
|
|||||||
|
|
||||||
def get_column_parallel_op(
|
def get_column_parallel_op(
|
||||||
disable_tp, prefix, layer
|
disable_tp, prefix, layer
|
||||||
) -> Tuple[
|
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
|
||||||
Optional[Union[MLPColumnParallelOp, DenseOptimMergedColumnParallelOp,
|
SequenceQKVParallelOp]], int, int]:
|
||||||
DenseOptimQKVParallelOp]], int, int]:
|
|
||||||
if disable_tp:
|
if disable_tp:
|
||||||
return None, 0, 1
|
return None, 0, 1
|
||||||
|
|
||||||
custom_op: Optional[Union[
|
custom_op: Optional[Union[
|
||||||
MLPColumnParallelOp,
|
MLPColumnParallelOp,
|
||||||
DenseOptimMergedColumnParallelOp,
|
SequenceMergedColumnParallelOp,
|
||||||
DenseOptimQKVParallelOp,
|
SequenceQKVParallelOp,
|
||||||
]] = None
|
]] = None
|
||||||
if "gate_up_proj" in prefix and mlp_tp_enable():
|
if "gate_up_proj" in prefix and mlp_tp_enable():
|
||||||
custom_op = MLPColumnParallelOp(layer)
|
custom_op = MLPColumnParallelOp(layer)
|
||||||
elif "gate_up_proj" in prefix and dense_optim_enable():
|
elif "gate_up_proj" in prefix and enable_sp():
|
||||||
custom_op = DenseOptimMergedColumnParallelOp(layer)
|
custom_op = SequenceMergedColumnParallelOp(layer)
|
||||||
elif dense_optim_enable():
|
elif enable_sp():
|
||||||
custom_op = DenseOptimQKVParallelOp(layer, prefix)
|
custom_op = SequenceQKVParallelOp(layer, prefix)
|
||||||
|
|
||||||
if custom_op is not None:
|
if custom_op is not None:
|
||||||
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
||||||
@@ -435,21 +434,21 @@ def get_row_parallel_op(
|
|||||||
disable_tp, prefix, layer
|
disable_tp, prefix, layer
|
||||||
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||||
MatmulAllreduceRowParallelOp,
|
MatmulAllreduceRowParallelOp,
|
||||||
DenseOptimRowParallelOp]], int, int]:
|
SequenceRowParallelOp]], int, int]:
|
||||||
if disable_tp:
|
if disable_tp:
|
||||||
return None, 0, 1
|
return None, 0, 1
|
||||||
|
|
||||||
custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||||
MatmulAllreduceRowParallelOp,
|
MatmulAllreduceRowParallelOp,
|
||||||
DenseOptimRowParallelOp]] = None
|
SequenceRowParallelOp]] = None
|
||||||
if "down_proj" in prefix and mlp_tp_enable():
|
if "down_proj" in prefix and mlp_tp_enable():
|
||||||
custom_op = MLPRowParallelOp(layer)
|
custom_op = MLPRowParallelOp(layer)
|
||||||
elif "o_proj" in prefix and oproj_tp_enable():
|
elif "o_proj" in prefix and oproj_tp_enable():
|
||||||
custom_op = OProjRowParallelOp(layer)
|
custom_op = OProjRowParallelOp(layer)
|
||||||
elif matmul_allreduce_enable():
|
elif matmul_allreduce_enable():
|
||||||
custom_op = MatmulAllreduceRowParallelOp(layer)
|
custom_op = MatmulAllreduceRowParallelOp(layer)
|
||||||
elif dense_optim_enable():
|
elif enable_sp():
|
||||||
custom_op = DenseOptimRowParallelOp(layer, prefix)
|
custom_op = SequenceRowParallelOp(layer, prefix)
|
||||||
|
|
||||||
if custom_op is not None:
|
if custom_op is not None:
|
||||||
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
||||||
|
|||||||
@@ -133,11 +133,15 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
|||||||
"""
|
"""
|
||||||
self.replace_allreduce = replace_allreduce
|
self.replace_allreduce = replace_allreduce
|
||||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
mc2_mask = forward_context.mc2_mask
|
||||||
|
if self.tp_size > 1:
|
||||||
|
# Also slice mc2_mask
|
||||||
|
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
|
||||||
|
mc2_mask = split_mc2_mask[self.tp_rank]
|
||||||
|
|
||||||
if not self.replace_allreduce:
|
if not self.replace_allreduce:
|
||||||
self.num_tokens, _ = hidden_states.shape
|
self.num_tokens, _ = hidden_states.shape
|
||||||
forward_context = get_forward_context()
|
|
||||||
mc2_mask = forward_context.mc2_mask
|
|
||||||
target_pad_length = forward_context.padded_num_tokens
|
target_pad_length = forward_context.padded_num_tokens
|
||||||
pad_size = target_pad_length - self.num_tokens
|
pad_size = target_pad_length - self.num_tokens
|
||||||
|
|
||||||
@@ -149,23 +153,16 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
|||||||
(0, 0, 0, pad_size))
|
(0, 0, 0, pad_size))
|
||||||
|
|
||||||
# Slice across TP ranks
|
# Slice across TP ranks
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1 and not self.enable_shared_expert_dp:
|
||||||
if not self.enable_shared_expert_dp:
|
split_hidden_states = torch.tensor_split(hidden_states,
|
||||||
split_hidden_states = torch.tensor_split(hidden_states,
|
self.tp_size,
|
||||||
self.tp_size,
|
dim=0)
|
||||||
dim=0)
|
split_router_logits = torch.tensor_split(router_logits,
|
||||||
split_router_logits = torch.tensor_split(router_logits,
|
self.tp_size,
|
||||||
self.tp_size,
|
dim=0)
|
||||||
dim=0)
|
hidden_states = split_hidden_states[self.tp_rank]
|
||||||
hidden_states = split_hidden_states[self.tp_rank]
|
router_logits = split_router_logits[self.tp_rank]
|
||||||
router_logits = split_router_logits[self.tp_rank]
|
self.split_hidden_states = split_hidden_states # Save for finalize
|
||||||
self.split_hidden_states = split_hidden_states # Save for finalize
|
|
||||||
|
|
||||||
# Also slice mc2_mask
|
|
||||||
split_mc2_mask = torch.tensor_split(mc2_mask,
|
|
||||||
self.tp_size,
|
|
||||||
dim=0)
|
|
||||||
mc2_mask = split_mc2_mask[self.tp_rank]
|
|
||||||
|
|
||||||
return hidden_states, router_logits, mc2_mask
|
return hidden_states, router_logits, mc2_mask
|
||||||
|
|
||||||
|
|||||||
@@ -20,10 +20,9 @@ def _maybe_chunk_residual_impl(x: torch.Tensor,
|
|||||||
return residual
|
return residual
|
||||||
|
|
||||||
if x.size(0) != residual.size(0):
|
if x.size(0) != residual.size(0):
|
||||||
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
|
sp_enabled = forward_context.sp_enabled
|
||||||
assert flashcomm_v1_enabled is True, (
|
assert sp_enabled is True, ("Currently, this situation only occurs "
|
||||||
"Currently, this situation only occurs "
|
"when sp is enabled")
|
||||||
"when flashcomm_v1 is enabled")
|
|
||||||
pad_size = forward_context.pad_size
|
pad_size = forward_context.pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||||
@@ -41,8 +40,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
|
|||||||
except AssertionError:
|
except AssertionError:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
|
sp_enabled = forward_context.sp_enabled
|
||||||
if flashcomm_v1_enabled and label:
|
if sp_enabled and label:
|
||||||
x = tensor_model_parallel_all_gather(x, 0)
|
x = tensor_model_parallel_all_gather(x, 0)
|
||||||
pad_size = forward_context.pad_size
|
pad_size = forward_context.pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
@@ -56,8 +55,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
|
|||||||
except AssertionError:
|
except AssertionError:
|
||||||
return tensor_model_parallel_all_reduce(x)
|
return tensor_model_parallel_all_reduce(x)
|
||||||
|
|
||||||
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
|
sp_enabled = forward_context.sp_enabled
|
||||||
if flashcomm_v1_enabled:
|
if sp_enabled:
|
||||||
pad_size = forward_context.pad_size
|
pad_size = forward_context.pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
x = F.pad(x, (0, 0, 0, pad_size))
|
x = F.pad(x, (0, 0, 0, pad_size))
|
||||||
|
|||||||
@@ -282,12 +282,6 @@ class NPUPlatform(Platform):
|
|||||||
ascend_config.ascend_scheduler_config)
|
ascend_config.ascend_scheduler_config)
|
||||||
vllm_config.scheduler_config = ascend_scheduler_config
|
vllm_config.scheduler_config = ascend_scheduler_config
|
||||||
|
|
||||||
if compilation_config.pass_config.enable_sequence_parallelism:
|
|
||||||
if not parallel_config.enable_expert_parallel or vllm_config.model_config.hf_config.model_type != "qwen3_moe":
|
|
||||||
raise NotImplementedError(
|
|
||||||
"For better performance in Qwen3 MoE, SP only works exclusively with MC2, AllToAll, and AllToAllV."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attn_backend_cls(cls,
|
def get_attn_backend_cls(cls,
|
||||||
selected_backend,
|
selected_backend,
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||||
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
|
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
|
||||||
init_metadata_for_sp)
|
init_metadata_for_sp)
|
||||||
|
|
||||||
|
|
||||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||||
|
|||||||
@@ -44,8 +44,8 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
|
|||||||
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
||||||
determine_default_log2phy_map)
|
determine_default_log2phy_map)
|
||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
|
||||||
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
||||||
|
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
|
||||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||||
get_all_reduce_merge_state,
|
get_all_reduce_merge_state,
|
||||||
|
|||||||
@@ -590,6 +590,14 @@ def dense_optim_enable() -> bool:
|
|||||||
return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE
|
return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE
|
||||||
|
|
||||||
|
|
||||||
|
def enable_sp() -> bool:
|
||||||
|
from vllm.config import get_cached_compilation_config
|
||||||
|
|
||||||
|
return (
|
||||||
|
get_cached_compilation_config().pass_config.enable_sequence_parallelism
|
||||||
|
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM)
|
||||||
|
|
||||||
|
|
||||||
def is_moe_model(vllm_config: VllmConfig):
|
def is_moe_model(vllm_config: VllmConfig):
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
return any('experts' in key.lower() for key in config.to_dict())
|
return any('experts' in key.lower() for key in config.to_dict())
|
||||||
|
|||||||
@@ -1582,7 +1582,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
update_attn_params(self.update_stream, forward_context,
|
update_attn_params(self.update_stream, forward_context,
|
||||||
positions.shape[0])
|
positions.shape[0])
|
||||||
|
|
||||||
if get_forward_context().flashcomm_v1_enabled:
|
if get_forward_context().sp_enabled:
|
||||||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||||
pad_size = get_forward_context().pad_size
|
pad_size = get_forward_context().pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user