Files
xc-llm-ascend/vllm_ascend/spec_decode/mtp_proposer.py
Mengqing Cao 8abe517870 [Refactor] Adapt deepseek-v3.2 to vllm 0.11.0 (#3432)
### What this PR does / why we need it?
Adapt deepseek-v3.2 to vllm 0.11.0, removing the useless patch.

The final goal is to remove all the patches and align the code arch to
vllm, thus we need to do the following work in next prs.
TODO:
- [x] remove patch on attention spec
- [ ] refactor the kvcache creation logic

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
1. CI passed with existing test.
2. Test pass with deepseek-v3.2-exp


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: MengqingCao <cmq0113@163.com>
2025-10-15 17:48:58 +08:00

668 lines
31 KiB
Python

import types
import torch
import torch.nn as nn
import torchair
from torchair import patch_for_hcom
from vllm.attention.layer import Attention
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
TorchairDeepSeekMTP
from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
TorchairCommonAttentionMetadata)
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
PADDING_SLOT_ID = -1
class MtpProposer(Proposer):
def __init__(
self,
vllm_config: VllmConfig,
device,
runner,
):
self.name = SpecDcodeType.MTP
self.vllm_config = vllm_config
self.device = device
self.runner = runner
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
# persistent buffers for graph
self.input_ids = torch.zeros(self.runner.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.positions = torch.zeros(self.runner.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.hidden_states = torch.zeros(
(self.runner.max_num_tokens,
vllm_config.model_config.get_hidden_size()),
dtype=self.runner.dtype,
device=self.device)
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
self.torchair_graph_enabled = get_ascend_config(
).torchair_graph_config.enabled
self.enable_shared_expert_dp = get_ascend_config(
).enable_shared_expert_dp
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
1,
device=self.runner.device,
dtype=torch.int32)
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk")
def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config)
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
target_device = self.vllm_config.device_config.device
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
if self.torchair_graph_enabled or (
self.enable_shared_expert_dp
and self.vllm_config.model_config.use_mla):
self.model = TorchairDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
else:
self.model = CustomDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names)
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = list(draft_attn_layer_names)
self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
process_weights_after_loading(self.model, draft_model_config,
target_device)
@torch.inference_mode()
def dummy_run(self,
num_tokens: int,
with_prefill: bool = False,
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp=None) -> None:
if not self.torchair_graph_enabled:
# TODO: adapt enable_dbo later
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._sync_metadata_across_dp(num_tokens,
with_prefill, False)
moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
is_running_torchair = self.torchair_graph_enabled and \
not with_prefill
if is_running_torchair:
skip_attn = False
if skip_attn:
attn_metadata = None
else:
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=num_reqs,
num_actual_tokens=1,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
decode_token_per_req=self.runner.decode_token_per_req,
)
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
common_attn_metadata)
input_ids = self.input_ids[:num_tokens]
positions = self.positions[:num_tokens]
previous_hidden_states = self.hidden_states[:num_tokens]
for _ in range(self.num_speculative_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0):
if is_running_torchair:
assert attn_metadata is not None
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(previous_hidden_states)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(get_forward_context().mc2_mask)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
torchair_compiled_model(
input_ids=input_ids,
positions=positions,
previous_hidden_states=previous_hidden_states,
inputs_embeds=None,
intermediate_tensors=None,
attn_metadata=attn_metadata,
kv_caches=self.runner.kv_caches[-1:],
spec_step_idx=0)
else:
self.model(input_ids=input_ids,
positions=positions,
previous_hidden_states=previous_hidden_states)
if with_prefill:
break
def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None,
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.runner.input_batch.req_ids[i]
req_state = self.runner.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
accepted_token_indices = None
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
cu_num_tokens, accepted_token_indices, target_token_ids, \
target_positions, target_hidden_states, target_slot_mapping = self._prepare_inputs(
attn_metadata.query_start_loc,
num_rejected_tokens,
self.runner.input_ids[:num_scheduled_tokens],
positions[:num_scheduled_tokens],
hidden_states[:num_scheduled_tokens],
attn_metadata.slot_mapping[:num_scheduled_tokens],
is_torchair_graph=self.runner._build_drafter_prepare_inputs_torchair_param(),
)
draft_token_ids = self._propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
token_indices=accepted_token_indices)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
def _prepare_inputs(
self,
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
token_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
slot_mapping: torch.Tensor,
is_torchair_graph: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens
if is_torchair_graph:
cu_num_tokens = cu_target_query_lens
relative_index = query_len_per_req - num_rejected_tokens - 1
token_indices = cu_num_tokens[:-1] + relative_index
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model
target_token_ids = token_ids
target_positions = positions
target_hidden_states = hidden_states
target_slot_mapping = slot_mapping
else:
cu_num_tokens = torch.empty_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0
# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.zeros(
num_tokens,
dtype=torch.int32,
device=cu_num_tokens.device,
)
BLOCK_SIZE = 1024
self._prepare_input_kernel(
token_indices,
cu_target_query_lens,
cu_num_tokens,
block_size=BLOCK_SIZE,
)
target_token_ids = token_ids[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = slot_mapping[token_indices]
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
def _propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
token_indices=None) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
if token_indices is not None and self.torchair_graph_enabled:
last_token_indices = token_indices
self.input_ids[last_token_indices] = next_token_ids
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
# FIXME: reorder_batch() needs to be called before build()
# because fields of attn_metadata_builder needs to be updated.
# However, currently reorder_batch() takes input_batch and
# scheduler_output as arguments, we should probably refactor
# the method to use new data structures which are independent
# from input_batch and scheduler_output.
# self.runner.attn_metadata_builder.reorder_batch(
# input_batch=self.runner.input_batch,
# scheduler_output=self.runner.scheduler_output,
# )
is_running_torchair = self.torchair_graph_enabled and \
not self.runner.with_prefill
if is_running_torchair:
# Torchair graph mode, padding is same as the main model
num_input_tokens = self.runner.graph_pad_size
elif (self.runner.use_aclgraph
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
# Acl graph mode, add padding to the batch size
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
# Eager mode, no padding needed
num_input_tokens = num_tokens
seq_lens = target_positions[last_token_indices] + 1
seq_lens = seq_lens.int()
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=cu_num_tokens[:batch_size + 1],
query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(),
seq_lens_cpu=seq_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(),
slot_mapping=target_slot_mapping,
positions=target_positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
graph_pad_size=self.runner.graph_pad_size,
decode_token_per_req=self.runner.decode_token_per_req,
num_computed_tokens_cpu=None,
seq_lens=None)
if not self.torchair_graph_enabled:
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_mtp = builder.build(0, common_attn_metadata,
self.runner.get_model())
attn_metadata = {}
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_mtp
else:
attn_metadata = self.runner.attn_metadata_builder.build(
0, common_attn_metadata, self.runner.get_model())
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if not self.torchair_graph_enabled:
# torch mode need to update num_tokens_across_dp
# TODO: adapt enable_dbo later
(num_input_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._sync_metadata_across_dp(
num_input_tokens, self.runner.with_prefill, False)
else:
# torchair mode can reuse self.runner.num_tokens_across_dp
num_tokens_across_dp = self.runner.num_tokens_across_dp
with_prefill = self.runner.with_prefill
moe_comm_type = self.runner._select_moe_comm_method(
num_input_tokens, with_prefill)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False)
aclgraph_runtime_mode, batch_descriptor = \
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
for step in range(self.num_speculative_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
aclgraph_runtime_mode=aclgraph_runtime_mode,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
if is_running_torchair:
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_input_tokens)
hidden_states = torchair_compiled_model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
previous_hidden_states=self.
hidden_states[:num_input_tokens],
inputs_embeds=None,
intermediate_tensors=None,
spec_step_idx=0,
**model_kwargs)
else:
hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
previous_hidden_states=self.
hidden_states[:num_input_tokens],
kv_caches=self.runner.kv_caches[-1:])
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable():
if not self.runner.with_prefill:
max_num_reqs_across_dp = num_input_tokens
else:
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
last_token_indices = nn.functional.pad(
last_token_indices,
(0, max_num_reqs_across_dp - num_indices))
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices]
draft_token_ids = logits.argmax(dim=-1)
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
if step == 0:
draft_token_ids_list = [draft_token_ids]
else:
draft_token_ids_list.append(draft_token_ids)
# prepare next mtp inputs
# mtp>1: prefill skip or decode skip last loop
if with_prefill and self.torchair_graph_enabled:
for _ in range(self.num_speculative_tokens - 1):
draft_token_ids_list.append(draft_token_ids)
if step == self.num_speculative_tokens - 1 or with_prefill:
break
if not self.torchair_graph_enabled:
attn_metadata_i = attn_metadata[self.attn_layer_name[0]]
else:
attn_metadata_i = attn_metadata
if step == 0:
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
slot_mapping = attn_metadata_i.slot_mapping[last_token_indices]
attn_metadata_i.slot_mapping.fill_(-1)
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
last_token_indices = self.arange[:batch_size]
if attn_metadata_i.num_decode_tokens != 0:
attn_metadata_i.num_decode_tokens = batch_size
if is_running_torchair:
attn_metadata_i.num_actual_tokens = batch_size
attn_metadata_i.query_lens = [1] * batch_size
input_ids = draft_token_ids_list[-1].int()
positions += 1
if not self.torchair_graph_enabled:
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
1:batch_size + 1].tolist()
attn_metadata_i.decode.cos = builder.cos_cache[
positions].unsqueeze(1).unsqueeze(2)
attn_metadata_i.decode.sin = builder.sin_cache[
positions].unsqueeze(1).unsqueeze(2)
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
attn_metadata_i.seq_lens[:batch_size] += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
attn_metadata_i.seq_lens.device, non_blocking=True)
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len_cpu, 1)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping += 1
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:hidden_states.shape[0]] = hidden_states
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
if attn_metadata_i.prefill is not None:
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
)
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.input_positions = self.positions[:
num_input_tokens]
attn_metadata_i.prefill.max_seq_lens += 1
attn_metadata_i.prefill.max_seq_lens = min(
attn_metadata_i.prefill.max_seq_lens,
self.runner.model_config.max_model_len)
if attn_metadata_i.decode is not None:
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
)
attn_metadata_i.decode.input_positions = self.positions[:
num_input_tokens]
attn_metadata_i.decode.max_seq_lens += 1
attn_metadata_i.decode.max_seq_lens = min(
attn_metadata_i.decode.max_seq_lens,
self.runner.model_config.max_model_len)
# mtp>1: [batch_size, k]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
-1]:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}"
)
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.runner.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
patch_for_hcom()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.runner.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=not self.use_sparse,
fullgraph=True,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=not self.use_sparse,
fullgraph=True,
cache_dir=TORCHAIR_CACHE_DIR,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]
# TODO Using torch instead of triton may result in poor performance
def _prepare_input_kernel(self, out_ptr: torch.Tensor,
cu_query_lens: torch.Tensor,
cu_num_tokens: torch.Tensor, block_size: int):
device = cu_query_lens.device
dtype = out_ptr.dtype
offsets = torch.arange(block_size, device=device, dtype=dtype)
start_pos = cu_num_tokens[:-1]
end_pos = cu_num_tokens[1:]
num_tokens = end_pos - start_pos
global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1))
values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1))
mask = (offsets.view(1, -1) < num_tokens.view(-1, 1))
global_indices_flat = global_indices[mask]
values_flat = values[mask]
out_ptr[global_indices_flat] = values_flat