[Fix] Fix DP-related padding logic (#2582)
### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.
The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.
For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).
Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.
Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao
- vLLM version: v0.10.1.1
- vLLM main:
c5d004aaaf
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -94,43 +94,6 @@ class MoECommMethod(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class DummyCommImpl(MoECommMethod):
|
||||
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Dummy prepare method that does nothing."""
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""Dummy finalize method that does nothing."""
|
||||
return hidden_states
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""Dummy implementation, make sure the output shapes are correct."""
|
||||
top_k_num = topk_ids.shape[1]
|
||||
permuted_hidden_states = hidden_states.repeat_interleave(top_k_num,
|
||||
dim=0)
|
||||
expert_tokens = torch.zeros((num_experts, ),
|
||||
dtype=torch.int64,
|
||||
device=hidden_states.device)
|
||||
group_list_type = 0
|
||||
return permuted_hidden_states, expert_tokens, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
"""Dummy implementation that does nothing."""
|
||||
pass
|
||||
|
||||
|
||||
class AllGatherCommImpl(MoECommMethod):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
DummyCommImpl,
|
||||
MC2CommImpl,
|
||||
MoECommMethod)
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
@@ -230,7 +229,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.moe_config.ep_group = get_ep_group()
|
||||
self.moe_config.mc2_group = get_mc2_group()
|
||||
|
||||
for method in {AllGatherCommImpl, DummyCommImpl, MC2CommImpl}:
|
||||
for method in {AllGatherCommImpl, MC2CommImpl}:
|
||||
setattr(
|
||||
self, method.__name__.lower(),
|
||||
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
||||
@@ -241,8 +240,11 @@ class AscendFusedMoE(FusedMoE):
|
||||
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
||||
if not self.moe_config.use_ep and moe_comm_method_name != "dummycommimpl":
|
||||
|
||||
# TODO: Can we refactor this logic to model_runner?
|
||||
if not self.moe_config.use_ep:
|
||||
moe_comm_method_name = "allgathercommimpl"
|
||||
|
||||
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
|
||||
|
||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||
|
||||
@@ -70,7 +70,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
register_torchair_model()
|
||||
torchair_quant_method_register()
|
||||
|
||||
def _get_forward_metadata_across_dp_and_pad(
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
"""Override from NPUModelRunner to pad num_tokens"""
|
||||
@@ -81,8 +81,17 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
|
||||
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
|
||||
num_tokens, with_prefill, enable_dbo)
|
||||
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
num_tokens_across_dp[self.dp_rank] = num_tokens
|
||||
num_tokens_across_dp[-2] = int(with_prefill)
|
||||
num_tokens_across_dp[-1] = int(not enable_dbo)
|
||||
dist.all_reduce(num_tokens_across_dp,
|
||||
group=get_dp_group().device_group)
|
||||
with_prefill = bool(num_tokens_across_dp[-2])
|
||||
enable_dbo = not bool(num_tokens_across_dp[-1])
|
||||
num_tokens_across_dp = num_tokens_across_dp[:-2]
|
||||
|
||||
if not with_prefill:
|
||||
max_num_token = num_tokens_across_dp.max().item()
|
||||
|
||||
@@ -43,8 +43,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
||||
get_tp_group,
|
||||
is_global_first_rank)
|
||||
from vllm.forward_context import (BatchDescriptor, DPMetadata,
|
||||
get_forward_context)
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
@@ -373,10 +372,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.moe_comm_method = "mc2"
|
||||
self.fallback_moe_comm_method = "allgather"
|
||||
self.dummy_moe_comm_method = "dummy"
|
||||
|
||||
def _use_aclgraph(self) -> bool:
|
||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
||||
|
||||
@@ -594,32 +589,43 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Refresh batch metadata with any pending updates.
|
||||
self.input_batch.refresh_metadata()
|
||||
|
||||
def _get_forward_metadata_across_dp(
|
||||
self, num_tokens: int, with_prefill: bool,
|
||||
enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]:
|
||||
|
||||
# Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
|
||||
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
num_tokens_across_dp[self.dp_rank] = num_tokens
|
||||
num_tokens_across_dp[-2] = int(with_prefill)
|
||||
num_tokens_across_dp[-1] = int(not enable_dbo)
|
||||
dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group)
|
||||
with_prefill = bool(num_tokens_across_dp[-2])
|
||||
enable_dbo = not bool(num_tokens_across_dp[-1])
|
||||
num_tokens_across_dp = num_tokens_across_dp[:-2]
|
||||
return num_tokens_across_dp, with_prefill, enable_dbo
|
||||
|
||||
def _get_forward_metadata_across_dp_and_pad(
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
if self.dp_size == 1:
|
||||
if self.dp_size == 1 or self.vllm_config.model_config.enforce_eager:
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
|
||||
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
|
||||
num_tokens, with_prefill, enable_dbo)
|
||||
return num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
|
||||
# Sync num_tokens, with_prefill, enable_dbo across dp ranks
|
||||
num_tokens_tensor = torch.tensor([
|
||||
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
|
||||
flags_tensor = torch.tensor(
|
||||
[int(with_prefill), int(not enable_dbo)],
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
|
||||
packed_tensor = torch.cat([num_tokens_tensor, flags_tensor])
|
||||
|
||||
dist.all_reduce(packed_tensor, group=get_dp_group().device_group)
|
||||
|
||||
# Unpack the results
|
||||
num_tokens_across_dp = packed_tensor[:-2]
|
||||
synced_flags = packed_tensor[-2:]
|
||||
|
||||
max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
|
||||
global_with_prefill = bool(synced_flags[0])
|
||||
global_enable_dbo = not bool(synced_flags[1])
|
||||
|
||||
# Create a tensor for num_tokens_after_padding
|
||||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
|
||||
self.dp_size,
|
||||
device="npu",
|
||||
dtype=torch.int32)
|
||||
|
||||
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo
|
||||
|
||||
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
|
||||
attn_state: AscendAttentionState,
|
||||
@@ -1025,32 +1031,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
return mm_embeds
|
||||
|
||||
def get_dp_padding(self,
|
||||
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
|
||||
"""This implementation is derived from vLLM's `GPUModelRunner.get_dp_padding`.
|
||||
Please note that vLLM may refactor or modify this function over time,
|
||||
at present, we are using the version introduced in PR #18935.
|
||||
"""
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# For DP: Don't pad when setting enforce_eager.
|
||||
# This lets us set enforce_eager on the prefiller in a P/D setup and
|
||||
# still use ACL graphs (enabled by this padding) on the decoder.
|
||||
|
||||
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
|
||||
# Early exit.
|
||||
return 0, None
|
||||
|
||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||
num_tokens, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
|
||||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
|
||||
dp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@@ -1060,24 +1040,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch.Tensor, int, torch.Tensor, SpecDecodeMetadata,
|
||||
Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
# Check input valid
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
if (self.use_aclgraph and total_num_scheduled_tokens
|
||||
<= self.aclgraph_batch_sizes[-1]):
|
||||
# Add padding to the batch size.
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
total_num_scheduled_tokens)
|
||||
else:
|
||||
# Eager mode.
|
||||
num_input_tokens = total_num_scheduled_tokens
|
||||
|
||||
# Padding for DP
|
||||
num_pad, num_tokens_across_dp_native = self.get_dp_padding(
|
||||
num_input_tokens)
|
||||
num_input_tokens += num_pad
|
||||
|
||||
self.attn_metadata_builder.reorder_batch(self.input_batch,
|
||||
scheduler_output)
|
||||
@@ -1098,6 +1064,41 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||
num_tokens)
|
||||
|
||||
if (self.use_aclgraph and total_num_scheduled_tokens
|
||||
<= self.aclgraph_batch_sizes[-1]):
|
||||
# Add padding to the batch size.
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
total_num_scheduled_tokens)
|
||||
else:
|
||||
# Eager mode.
|
||||
num_input_tokens = total_num_scheduled_tokens
|
||||
|
||||
# Get the attention state.
|
||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens)
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
# Determine if it's a splitfuse batch
|
||||
with_prefill = attn_state not in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
||||
attn_state,
|
||||
total_num_scheduled_tokens)
|
||||
|
||||
# Get info across DP ranks.
|
||||
# NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP,
|
||||
# Otherwise, it's just max_tokens_across_dp_cpu
|
||||
(maybe_padded_num_tokens, num_tokens_across_dp, with_prefill,
|
||||
enable_dbo) = self._sync_metadata_across_dp(num_input_tokens,
|
||||
with_prefill, enable_dbo)
|
||||
|
||||
if self.use_aclgraph:
|
||||
# When using TorchAir with DP, we have other plans for padding
|
||||
num_input_tokens = maybe_padded_num_tokens
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
@@ -1166,20 +1167,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
||||
|
||||
with_prefill = attn_state not in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
|
||||
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
||||
attn_state,
|
||||
total_num_scheduled_tokens)
|
||||
|
||||
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
|
||||
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
|
||||
total_num_scheduled_tokens, with_prefill, enable_dbo)
|
||||
self.with_prefill = with_prefill
|
||||
self.num_tokens_across_dp = num_tokens_across_dp
|
||||
self._update_graph_pad_size(with_prefill, padded_num_tokens_across_dp)
|
||||
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||
@@ -1247,7 +1237,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions = self.positions[:num_input_tokens]
|
||||
input_ids, positions = self._update_input_ids_and_positions(
|
||||
input_ids, positions, num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp)
|
||||
maybe_padded_num_tokens)
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
@@ -1262,14 +1252,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
# NOTE: Currently this padding logic is really messy,
|
||||
# MC2 may not be available in eager mode
|
||||
# TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
|
||||
if self.use_aclgraph:
|
||||
num_tokens_across_dp = num_tokens_across_dp_native
|
||||
else:
|
||||
num_input_tokens = padded_num_tokens_across_dp
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if not use_spec_decode:
|
||||
@@ -1297,12 +1279,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return (attn_metadata, positions, num_scheduled_tokens,
|
||||
num_input_tokens, num_tokens_across_dp,
|
||||
padded_num_tokens_across_dp, logits_indices,
|
||||
spec_decode_metadata, input_ids, inputs_embeds,
|
||||
intermediate_tensors)
|
||||
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
|
||||
input_ids, inputs_embeds, intermediate_tensors)
|
||||
|
||||
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||
padded_num_tokens_across_dp,
|
||||
maybe_padded_num_tokens,
|
||||
input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds):
|
||||
@@ -1345,7 +1326,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _update_input_ids_and_positions(self, input_ids, positions,
|
||||
num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp):
|
||||
maybe_padded_num_tokens):
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
return input_ids, positions
|
||||
@@ -1632,6 +1613,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_connector_output=kv_connector_output,
|
||||
)
|
||||
|
||||
def _select_moe_comm_method(self, num_tokens: int) -> str:
|
||||
return ("mc2"
|
||||
if num_tokens <= self.mc2_tokens_capacity else "allgather")
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -1649,15 +1634,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
(attn_metadata, positions, num_scheduled_tokens_np,
|
||||
num_input_tokens, num_tokens_across_dp,
|
||||
padded_num_tokens_across_dp, logits_indices, spec_decode_metadata,
|
||||
input_ids, inputs_embeds,
|
||||
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
|
||||
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
|
||||
intermediate_tensors) = (self._prepare_inputs(
|
||||
scheduler_output, intermediate_tensors))
|
||||
|
||||
moe_comm_method = (self.moe_comm_method
|
||||
if num_input_tokens <= self.mc2_tokens_capacity else
|
||||
self.fallback_moe_comm_method)
|
||||
moe_comm_method = self._select_moe_comm_method(num_input_tokens)
|
||||
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=False)
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
@@ -1680,9 +1663,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
hidden_states = self._generate_process_reqs_hidden_states(
|
||||
attn_metadata, self.with_prefill,
|
||||
padded_num_tokens_across_dp, input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
attn_metadata, self.with_prefill, maybe_padded_num_tokens,
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
||||
@@ -1988,7 +1970,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens: int,
|
||||
with_prefill: bool = False,
|
||||
is_torchair_compile: bool = False,
|
||||
moe_comm_method: str = "dummy",
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
force_attention: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
@@ -2003,13 +1984,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
# Padding for DP
|
||||
num_pad, num_tokens_across_dp_native = self.get_dp_padding(num_tokens)
|
||||
# num_tokens += num_pad ## Uncomment this after TorchAir is removed
|
||||
|
||||
# Padding for DP (for TorchAir)
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self._get_forward_metadata_across_dp_and_pad(
|
||||
num_tokens, with_prefill, False)
|
||||
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
|
||||
|
||||
moe_comm_method = self._select_moe_comm_method(num_tokens)
|
||||
|
||||
# If cudagraph_mode.decode_mode() == FULL and
|
||||
# cudagraph_mode.seperate_routine(). This means that we are using
|
||||
@@ -2518,12 +2496,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self._dummy_run(num_tokens,
|
||||
aclgraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=uniform_decode,
|
||||
moe_comm_method=self.moe_comm_method)
|
||||
uniform_decode=uniform_decode)
|
||||
self._dummy_run(num_tokens,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
uniform_decode=uniform_decode,
|
||||
moe_comm_method=self.moe_comm_method)
|
||||
uniform_decode=uniform_decode)
|
||||
|
||||
def _capture_model(self):
|
||||
if not self.use_aclgraph:
|
||||
|
||||
@@ -194,7 +194,7 @@ class MtpProposer:
|
||||
# torch mode need to update num_tokens_across_dp
|
||||
# TODO: adapt enable_dbo later
|
||||
(num_input_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self.runner._get_forward_metadata_across_dp_and_pad(
|
||||
_) = self.runner._sync_metadata_across_dp(
|
||||
num_tokens, self.runner.with_prefill, False)
|
||||
attn_metadata.slot_mapping = target_slot_mapping
|
||||
else:
|
||||
@@ -281,8 +281,8 @@ class MtpProposer:
|
||||
if not self.torchair_graph_enabled:
|
||||
# TODO: adapt enable_dbo later
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self.runner._get_forward_metadata_across_dp_and_pad(
|
||||
num_tokens, with_prefill, False)
|
||||
_) = self.runner._sync_metadata_across_dp(num_tokens,
|
||||
with_prefill, False)
|
||||
is_running_torchair = self.torchair_graph_enabled and \
|
||||
not with_prefill
|
||||
|
||||
|
||||
Reference in New Issue
Block a user