[Feat]support sequence parallelism by pass for VL models (#5632)
This commit is contained in:
@@ -440,7 +440,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
replace_allreduce=forward_context.sp_enabled,
|
||||
replace_allreduce=forward_context.flash_comm_v1_enabled,
|
||||
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
||||
quant_type=self.quant_type,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
|
||||
from vllm_ascend.utils import enable_sp, maybe_trans_nz
|
||||
from vllm_ascend.utils import enable_flash_comm_v1, maybe_trans_nz
|
||||
|
||||
|
||||
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
|
||||
@@ -240,7 +240,7 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
# TODO(kunpengW-code): Specifying the prefix in linear layers of some models in the vLLM.
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
unique_prefix = prefix
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
|
||||
@@ -70,7 +70,7 @@ from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
||||
from vllm_ascend.utils import (
|
||||
enable_dsa_cp,
|
||||
enable_dsa_cp_with_layer_shard,
|
||||
enable_sp,
|
||||
enable_flash_comm_v1,
|
||||
flashcomm2_enable,
|
||||
get_flashcomm2_reorgnized_batch_ids,
|
||||
get_weight_prefetch_method,
|
||||
@@ -368,7 +368,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if not forward_context.sp_enabled:
|
||||
if not forward_context.flash_comm_v1_enabled:
|
||||
# flashcomm1 not enabled
|
||||
output = get_tp_group().all_gather(output, 0)
|
||||
if num_padding_tokens > 0:
|
||||
@@ -466,7 +466,7 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||
|
||||
# Trigger async broadcast before matmul to overlap communication.
|
||||
@@ -515,15 +515,15 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
assert self.quant_method is not None
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
|
||||
mmrs_fusion = forward_context.mmrs_fusion
|
||||
except AssertionError:
|
||||
sp_enabled = False
|
||||
flash_comm_v1_enabled = False
|
||||
mmrs_fusion = False
|
||||
|
||||
x = input_parallel
|
||||
|
||||
if not sp_enabled:
|
||||
if not flash_comm_v1_enabled:
|
||||
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
|
||||
return tensor_model_parallel_all_reduce(output_parallel)
|
||||
|
||||
@@ -649,7 +649,7 @@ def _get_column_parallel_op(
|
||||
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
|
||||
return Flashcomm2OshardQKVParallelOp(layer)
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
sp_column_prefix = [
|
||||
@@ -688,7 +688,7 @@ def _get_row_parallel_op(
|
||||
if flashcomm2_enable():
|
||||
if "o_proj" in prefix or "out_proj" in prefix:
|
||||
return Flashcomm2OProjRowParallelOp(layer)
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
sp_row_prefixes = [
|
||||
|
||||
@@ -150,7 +150,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
kv_cache: torch.Tensor | None = None,
|
||||
attn_metadata: AttentionMetadata | None = None,
|
||||
) -> torch.Tensor:
|
||||
need_gather_q_kv = get_forward_context().sp_enabled
|
||||
need_gather_q_kv = get_forward_context().flash_comm_v1_enabled
|
||||
output_shape = hidden_states.shape
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
@@ -26,8 +26,6 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch
|
||||
return residual
|
||||
|
||||
if x.size(0) != residual.size(0):
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
assert sp_enabled is True, "Currently, this situation only occurs when sp is enabled"
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||
@@ -44,8 +42,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
||||
except AssertionError:
|
||||
return x
|
||||
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled and label:
|
||||
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
|
||||
if flash_comm_v1_enabled and label:
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
x = tensor_model_parallel_all_gather(x, 0)
|
||||
@@ -75,7 +73,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
if not getattr(forward_context, "sp_enabled", False):
|
||||
if not getattr(forward_context, "flash_comm_v1_enabled", False):
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
@@ -99,7 +97,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if get_forward_context().sp_enabled and label:
|
||||
if get_forward_context().flash_comm_v1_enabled and label:
|
||||
return torch.empty(
|
||||
(x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
@@ -108,7 +106,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if get_forward_context().sp_enabled:
|
||||
if get_forward_context().flash_comm_v1_enabled:
|
||||
return torch.empty(
|
||||
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
@@ -141,7 +139,10 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
|
||||
def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} or forward_context.sp_enabled:
|
||||
if (
|
||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
or forward_context.flash_comm_v1_enabled
|
||||
):
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
@@ -161,7 +162,7 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str)
|
||||
forward_context = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
num_tokens = input_parallel.size(0)
|
||||
if forward_context.sp_enabled:
|
||||
if forward_context.flash_comm_v1_enabled:
|
||||
num_tokens = num_tokens // self.tp_size
|
||||
output = torch.empty(
|
||||
size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype
|
||||
@@ -203,7 +204,7 @@ def _rope_forward_oot_impl_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_chunk_residual",
|
||||
op_func=_maybe_chunk_residual_impl,
|
||||
fake_impl=lambda x, residual: x,
|
||||
fake_impl=lambda x, residual: torch.empty_like(x),
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user