[Feat]support sequence parallelism by pass for VL models (#5632)
This commit is contained in:
@@ -48,6 +48,7 @@ from vllm_ascend.utils import (
|
||||
update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes,
|
||||
is_310p,
|
||||
enable_flash_comm_v1,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -198,32 +199,9 @@ class NPUPlatform(Platform):
|
||||
if not isinstance(ascend_compilation_config, dict)
|
||||
else ascend_compilation_config
|
||||
)
|
||||
ascend_config.update_compile_ranges_split_points()
|
||||
|
||||
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to "
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
npugraph_ex_config = ascend_config.npugraph_ex_config
|
||||
if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms:
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to {new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||
if model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
|
||||
|
||||
ascend_fusion_config = ascend_config.ascend_fusion_config
|
||||
@@ -417,15 +395,19 @@ class NPUPlatform(Platform):
|
||||
)
|
||||
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
|
||||
|
||||
if is_vl_model(vllm_config):
|
||||
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool(
|
||||
int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently, VL models doesn't support "
|
||||
"FLASHCOMM in vllm-ascend. We will fix this in the future. "
|
||||
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0."
|
||||
)
|
||||
if enable_flash_comm_v1():
|
||||
assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \
|
||||
Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \
|
||||
For optimal performance with VL models, we recommend enabling Sequence Parallelism \
|
||||
via --compilation-config '{"pass_config": {"enable_sp": true}}'."""
|
||||
|
||||
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
|
||||
"Flash Comm v1 is only supported when tp_size > 1."
|
||||
)
|
||||
|
||||
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
|
||||
"Flash Comm v1 requires enable_expert_parallel=True for MoE models."
|
||||
)
|
||||
|
||||
# Set "PYTORCH_NPU_ALLOC_CONF=expandable_segments:True" by default to optimize NPU memory management.
|
||||
# Find more details at https://docs.vllm.ai/projects/ascend/en/latest/faqs.html#how-to-handle-the-out-of-memory-issue
|
||||
@@ -626,16 +608,16 @@ class NPUPlatform(Platform):
|
||||
# communication methods.
|
||||
mmrs_fusion = True
|
||||
if is_moe_model(vllm_config):
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
mmrs_fusion = False
|
||||
else:
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
||||
pad_size = None
|
||||
padded_length = None
|
||||
if sp_enabled or flashcomm_v2_enabled:
|
||||
if flash_comm_v1_enabled or flashcomm_v2_enabled:
|
||||
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
||||
|
||||
if num_tokens is None and attn_metadata is not None:
|
||||
@@ -643,7 +625,7 @@ class NPUPlatform(Platform):
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and dp_metadata is not None:
|
||||
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if sp_enabled or flashcomm_v2_enabled:
|
||||
if flash_comm_v1_enabled or flashcomm_v2_enabled:
|
||||
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
else:
|
||||
@@ -664,7 +646,7 @@ class NPUPlatform(Platform):
|
||||
"capturing": capturing,
|
||||
"mmrs_fusion": mmrs_fusion,
|
||||
"num_tokens": num_tokens,
|
||||
"sp_enabled": sp_enabled,
|
||||
"flash_comm_v1_enabled": flash_comm_v1_enabled,
|
||||
"flashcomm_v2_enabled": flashcomm_v2_enabled,
|
||||
"pad_size": pad_size,
|
||||
"padded_length": padded_length,
|
||||
|
||||
Reference in New Issue
Block a user