[Feat]support sequence parallelism by pass for VL models (#5632)
This commit is contained in:
@@ -30,6 +30,7 @@ class AscendConfig:
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.vllm_config = vllm_config
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
|
||||
xlite_graph_config = additional_config.get("xlite_graph_config", {})
|
||||
@@ -160,6 +161,47 @@ class AscendConfig:
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def update_compile_ranges_split_points(self):
|
||||
vllm_config = self.vllm_config
|
||||
if self.npugraph_ex_config.enable:
|
||||
if self.npugraph_ex_config.fuse_allreduce_rms:
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to "
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
else:
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to "
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
from vllm_ascend.utils import is_moe_model
|
||||
|
||||
if vllm_config.compilation_config.pass_config.enable_sp and not is_moe_model(vllm_config):
|
||||
from vllm_ascend.compilation.passes.sequence_parallelism import get_sp_threshold
|
||||
|
||||
sp_threshold = get_sp_threshold(vllm_config)
|
||||
new_compile_ranges_split_points.append(sp_threshold)
|
||||
logger.debug(f"add {sp_threshold} to compile_ranges_split_points for sequence parallelism")
|
||||
if len(new_compile_ranges_split_points) > len(vllm_config.compilation_config.compile_ranges_split_points):
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
|
||||
|
||||
class FinegrainedTPConfig:
|
||||
"""
|
||||
|
||||
@@ -12,7 +12,7 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (
|
||||
AscendDeviceType,
|
||||
enable_sp,
|
||||
enable_flash_comm_v1,
|
||||
flashcomm2_enable,
|
||||
get_ascend_device_type,
|
||||
has_layer_idx,
|
||||
@@ -92,22 +92,22 @@ def set_ascend_forward_context(
|
||||
# main model and drafter model may have different architecture
|
||||
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
|
||||
if is_context_moe_model:
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None
|
||||
mmrs_fusion = False
|
||||
elif is_draft_model:
|
||||
# TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`.
|
||||
# Disable it to avoid more problems.
|
||||
sp_enabled = False
|
||||
flash_comm_v1_enabled = False
|
||||
else:
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||
|
||||
flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None and num_tokens > 1000
|
||||
forward_context.mmrs_fusion = mmrs_fusion
|
||||
forward_context.num_tokens = num_tokens
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
forward_context.flash_comm_v1_enabled = flash_comm_v1_enabled
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
||||
|
||||
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
|
||||
forward_context.pad_size = 0
|
||||
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
|
||||
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
||||
forward_context.pad_size = pad_size
|
||||
|
||||
@@ -131,7 +131,7 @@ def set_ascend_forward_context(
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
|
||||
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
|
||||
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
forward_context.padded_length = padded_length
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import copy
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -129,6 +130,10 @@ class AscendCompiler(CompilerInterface):
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
npugraph_ex_config = get_ascend_config().npugraph_ex_config
|
||||
if npugraph_ex_config.enable:
|
||||
assert hasattr(self, "vllm_config")
|
||||
|
||||
@@ -70,3 +70,8 @@ class GraphFusionPassManager:
|
||||
from .passes.allreduce_rmsnorm_fusion_pass import MatmulAllReduceAddRMSNormPass
|
||||
|
||||
self.passes.append(MatmulAllReduceAddRMSNormPass(config))
|
||||
|
||||
if config.compilation_config.pass_config.enable_sp:
|
||||
from .passes.sequence_parallelism import AscendSequenceParallelismPass
|
||||
|
||||
self.passes.append(AscendSequenceParallelismPass(config))
|
||||
|
||||
202
vllm_ascend/compilation/passes/sequence_parallelism.py
Normal file
202
vllm_ascend/compilation/passes/sequence_parallelism.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm_ascend.utils import is_moe_model, vllm_version_is
|
||||
|
||||
if vllm_version_is("0.15.0"):
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass # type: ignore
|
||||
else:
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.logger import logger
|
||||
|
||||
SP_THRESHOLD = 1000
|
||||
|
||||
|
||||
def get_sp_threshold(config: VllmConfig):
|
||||
if is_moe_model(config):
|
||||
return 1
|
||||
|
||||
additional_config = config.additional_config if config.additional_config is not None else {}
|
||||
return additional_config.get("sp_threshold", SP_THRESHOLD)
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper:
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
self.eps = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
|
||||
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.reduce_scatter(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
|
||||
|
||||
def empty(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
|
||||
|
||||
|
||||
class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def empty(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
Generate example inputs.
|
||||
"""
|
||||
input = self.empty(8, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x = self._all_reduce(input)
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(x, residual, weight, None, self.eps)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual)
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
reduce_scatter, residual, weight, None, self.eps
|
||||
)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather, residual
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(8, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = self._all_reduce(input)
|
||||
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(x, residual, weight, None, self.eps)
|
||||
|
||||
return result
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual)
|
||||
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(reduce_scatter, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(8, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
deepstack_input_embeds = self.empty(8, 16)
|
||||
return [input, weight, residual, deepstack_input_embeds]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x = self._all_reduce(input)
|
||||
add_ = x + deepstack_input_embeds
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
chunk = deepstack_input_embeds.chunk(self.tp_size)[self.tp_rank]
|
||||
add_ = reduce_scatter + chunk
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual)
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather, residual
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendSequenceParallelismPass(VllmInductorPass):
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_pass")
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
AscendMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AscendLastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AscendQwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
self.min_tokens = get_sp_threshold(config)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
"""
|
||||
Check if the pass is applicable for the current configuration.
|
||||
"""
|
||||
applicable = compile_range.start >= self.min_tokens
|
||||
logger.debug(f"SequenceParallelismPass {compile_range=} {applicable=}")
|
||||
return applicable
|
||||
@@ -440,7 +440,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
replace_allreduce=forward_context.sp_enabled,
|
||||
replace_allreduce=forward_context.flash_comm_v1_enabled,
|
||||
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
||||
quant_type=self.quant_type,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
|
||||
from vllm_ascend.utils import enable_sp, maybe_trans_nz
|
||||
from vllm_ascend.utils import enable_flash_comm_v1, maybe_trans_nz
|
||||
|
||||
|
||||
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
|
||||
@@ -240,7 +240,7 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
# TODO(kunpengW-code): Specifying the prefix in linear layers of some models in the vLLM.
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
unique_prefix = prefix
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
|
||||
@@ -70,7 +70,7 @@ from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
||||
from vllm_ascend.utils import (
|
||||
enable_dsa_cp,
|
||||
enable_dsa_cp_with_layer_shard,
|
||||
enable_sp,
|
||||
enable_flash_comm_v1,
|
||||
flashcomm2_enable,
|
||||
get_flashcomm2_reorgnized_batch_ids,
|
||||
get_weight_prefetch_method,
|
||||
@@ -368,7 +368,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if not forward_context.sp_enabled:
|
||||
if not forward_context.flash_comm_v1_enabled:
|
||||
# flashcomm1 not enabled
|
||||
output = get_tp_group().all_gather(output, 0)
|
||||
if num_padding_tokens > 0:
|
||||
@@ -466,7 +466,7 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||
|
||||
# Trigger async broadcast before matmul to overlap communication.
|
||||
@@ -515,15 +515,15 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
assert self.quant_method is not None
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
|
||||
mmrs_fusion = forward_context.mmrs_fusion
|
||||
except AssertionError:
|
||||
sp_enabled = False
|
||||
flash_comm_v1_enabled = False
|
||||
mmrs_fusion = False
|
||||
|
||||
x = input_parallel
|
||||
|
||||
if not sp_enabled:
|
||||
if not flash_comm_v1_enabled:
|
||||
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
|
||||
return tensor_model_parallel_all_reduce(output_parallel)
|
||||
|
||||
@@ -649,7 +649,7 @@ def _get_column_parallel_op(
|
||||
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
|
||||
return Flashcomm2OshardQKVParallelOp(layer)
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
sp_column_prefix = [
|
||||
@@ -688,7 +688,7 @@ def _get_row_parallel_op(
|
||||
if flashcomm2_enable():
|
||||
if "o_proj" in prefix or "out_proj" in prefix:
|
||||
return Flashcomm2OProjRowParallelOp(layer)
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
sp_row_prefixes = [
|
||||
|
||||
@@ -150,7 +150,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
kv_cache: torch.Tensor | None = None,
|
||||
attn_metadata: AttentionMetadata | None = None,
|
||||
) -> torch.Tensor:
|
||||
need_gather_q_kv = get_forward_context().sp_enabled
|
||||
need_gather_q_kv = get_forward_context().flash_comm_v1_enabled
|
||||
output_shape = hidden_states.shape
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
@@ -26,8 +26,6 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch
|
||||
return residual
|
||||
|
||||
if x.size(0) != residual.size(0):
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
assert sp_enabled is True, "Currently, this situation only occurs when sp is enabled"
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||
@@ -44,8 +42,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
||||
except AssertionError:
|
||||
return x
|
||||
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled and label:
|
||||
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
|
||||
if flash_comm_v1_enabled and label:
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
x = tensor_model_parallel_all_gather(x, 0)
|
||||
@@ -75,7 +73,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
if not getattr(forward_context, "sp_enabled", False):
|
||||
if not getattr(forward_context, "flash_comm_v1_enabled", False):
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
@@ -99,7 +97,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if get_forward_context().sp_enabled and label:
|
||||
if get_forward_context().flash_comm_v1_enabled and label:
|
||||
return torch.empty(
|
||||
(x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
@@ -108,7 +106,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if get_forward_context().sp_enabled:
|
||||
if get_forward_context().flash_comm_v1_enabled:
|
||||
return torch.empty(
|
||||
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
@@ -141,7 +139,10 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
|
||||
def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} or forward_context.sp_enabled:
|
||||
if (
|
||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
or forward_context.flash_comm_v1_enabled
|
||||
):
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
@@ -161,7 +162,7 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str)
|
||||
forward_context = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
num_tokens = input_parallel.size(0)
|
||||
if forward_context.sp_enabled:
|
||||
if forward_context.flash_comm_v1_enabled:
|
||||
num_tokens = num_tokens // self.tp_size
|
||||
output = torch.empty(
|
||||
size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype
|
||||
@@ -203,7 +204,7 @@ def _rope_forward_oot_impl_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_chunk_residual",
|
||||
op_func=_maybe_chunk_residual_impl,
|
||||
fake_impl=lambda x, residual: x,
|
||||
fake_impl=lambda x, residual: torch.empty_like(x),
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
@@ -48,6 +48,7 @@ from vllm_ascend.utils import (
|
||||
update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes,
|
||||
is_310p,
|
||||
enable_flash_comm_v1,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -198,32 +199,9 @@ class NPUPlatform(Platform):
|
||||
if not isinstance(ascend_compilation_config, dict)
|
||||
else ascend_compilation_config
|
||||
)
|
||||
ascend_config.update_compile_ranges_split_points()
|
||||
|
||||
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to "
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
npugraph_ex_config = ascend_config.npugraph_ex_config
|
||||
if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms:
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to {new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||
if model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
|
||||
|
||||
ascend_fusion_config = ascend_config.ascend_fusion_config
|
||||
@@ -417,15 +395,19 @@ class NPUPlatform(Platform):
|
||||
)
|
||||
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
|
||||
|
||||
if is_vl_model(vllm_config):
|
||||
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool(
|
||||
int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently, VL models doesn't support "
|
||||
"FLASHCOMM in vllm-ascend. We will fix this in the future. "
|
||||
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0."
|
||||
)
|
||||
if enable_flash_comm_v1():
|
||||
assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \
|
||||
Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \
|
||||
For optimal performance with VL models, we recommend enabling Sequence Parallelism \
|
||||
via --compilation-config '{"pass_config": {"enable_sp": true}}'."""
|
||||
|
||||
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
|
||||
"Flash Comm v1 is only supported when tp_size > 1."
|
||||
)
|
||||
|
||||
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
|
||||
"Flash Comm v1 requires enable_expert_parallel=True for MoE models."
|
||||
)
|
||||
|
||||
# Set "PYTORCH_NPU_ALLOC_CONF=expandable_segments:True" by default to optimize NPU memory management.
|
||||
# Find more details at https://docs.vllm.ai/projects/ascend/en/latest/faqs.html#how-to-handle-the-out-of-memory-issue
|
||||
@@ -626,16 +608,16 @@ class NPUPlatform(Platform):
|
||||
# communication methods.
|
||||
mmrs_fusion = True
|
||||
if is_moe_model(vllm_config):
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
mmrs_fusion = False
|
||||
else:
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
||||
pad_size = None
|
||||
padded_length = None
|
||||
if sp_enabled or flashcomm_v2_enabled:
|
||||
if flash_comm_v1_enabled or flashcomm_v2_enabled:
|
||||
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
||||
|
||||
if num_tokens is None and attn_metadata is not None:
|
||||
@@ -643,7 +625,7 @@ class NPUPlatform(Platform):
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and dp_metadata is not None:
|
||||
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if sp_enabled or flashcomm_v2_enabled:
|
||||
if flash_comm_v1_enabled or flashcomm_v2_enabled:
|
||||
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
else:
|
||||
@@ -664,7 +646,7 @@ class NPUPlatform(Platform):
|
||||
"capturing": capturing,
|
||||
"mmrs_fusion": mmrs_fusion,
|
||||
"num_tokens": num_tokens,
|
||||
"sp_enabled": sp_enabled,
|
||||
"flash_comm_v1_enabled": flash_comm_v1_enabled,
|
||||
"flashcomm_v2_enabled": flashcomm_v2_enabled,
|
||||
"pad_size": pad_size,
|
||||
"padded_length": padded_length,
|
||||
|
||||
@@ -1171,7 +1171,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
positions = positions.squeeze(-1)
|
||||
else:
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.sp_enabled:
|
||||
if forward_context.flash_comm_v1_enabled:
|
||||
hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states)
|
||||
return hidden_states, positions
|
||||
|
||||
@@ -1191,7 +1191,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.sp_enabled:
|
||||
if forward_context.flash_comm_v1_enabled:
|
||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
last_hidden_states.contiguous(), True
|
||||
)
|
||||
|
||||
@@ -724,6 +724,19 @@ def matmul_allreduce_enable() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||
|
||||
|
||||
def enable_flash_comm_v1():
|
||||
return (
|
||||
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
|
||||
)
|
||||
|
||||
|
||||
def enable_sp_by_pass(vllm_config: VllmConfig):
|
||||
return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp
|
||||
|
||||
|
||||
def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
||||
global _ENABLE_SP
|
||||
if _ENABLE_SP is None:
|
||||
@@ -731,29 +744,12 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
_ENABLE_SP = (
|
||||
vllm_config.compilation_config.pass_config.enable_sp
|
||||
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
|
||||
)
|
||||
_ENABLE_SP = enable_sp_by_pass(vllm_config) or enable_flash_comm_v1()
|
||||
|
||||
if not _ENABLE_SP and enable_shared_expert_dp:
|
||||
_ENABLE_SP = True
|
||||
logger.info("shared_expert_dp requires enable_sp = True. has set enable_sp to True")
|
||||
|
||||
if not _ENABLE_SP:
|
||||
return _ENABLE_SP
|
||||
|
||||
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
|
||||
"Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1."
|
||||
)
|
||||
|
||||
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
|
||||
"Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
|
||||
)
|
||||
|
||||
return _ENABLE_SP
|
||||
|
||||
|
||||
@@ -1113,7 +1109,7 @@ def enable_dsa_cp() -> bool:
|
||||
is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk"
|
||||
)
|
||||
return bool(is_ds_v32 and enable_sp())
|
||||
return bool(is_ds_v32 and enable_flash_comm_v1())
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
|
||||
@@ -112,6 +112,7 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.utils import (
|
||||
enable_flash_comm_v1,
|
||||
enable_sp,
|
||||
is_drafter_moe_model,
|
||||
is_moe_model,
|
||||
@@ -1677,7 +1678,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.speculative_config,
|
||||
positions.shape[0],
|
||||
)
|
||||
if get_forward_context().sp_enabled and not isinstance(hidden_states, IntermediateTensors):
|
||||
if get_forward_context().flash_comm_v1_enabled and not isinstance(hidden_states, IntermediateTensors):
|
||||
hidden_states = self._all_gather_hidden_states_and_aux(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -1685,7 +1686,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if enable_sp():
|
||||
if enable_sp(self.vllm_config):
|
||||
return round_up(num_scheduled_tokens, tp_size)
|
||||
return num_scheduled_tokens
|
||||
|
||||
@@ -2223,7 +2224,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# tp_size; otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading
|
||||
# to incorrect memory estimation and potentially causing OOM.
|
||||
intermediate_tokens = num_tokens_padded
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_tokens = (num_tokens_padded + tp_size - 1) // tp_size
|
||||
if self.intermediate_tensors is None:
|
||||
|
||||
@@ -55,7 +55,7 @@ from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
from vllm_ascend.utils import (
|
||||
AscendDeviceType,
|
||||
check_ascend_device_type,
|
||||
enable_sp,
|
||||
enable_flash_comm_v1,
|
||||
get_ascend_device_type,
|
||||
register_ascend_customop,
|
||||
)
|
||||
@@ -376,7 +376,7 @@ class NPUWorker(WorkerBase):
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
|
||||
# it will conflict with the all-gather operation in flashcomm1.
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
all_gather_group = None
|
||||
else:
|
||||
all_gather_group = get_tp_group()
|
||||
@@ -393,7 +393,7 @@ class NPUWorker(WorkerBase):
|
||||
assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank
|
||||
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
|
||||
# it will conflict with the all-gather operation in flashcomm1.
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
all_gather_group = None
|
||||
else:
|
||||
all_gather_group = get_tp_group()
|
||||
|
||||
Reference in New Issue
Block a user