The main purposes of this PR are as follows:
1. Remove the multicast-related code;
Reason:
1. In the scenario like a2 Dual-System Back-to-Back Networking,the
performance is worse than all_gather. Before the modification, in e2e
test, it was 3 tps; after the modification, it is 10 tps.
2. At the same time, we usually enable the SP feature,it is consistent
with the current logic.
3. The advantage of broadcast communication lies in the fact that it
does not suffer from uneven DP load and does not require the prefill ACL
graph to be enabled. But we support prefill Acl graph recently.
So we think there is no need to maintain the multicast as one choice in
moe communication.
Performance benefits are as follows:
When not enable_flashcomm1, TTFT remains relatively stable at around
43000ms, which is approximately 15000ms faster than before the
modification.
When enable_flashcomm1, there is no diffenence, TTFT remains relatively
stable at around 29000ms.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Signed-off-by: weijinqian0 <1184188277@qq.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
548 lines
26 KiB
Python
548 lines
26 KiB
Python
import types
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchair
|
|
from torchair import patch_for_hcom
|
|
from vllm.config import (CUDAGraphMode, VllmConfig,
|
|
get_layers_from_vllm_config, set_current_vllm_config)
|
|
from vllm.forward_context import BatchDescriptor, get_forward_context
|
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
from vllm.model_executor.model_loader import get_model_loader
|
|
from vllm.model_executor.model_loader.utils import \
|
|
process_weights_after_loading
|
|
from vllm.utils.torch_utils import set_default_torch_dtype
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
|
from vllm_ascend.spec_decode import MtpProposer
|
|
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
|
|
TorchairDeepSeekMTP
|
|
from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
|
|
TorchairCommonAttentionMetadata)
|
|
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
|
|
|
PADDING_SLOT_ID = -1
|
|
|
|
|
|
class TorchairMtpProposer(MtpProposer):
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
device,
|
|
runner,
|
|
):
|
|
super().__init__(vllm_config, device, runner)
|
|
self.torchair_compiled_model = None # type: ignore
|
|
self.torchair_compiled_models = {} # type: ignore
|
|
|
|
def load_model(self, model) -> None:
|
|
loader = get_model_loader(self.vllm_config.load_config)
|
|
|
|
target_attn_layer_names = set(
|
|
get_layers_from_vllm_config(self.vllm_config,
|
|
AttentionLayerBase).keys())
|
|
draft_model_config = \
|
|
self.vllm_config.speculative_config.draft_model_config
|
|
target_device = self.vllm_config.device_config.device
|
|
|
|
with set_default_torch_dtype(
|
|
draft_model_config.dtype), set_current_vllm_config(
|
|
self.vllm_config):
|
|
|
|
self.model = TorchairDeepSeekMTP(
|
|
vllm_config=self.vllm_config).to(target_device)
|
|
|
|
draft_attn_layer_names = (get_layers_from_vllm_config(
|
|
self.vllm_config, AttentionLayerBase).keys() -
|
|
target_attn_layer_names)
|
|
|
|
assert len(draft_attn_layer_names) == 1
|
|
self.attn_layer_name = list(draft_attn_layer_names)
|
|
|
|
self.model.load_weights(
|
|
loader.get_all_weights(
|
|
self.vllm_config.speculative_config.draft_model_config,
|
|
self.model))
|
|
process_weights_after_loading(self.model, draft_model_config,
|
|
target_device)
|
|
|
|
@torch.inference_mode()
|
|
def dummy_run(self,
|
|
num_tokens: int,
|
|
with_prefill: bool = False,
|
|
skip_attn: bool = False,
|
|
num_reqs: int = 0,
|
|
num_tokens_across_dp=None,
|
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
|
batch_descriptor=None) -> None:
|
|
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
|
|
|
|
if not with_prefill:
|
|
skip_attn = False
|
|
if skip_attn:
|
|
attn_metadata = None
|
|
else:
|
|
common_attn_metadata = TorchairCommonAttentionMetadata(
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=1,
|
|
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
|
attn_mask=self.runner.attn_mask,
|
|
spec_attn_mask=self.runner.spec_attn_mask,
|
|
decode_token_per_req=self.runner.decode_token_per_req,
|
|
)
|
|
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
|
|
common_attn_metadata)
|
|
|
|
input_ids = self.input_ids[:num_tokens]
|
|
positions = self.positions[:num_tokens]
|
|
previous_hidden_states = self.hidden_states[:num_tokens]
|
|
for _ in range(self.num_speculative_tokens):
|
|
with set_ascend_forward_context(
|
|
attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=num_tokens,
|
|
with_prefill=with_prefill,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
|
moe_comm_type=moe_comm_type,
|
|
in_profile_run=self.runner.in_profile_run,
|
|
num_actual_tokens=0):
|
|
if not with_prefill:
|
|
assert attn_metadata is not None
|
|
torch._dynamo.mark_static(input_ids)
|
|
torch._dynamo.mark_static(positions)
|
|
torch._dynamo.mark_static(previous_hidden_states)
|
|
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
|
torch._dynamo.mark_static(
|
|
attn_metadata.decode.input_positions)
|
|
if hasattr(attn_metadata.decode, "sin"):
|
|
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
|
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
|
torch._dynamo.mark_static(get_forward_context().mc2_mask)
|
|
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
|
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
|
|
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
|
num_tokens)
|
|
torchair_compiled_model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
hidden_states=previous_hidden_states,
|
|
inputs_embeds=None,
|
|
intermediate_tensors=None,
|
|
attn_metadata=attn_metadata,
|
|
kv_caches=self.runner.kv_caches[-1:],
|
|
spec_step_idx=0)
|
|
else:
|
|
self.model(input_ids=input_ids,
|
|
positions=positions,
|
|
hidden_states=previous_hidden_states)
|
|
if with_prefill:
|
|
break
|
|
|
|
def generate_token_ids(self,
|
|
valid_sampled_token_ids: list[list[int]],
|
|
sampling_metadata: SamplingMetadata = None,
|
|
scheduler_output: SchedulerOutput = None,
|
|
spec_decode_metadata: SpecDecodeMetadata = None,
|
|
positions: torch.Tensor = None,
|
|
num_scheduled_tokens: int = 0,
|
|
hidden_states: torch.Tensor = None,
|
|
attn_metadata=None,
|
|
aux_hidden_states: torch.Tensor = None):
|
|
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
|
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
|
next_token_ids: list[int] = []
|
|
for i, token_ids in enumerate(valid_sampled_token_ids):
|
|
if token_ids:
|
|
# Common case.
|
|
next_token_id = token_ids[-1]
|
|
else:
|
|
# Partial prefill (rare case).
|
|
# Get the next token id from the request state.
|
|
req_id = self.runner.input_batch.req_ids[i]
|
|
req_state = self.runner.requests[req_id]
|
|
seq_len = (req_state.num_computed_tokens +
|
|
scheduler_output.num_scheduled_tokens[req_id])
|
|
next_token_id = req_state.get_token_id(seq_len)
|
|
next_token_ids.append(next_token_id)
|
|
next_token_ids = torch.tensor(next_token_ids,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
accepted_token_indices = None
|
|
if spec_decode_metadata is None:
|
|
# input_ids can be None for multimodal models.
|
|
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
|
|
target_positions = positions[:num_scheduled_tokens]
|
|
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
|
target_slot_mapping = attn_metadata.slot_mapping
|
|
cu_num_tokens = attn_metadata.query_start_loc
|
|
else:
|
|
# TODO(woosuk): Refactor this.
|
|
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
|
num_rejected_tokens = [
|
|
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
|
for i, n in enumerate(num_draft_tokens)
|
|
]
|
|
num_rejected_tokens = torch.tensor(
|
|
num_rejected_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
cu_num_tokens, accepted_token_indices, target_token_ids, \
|
|
target_positions, target_hidden_states, target_slot_mapping = self._torchair_prepare_inputs(
|
|
attn_metadata.query_start_loc,
|
|
num_rejected_tokens,
|
|
self.runner.input_ids[:num_scheduled_tokens],
|
|
positions[:num_scheduled_tokens],
|
|
hidden_states[:num_scheduled_tokens],
|
|
attn_metadata.slot_mapping[:num_scheduled_tokens],
|
|
)
|
|
|
|
draft_token_ids = self._propose_torchair(
|
|
target_token_ids=target_token_ids,
|
|
target_positions=target_positions,
|
|
target_hidden_states=target_hidden_states,
|
|
target_slot_mapping=target_slot_mapping,
|
|
next_token_ids=next_token_ids,
|
|
cu_num_tokens=cu_num_tokens,
|
|
block_table=attn_metadata.block_tables,
|
|
sampling_metadata=sampling_metadata,
|
|
token_indices=accepted_token_indices)
|
|
spec_token_ids = draft_token_ids.tolist()
|
|
return spec_token_ids
|
|
|
|
def _torchair_prepare_inputs(
|
|
self,
|
|
# [batch_size + 1]
|
|
cu_target_query_lens: torch.Tensor,
|
|
# [batch_size]
|
|
num_rejected_tokens: torch.Tensor,
|
|
token_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
torch.Tensor, torch.Tensor]:
|
|
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
|
# num_rejected_tokens: [n1, n2, n3]
|
|
# num_tokens_per_req: [a - n1, b - n2, c - n3]
|
|
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
|
# token_indices: [0, 1, ..., a - n1 - 1,
|
|
# a, a + 1, ..., a + b - n2 - 1,
|
|
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
|
# [0, a, a + b, a + b + c] -> [a, b, c]
|
|
query_len_per_req = (cu_target_query_lens[1:] -
|
|
cu_target_query_lens[:-1])
|
|
# [a, b, c] -> [a - n1, b - n2, c - n3]
|
|
|
|
cu_num_tokens = cu_target_query_lens
|
|
relative_index = query_len_per_req - num_rejected_tokens - 1
|
|
token_indices = cu_num_tokens[:-1] + relative_index
|
|
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model
|
|
target_token_ids = token_ids
|
|
target_positions = positions
|
|
target_hidden_states = hidden_states
|
|
target_slot_mapping = slot_mapping
|
|
|
|
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
|
|
|
|
def _propose_torchair(
|
|
self,
|
|
# [num_tokens]
|
|
target_token_ids: torch.Tensor,
|
|
# [num_tokens]
|
|
target_positions: torch.Tensor,
|
|
# [num_tokens, hidden_size]
|
|
target_hidden_states: torch.Tensor,
|
|
# [num_tokens]
|
|
target_slot_mapping: torch.Tensor,
|
|
# [batch_size]
|
|
next_token_ids: torch.Tensor,
|
|
# [batch_size + 1] starting with 0
|
|
cu_num_tokens: torch.Tensor,
|
|
# [batch_size, max_num_blocks_per_req]
|
|
block_table: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
token_indices=None) -> torch.Tensor:
|
|
num_tokens = target_token_ids.shape[0]
|
|
batch_size = next_token_ids.shape[0]
|
|
last_token_indices = cu_num_tokens[1:] - 1
|
|
|
|
# Shift the input ids by one token.
|
|
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
|
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
|
# Replace the last token with the next token.
|
|
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
|
if token_indices is not None:
|
|
last_token_indices = token_indices
|
|
|
|
self.input_ids[last_token_indices] = next_token_ids
|
|
|
|
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
|
|
max_query_len = query_lens.max().item()
|
|
|
|
# FIXME: reorder_batch() needs to be called before build()
|
|
# because fields of attn_metadata_builder needs to be updated.
|
|
# However, currently reorder_batch() takes input_batch and
|
|
# scheduler_output as arguments, we should probably refactor
|
|
# the method to use new data structures which are independent
|
|
# from input_batch and scheduler_output.
|
|
# self.runner.attn_metadata_builder.reorder_batch(
|
|
# input_batch=self.runner.input_batch,
|
|
# scheduler_output=self.runner.scheduler_output,
|
|
# )
|
|
|
|
if not self.runner.with_prefill:
|
|
# Torchair graph mode, padding is same as the main model
|
|
num_input_tokens = self.runner.graph_pad_size
|
|
elif (self.runner.use_aclgraph
|
|
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
|
|
# Acl graph mode, add padding to the batch size
|
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
|
else:
|
|
# Eager mode, no padding needed
|
|
num_input_tokens = num_tokens
|
|
|
|
seq_lens = target_positions[last_token_indices] + 1
|
|
seq_lens = seq_lens.int()
|
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
query_start_loc=cu_num_tokens[:batch_size + 1],
|
|
query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(),
|
|
seq_lens_cpu=seq_lens.cpu(),
|
|
num_reqs=batch_size,
|
|
num_actual_tokens=num_tokens,
|
|
max_query_len=max_query_len,
|
|
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
|
block_table_tensor=self.runner.input_batch.block_table[0].
|
|
get_device_tensor(),
|
|
slot_mapping=target_slot_mapping,
|
|
positions=target_positions,
|
|
attn_mask=self.runner.attn_mask,
|
|
spec_attn_mask=self.runner.spec_attn_mask,
|
|
attn_state=self.runner.attn_state,
|
|
graph_pad_size=self.runner.graph_pad_size,
|
|
decode_token_per_req=self.runner.decode_token_per_req,
|
|
num_computed_tokens_cpu=None,
|
|
seq_lens=None)
|
|
|
|
attn_metadata = self.runner.attn_metadata_builder.build(
|
|
0, common_attn_metadata, self.runner.get_model())
|
|
|
|
self.positions[:num_tokens] = target_positions
|
|
self.hidden_states[:num_tokens] = target_hidden_states
|
|
|
|
# torchair mode can reuse self.runner.num_tokens_across_dp
|
|
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
|
with_prefill = self.runner.with_prefill
|
|
|
|
moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens)
|
|
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
|
uniform_decode=False)
|
|
aclgraph_runtime_mode, batch_descriptor = \
|
|
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
|
|
|
|
for step in range(self.num_speculative_tokens):
|
|
with set_ascend_forward_context(
|
|
attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=num_input_tokens,
|
|
with_prefill=with_prefill,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
|
moe_comm_type=moe_comm_type,
|
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
in_profile_run=self.runner.in_profile_run,
|
|
num_actual_tokens=num_tokens):
|
|
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
|
model_kwargs = {}
|
|
model_kwargs["attn_metadata"] = attn_metadata
|
|
|
|
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
|
|
if not self.runner.with_prefill:
|
|
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
|
num_input_tokens)
|
|
hidden_states = torchair_compiled_model(
|
|
input_ids=self.input_ids[:num_input_tokens],
|
|
positions=self.positions[:num_input_tokens],
|
|
hidden_states=self.
|
|
hidden_states[:num_input_tokens],
|
|
inputs_embeds=None,
|
|
intermediate_tensors=None,
|
|
spec_step_idx=0,
|
|
**model_kwargs)
|
|
else:
|
|
hidden_states = self.model(
|
|
input_ids=self.input_ids[:num_input_tokens],
|
|
positions=self.positions[:num_input_tokens],
|
|
hidden_states=self.hidden_states[:num_input_tokens]
|
|
)
|
|
|
|
num_indices = last_token_indices.shape[0]
|
|
if lmhead_tp_enable():
|
|
if not self.runner.with_prefill:
|
|
max_num_reqs_across_dp = num_input_tokens
|
|
else:
|
|
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
|
|
last_token_indices = nn.functional.pad(
|
|
last_token_indices,
|
|
(0, max_num_reqs_across_dp - num_indices))
|
|
|
|
sample_hidden_states = hidden_states[last_token_indices]
|
|
logits = self.model.compute_logits(sample_hidden_states)
|
|
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
|
logits = logits[:num_indices]
|
|
draft_token_ids = logits.argmax(dim=-1)
|
|
|
|
if self.num_speculative_tokens == 1:
|
|
# [batch_size, 1]
|
|
return draft_token_ids.view(-1, 1)
|
|
|
|
if step == 0:
|
|
draft_token_ids_list = [draft_token_ids]
|
|
else:
|
|
draft_token_ids_list.append(draft_token_ids)
|
|
|
|
# prepare next mtp inputs
|
|
# mtp>1: prefill skip or decode skip last loop
|
|
if with_prefill:
|
|
for _ in range(self.num_speculative_tokens - 1):
|
|
draft_token_ids_list.append(draft_token_ids)
|
|
if step == self.num_speculative_tokens - 1 or with_prefill:
|
|
break
|
|
|
|
attn_metadata_i = attn_metadata
|
|
|
|
if step == 0:
|
|
positions = target_positions[last_token_indices]
|
|
hidden_states = hidden_states[last_token_indices]
|
|
slot_mapping = attn_metadata_i.slot_mapping[last_token_indices]
|
|
attn_metadata_i.slot_mapping.fill_(-1)
|
|
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
|
|
last_token_indices = self.arange[:batch_size]
|
|
if attn_metadata_i.num_decode_tokens != 0:
|
|
attn_metadata_i.num_decode_tokens = batch_size
|
|
if not self.runner.with_prefill:
|
|
attn_metadata_i.num_actual_tokens = batch_size
|
|
attn_metadata_i.query_lens = [1] * batch_size
|
|
|
|
input_ids = draft_token_ids_list[-1].int()
|
|
positions += 1
|
|
|
|
# NOTE(woosuk): We should handle the case where the draft model
|
|
# generates tokens beyond the max model length. Since it is complex
|
|
# to remove such requests from the batch, we keep them in the batch
|
|
# but adjust the position ids and slot mappings to avoid the
|
|
# out-of-range access during the model execution. The draft tokens
|
|
# generated with this adjustment should be ignored.
|
|
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
|
|
# Mask out the position ids that exceed the max model length.
|
|
# Otherwise, we may get out-of-range error in RoPE.
|
|
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
|
positions)
|
|
# Increment the sequence lengths.
|
|
attn_metadata_i.seq_lens[:batch_size] += 1
|
|
# For the requests that exceed the max model length, we set the
|
|
# sequence length to 1 to minimize their overheads in attention.
|
|
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
|
attn_metadata_i.seq_lens.device, non_blocking=True)
|
|
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
|
|
exceeds_max_model_len_cpu, 1)
|
|
# Mask out the slot mappings that exceed the max model length.
|
|
# Otherwise, the KV cache will be inadvertently updated with the
|
|
# padding tokens.
|
|
slot_mapping += 1
|
|
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
|
|
|
|
# copy inputs to buffer for cudagraph
|
|
self.input_ids[:batch_size] = input_ids
|
|
self.positions[:batch_size] = clamped_positions
|
|
self.hidden_states[:hidden_states.shape[0]] = hidden_states
|
|
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
|
|
|
|
if attn_metadata_i.prefill is not None:
|
|
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
|
|
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
|
|
)
|
|
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
|
|
attn_metadata_i.prefill.input_positions = self.positions[:
|
|
num_input_tokens]
|
|
attn_metadata_i.prefill.max_seq_lens += 1
|
|
attn_metadata_i.prefill.max_seq_lens = min(
|
|
attn_metadata_i.prefill.max_seq_lens,
|
|
self.runner.model_config.max_model_len)
|
|
if attn_metadata_i.decode is not None:
|
|
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
|
|
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
|
|
)
|
|
attn_metadata_i.decode.input_positions = self.positions[:
|
|
num_input_tokens]
|
|
attn_metadata_i.decode.max_seq_lens += 1
|
|
attn_metadata_i.decode.max_seq_lens = min(
|
|
attn_metadata_i.decode.max_seq_lens,
|
|
self.runner.model_config.max_model_len)
|
|
|
|
# mtp>1: [batch_size, k]
|
|
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
|
return draft_token_ids
|
|
|
|
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
|
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
|
|
-1]:
|
|
raise ValueError(
|
|
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}"
|
|
)
|
|
|
|
compiled_model = self.torchair_compiled_models.get(
|
|
batch_size
|
|
) if self.runner.use_cached_npu_graph else self.torchair_compiled_model
|
|
|
|
if compiled_model:
|
|
return compiled_model
|
|
|
|
patch_for_hcom()
|
|
config = torchair.CompilerConfig()
|
|
config.experimental_config.frozen_parameter = True
|
|
config.experimental_config.tiling_schedule_optimize = True
|
|
config.experimental_config.enable_view_optimize = \
|
|
get_ascend_config().torchair_graph_config.enable_view_optimize
|
|
torch.npu.set_compile_mode(jit_compile=False)
|
|
if not self.runner.use_cached_npu_graph:
|
|
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
|
self.torchair_compiled_model = torch.compile(
|
|
self.model,
|
|
dynamic=not self.use_sparse,
|
|
fullgraph=True,
|
|
backend=npu_backend)
|
|
return self.torchair_compiled_model
|
|
else:
|
|
# Generate a new forward proxy code object to prevent the invalidation of
|
|
# compilation cache caused by dynamo retracing
|
|
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
|
forward_fn = self.model.forward
|
|
code = forward_fn.__code__
|
|
# Mark code object with a new proxy name
|
|
modified_code = code.replace(co_name=forward_proxy_name, )
|
|
|
|
modified_func = types.FunctionType(modified_code,
|
|
forward_fn.__globals__,
|
|
name=forward_proxy_name,
|
|
argdefs=forward_fn.__defaults__)
|
|
|
|
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
|
self.model, nn.Module)
|
|
self.torchair_compiled_models[
|
|
batch_size] = torchair.inference.cache_compile(
|
|
self.model.__dict__[forward_proxy_name],
|
|
dynamic=not self.use_sparse,
|
|
fullgraph=True,
|
|
cache_dir=TORCHAIR_CACHE_DIR,
|
|
config=config,
|
|
ge_cache=False)
|
|
return self.torchair_compiled_models[batch_size]
|