[Feat]support sequence parallelism by pass for VL models (#5632)
This commit is contained in:
@@ -70,7 +70,7 @@ from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
||||
from vllm_ascend.utils import (
|
||||
enable_dsa_cp,
|
||||
enable_dsa_cp_with_layer_shard,
|
||||
enable_sp,
|
||||
enable_flash_comm_v1,
|
||||
flashcomm2_enable,
|
||||
get_flashcomm2_reorgnized_batch_ids,
|
||||
get_weight_prefetch_method,
|
||||
@@ -368,7 +368,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if not forward_context.sp_enabled:
|
||||
if not forward_context.flash_comm_v1_enabled:
|
||||
# flashcomm1 not enabled
|
||||
output = get_tp_group().all_gather(output, 0)
|
||||
if num_padding_tokens > 0:
|
||||
@@ -466,7 +466,7 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||
|
||||
# Trigger async broadcast before matmul to overlap communication.
|
||||
@@ -515,15 +515,15 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
assert self.quant_method is not None
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
|
||||
mmrs_fusion = forward_context.mmrs_fusion
|
||||
except AssertionError:
|
||||
sp_enabled = False
|
||||
flash_comm_v1_enabled = False
|
||||
mmrs_fusion = False
|
||||
|
||||
x = input_parallel
|
||||
|
||||
if not sp_enabled:
|
||||
if not flash_comm_v1_enabled:
|
||||
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
|
||||
return tensor_model_parallel_all_reduce(output_parallel)
|
||||
|
||||
@@ -649,7 +649,7 @@ def _get_column_parallel_op(
|
||||
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
|
||||
return Flashcomm2OshardQKVParallelOp(layer)
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
sp_column_prefix = [
|
||||
@@ -688,7 +688,7 @@ def _get_row_parallel_op(
|
||||
if flashcomm2_enable():
|
||||
if "o_proj" in prefix or "out_proj" in prefix:
|
||||
return Flashcomm2OProjRowParallelOp(layer)
|
||||
if enable_sp():
|
||||
if enable_flash_comm_v1():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
sp_row_prefixes = [
|
||||
|
||||
Reference in New Issue
Block a user