[Refactor] [SP]The sequence parallelism characteristics in the MoE and Dense models are integrated into a single solution. (#3085)
What this PR does / why we need it?
there are two sets of sp implementations for moe and dense models. One
is called sequence_parallelism, and the other is flashcomm_v1.
We did the following things:
Merge two sets of code with the same implementation into one.
Remove the implementation of sequence_parallelism, as this solution
cannot support aclgraph.
Does this PR introduce any user-facing change?
No
How was this patch tested?
e2e&ut
- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -216,7 +216,9 @@ class AscendFusedMoE(FusedMoE):
|
||||
|
||||
forward_context = get_forward_context()
|
||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states, router_logits=router_logits)
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
replace_allreduce=forward_context.sp_enabled)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
|
||||
@@ -21,8 +21,7 @@ from typing import Any, Callable, Optional
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -42,7 +41,6 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
|
||||
get_all_reduce_merge_state,
|
||||
get_rm_router_logits_state, is_310p,
|
||||
@@ -360,8 +358,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
top_k: Optional[int] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
gate=None,
|
||||
replace_allreduce: bool = False,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None):
|
||||
replace_allreduce: bool = False):
|
||||
|
||||
assert self.quant_method is not None
|
||||
|
||||
@@ -379,13 +376,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
shared_hidden_states = shared_experts(hidden_states)
|
||||
|
||||
enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if enable_sp:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
|
||||
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
|
||||
mc2_mask = chunk_mc2_mask[tp_rank]
|
||||
if forward_context.sp_enabled:
|
||||
replace_allreduce = True
|
||||
|
||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||
|
||||
@@ -48,8 +48,9 @@ from vllm.distributed.parallel_state import get_tp_group
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable,
|
||||
mlp_tp_enable, oproj_tp_enable)
|
||||
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
|
||||
matmul_allreduce_enable, mlp_tp_enable,
|
||||
oproj_tp_enable)
|
||||
|
||||
|
||||
class CustomTensorParallelOp:
|
||||
@@ -82,10 +83,17 @@ class CustomTensorParallelOp:
|
||||
self.skip_bias_add = self.layer.skip_bias_add
|
||||
self.return_bias = self.layer.return_bias
|
||||
self.quant_method = self.layer.quant_method
|
||||
self.prefix = self.layer.prefix
|
||||
|
||||
def apply_impl(self, input_):
|
||||
raise NotImplementedError
|
||||
|
||||
# Replace layer.forward to customize the layer computation process.
|
||||
def apply(self, input_):
|
||||
raise NotImplementedError
|
||||
output, output_bias = self.apply_impl(input_)
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomColumnParallelOp(CustomTensorParallelOp):
|
||||
@@ -113,6 +121,14 @@ class CustomRowParallelOp(CustomTensorParallelOp):
|
||||
self.reduce_results = self.layer.reduce_results
|
||||
self.input_size_per_partition = self.layer.input_size_per_partition
|
||||
|
||||
def apply(self, input_):
|
||||
output, output_bias = self.apply_impl(input_)
|
||||
if dense_optim_enable():
|
||||
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
@@ -123,7 +139,7 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||
def comm_group(self):
|
||||
return get_mlp_tp_group()
|
||||
|
||||
def apply(
|
||||
def apply_impl(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
@@ -134,14 +150,12 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||
output = self.quant_method.apply(self.layer, input_parallel, bias)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class DenseOptimMergedColumnParallelOp(CustomColumnParallelOp):
|
||||
class SequenceMergedColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def apply(
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
@@ -164,18 +178,16 @@ class DenseOptimMergedColumnParallelOp(CustomColumnParallelOp):
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class DenseOptimQKVParallelOp(CustomColumnParallelOp):
|
||||
class SequenceQKVParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def __init__(self, layer, prefix):
|
||||
super().__init__(layer)
|
||||
self.prefix = prefix
|
||||
|
||||
def apply(
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
@@ -201,8 +213,6 @@ class DenseOptimQKVParallelOp(CustomColumnParallelOp):
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
@@ -215,7 +225,7 @@ class MLPRowParallelOp(CustomRowParallelOp):
|
||||
def comm_group(self):
|
||||
return get_mlp_tp_group()
|
||||
|
||||
def apply(
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
@@ -234,8 +244,6 @@ class MLPRowParallelOp(CustomRowParallelOp):
|
||||
output = self.comm_group.reduce_scatter(output_parallel, 0)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
@@ -248,7 +256,7 @@ class OProjRowParallelOp(CustomRowParallelOp):
|
||||
def comm_group(self):
|
||||
return get_otp_group()
|
||||
|
||||
def apply(
|
||||
def apply_impl(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
@@ -294,8 +302,6 @@ class OProjRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
# Handle bias return based on configuration
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def update_attrs(self):
|
||||
@@ -311,7 +317,7 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
||||
super().__init__(layer)
|
||||
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
|
||||
|
||||
def apply(
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
@@ -335,8 +341,6 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
||||
bias=bias_)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
@classmethod
|
||||
@@ -359,13 +363,13 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
||||
self.weight_t = self.layer.weight.t()
|
||||
|
||||
|
||||
class DenseOptimRowParallelOp(CustomRowParallelOp):
|
||||
class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer, prefix):
|
||||
super().__init__(layer)
|
||||
self.prefix = prefix
|
||||
|
||||
def apply(
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
@@ -391,12 +395,8 @@ class DenseOptimRowParallelOp(CustomRowParallelOp):
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
|
||||
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def update_attrs(self):
|
||||
@@ -407,23 +407,22 @@ class DenseOptimRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def get_column_parallel_op(
|
||||
disable_tp, prefix, layer
|
||||
) -> Tuple[
|
||||
Optional[Union[MLPColumnParallelOp, DenseOptimMergedColumnParallelOp,
|
||||
DenseOptimQKVParallelOp]], int, int]:
|
||||
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
|
||||
SequenceQKVParallelOp]], int, int]:
|
||||
if disable_tp:
|
||||
return None, 0, 1
|
||||
|
||||
custom_op: Optional[Union[
|
||||
MLPColumnParallelOp,
|
||||
DenseOptimMergedColumnParallelOp,
|
||||
DenseOptimQKVParallelOp,
|
||||
SequenceMergedColumnParallelOp,
|
||||
SequenceQKVParallelOp,
|
||||
]] = None
|
||||
if "gate_up_proj" in prefix and mlp_tp_enable():
|
||||
custom_op = MLPColumnParallelOp(layer)
|
||||
elif "gate_up_proj" in prefix and dense_optim_enable():
|
||||
custom_op = DenseOptimMergedColumnParallelOp(layer)
|
||||
elif dense_optim_enable():
|
||||
custom_op = DenseOptimQKVParallelOp(layer, prefix)
|
||||
elif "gate_up_proj" in prefix and enable_sp():
|
||||
custom_op = SequenceMergedColumnParallelOp(layer)
|
||||
elif enable_sp():
|
||||
custom_op = SequenceQKVParallelOp(layer, prefix)
|
||||
|
||||
if custom_op is not None:
|
||||
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
||||
@@ -435,21 +434,21 @@ def get_row_parallel_op(
|
||||
disable_tp, prefix, layer
|
||||
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp,
|
||||
DenseOptimRowParallelOp]], int, int]:
|
||||
SequenceRowParallelOp]], int, int]:
|
||||
if disable_tp:
|
||||
return None, 0, 1
|
||||
|
||||
custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp,
|
||||
DenseOptimRowParallelOp]] = None
|
||||
SequenceRowParallelOp]] = None
|
||||
if "down_proj" in prefix and mlp_tp_enable():
|
||||
custom_op = MLPRowParallelOp(layer)
|
||||
elif "o_proj" in prefix and oproj_tp_enable():
|
||||
custom_op = OProjRowParallelOp(layer)
|
||||
elif matmul_allreduce_enable():
|
||||
custom_op = MatmulAllreduceRowParallelOp(layer)
|
||||
elif dense_optim_enable():
|
||||
custom_op = DenseOptimRowParallelOp(layer, prefix)
|
||||
elif enable_sp():
|
||||
custom_op = SequenceRowParallelOp(layer, prefix)
|
||||
|
||||
if custom_op is not None:
|
||||
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
||||
|
||||
@@ -133,11 +133,15 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
if self.tp_size > 1:
|
||||
# Also slice mc2_mask
|
||||
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
|
||||
mc2_mask = split_mc2_mask[self.tp_rank]
|
||||
|
||||
if not self.replace_allreduce:
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
target_pad_length = forward_context.padded_num_tokens
|
||||
pad_size = target_pad_length - self.num_tokens
|
||||
|
||||
@@ -149,23 +153,16 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
# Slice across TP ranks
|
||||
if self.tp_size > 1:
|
||||
if not self.enable_shared_expert_dp:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
self.split_hidden_states = split_hidden_states # Save for finalize
|
||||
|
||||
# Also slice mc2_mask
|
||||
split_mc2_mask = torch.tensor_split(mc2_mask,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
mc2_mask = split_mc2_mask[self.tp_rank]
|
||||
if self.tp_size > 1 and not self.enable_shared_expert_dp:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
self.split_hidden_states = split_hidden_states # Save for finalize
|
||||
|
||||
return hidden_states, router_logits, mc2_mask
|
||||
|
||||
|
||||
@@ -20,10 +20,9 @@ def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||
return residual
|
||||
|
||||
if x.size(0) != residual.size(0):
|
||||
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
|
||||
assert flashcomm_v1_enabled is True, (
|
||||
"Currently, this situation only occurs "
|
||||
"when flashcomm_v1 is enabled")
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
assert sp_enabled is True, ("Currently, this situation only occurs "
|
||||
"when sp is enabled")
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||
@@ -41,8 +40,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
|
||||
except AssertionError:
|
||||
return x
|
||||
|
||||
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
|
||||
if flashcomm_v1_enabled and label:
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled and label:
|
||||
x = tensor_model_parallel_all_gather(x, 0)
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
@@ -56,8 +55,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
|
||||
if flashcomm_v1_enabled:
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled:
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
|
||||
class MetadataForPadding:
|
||||
|
||||
def __init__(self,
|
||||
padding_flag=False,
|
||||
lengths_sum_padding=0,
|
||||
lengths_sum_unpadding=0,
|
||||
pad_size=0,
|
||||
not_dummy_and_is_prefill=False):
|
||||
self.padding_flag = padding_flag
|
||||
self.not_dummy_and_is_prefill = not_dummy_and_is_prefill
|
||||
|
||||
self.lengths_sum_padding = lengths_sum_padding
|
||||
self.lengths_sum_unpadding = lengths_sum_unpadding
|
||||
self.pad_size = pad_size
|
||||
|
||||
self.tp_size = get_tp_group().world_size
|
||||
self.tp_rank_in_group = get_tp_group().rank_in_group
|
||||
|
||||
assert self.lengths_sum_padding % self.tp_size == 0
|
||||
self.slice_size = self.lengths_sum_padding // self.tp_size
|
||||
|
||||
self.mc2_mask = torch.zeros(
|
||||
self.lengths_sum_padding,
|
||||
dtype=torch.bool,
|
||||
device=NPUPlatform.device_type,
|
||||
)
|
||||
self.mc2_mask[:lengths_sum_unpadding] = True
|
||||
|
||||
def padding_aligned_reduce_scatter(self,
|
||||
data: torch.Tensor) -> torch.Tensor:
|
||||
if self.padding_flag:
|
||||
pad_size = self.pad_size
|
||||
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||
else:
|
||||
padded_data = data
|
||||
padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(
|
||||
padded_data, 0)
|
||||
|
||||
return padded_data_reduce_scatter
|
||||
|
||||
def allgather_unpadding_aligned(self,
|
||||
padded_data: torch.Tensor) -> torch.Tensor:
|
||||
padded_data_allgather = tensor_model_parallel_all_gather(
|
||||
padded_data, 0)
|
||||
if self.padding_flag:
|
||||
lengths_sum_unpadding = self.lengths_sum_unpadding
|
||||
unpadding_data = padded_data_allgather[:lengths_sum_unpadding]
|
||||
else:
|
||||
unpadding_data = padded_data_allgather
|
||||
return unpadding_data
|
||||
|
||||
def padding_slice(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
padded_data = F.pad(data, (0, 0, 0, self.pad_size))
|
||||
start = self.tp_rank_in_group * self.slice_size
|
||||
end = start + self.slice_size
|
||||
slice_data = padded_data[start:end]
|
||||
|
||||
return slice_data
|
||||
|
||||
def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if self.padding_flag:
|
||||
pad_size = self.pad_size
|
||||
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||
else:
|
||||
padded_data = data
|
||||
# padded_data = data
|
||||
padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0)
|
||||
|
||||
padded_data_reduce_scatter = padded_data[self.tp_rank_in_group]
|
||||
|
||||
return padded_data_reduce_scatter
|
||||
|
||||
|
||||
def init_metadata_for_sp(input_ids, enable_sequence_parallelism):
|
||||
if not enable_sequence_parallelism:
|
||||
return MetadataForPadding(padding_flag=False,
|
||||
not_dummy_and_is_prefill=False)
|
||||
|
||||
is_perifll = 0
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if attn_metadata is not None:
|
||||
if hasattr(attn_metadata,
|
||||
'is_only_prefill') and attn_metadata.is_only_prefill:
|
||||
is_perifll = 1
|
||||
if hasattr(attn_metadata,
|
||||
'num_prefills') and attn_metadata.num_prefills > 0:
|
||||
is_perifll = 1
|
||||
|
||||
if is_perifll:
|
||||
lengths_sum_unpadding = input_ids.shape[0]
|
||||
lengths_sum_padding = (
|
||||
(lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
|
||||
if lengths_sum_unpadding == lengths_sum_padding:
|
||||
padding_flag = False
|
||||
else:
|
||||
padding_flag = True
|
||||
pad_size = lengths_sum_padding - lengths_sum_unpadding
|
||||
_metadata_for_padding = MetadataForPadding(
|
||||
lengths_sum_unpadding=lengths_sum_unpadding,
|
||||
lengths_sum_padding=lengths_sum_padding,
|
||||
padding_flag=padding_flag,
|
||||
pad_size=pad_size,
|
||||
not_dummy_and_is_prefill=True)
|
||||
|
||||
return _metadata_for_padding
|
||||
|
||||
return MetadataForPadding(padding_flag=False,
|
||||
not_dummy_and_is_prefill=False)
|
||||
Reference in New Issue
Block a user