support FULL graph mode for GQA (#3970)

### What this PR does / why we need it?
The current library only supports the FullDecodeOnly graph mode, which
enables full graph execution during the decode. This PR extends support
to allow full graph execution in both the prefill and decode, referred
to as FULL graph mode.

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-11-17 10:50:35 +08:00
committed by GitHub
parent c334114f69
commit e38ef2c434
11 changed files with 328 additions and 296 deletions

View File

@@ -331,6 +331,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_groups: list[list[AttentionGroup]] = []
self.encoder_cache: Dict[str, torch.Tensor] = {}
self.attn_mask = None
self.fia_attn_mask = None
self.attn_state = None
self.requests: Dict[str, CachedRequestState] = {}
self.intermediate_tensors: Optional[IntermediateTensors] = None
@@ -1030,6 +1031,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
return None
def _make_fia_attention_mask(self) -> torch.Tensor:
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
return self.attn_mask_builder.get_splitfuse_attn_mask()
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0
for index, req_id in enumerate(self.input_batch.req_ids):
@@ -1667,6 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu,
attn_state=attn_state)
self.fia_attn_mask = self._make_fia_attention_mask()
self.attn_state = attn_state # type: ignore
self.with_prefill = with_prefill
@@ -1899,6 +1906,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
fia_attn_mask=self.fia_attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
@@ -2756,13 +2764,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
self.device).to(torch.int32)
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
self.query_start_loc[1:num_reqs + 1] = torch.Tensor(cu_num_tokens)
self.query_start_loc_cpu[1:num_reqs +
1] = torch.Tensor(cu_num_tokens)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
assigned_mask_dim = 2048
self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim,
assigned_mask_dim),
diagonal=1).to(torch.int8).to(
self.device)
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -2805,6 +2818,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
fia_attn_mask=self.fia_attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
max_query_len=max_query_len,
@@ -3978,10 +3992,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
graph_support = None
if hasattr(builder, 'aclgraph_support'):
graph_support = builder.aclgraph_support.value
builder_aclgraph = builder.aclgraph_support
else:
graph_support = builder.cudagraph_support.value
builder_aclgraph = builder.cudagraph_support
if graph_support < min_ag_support.value:
min_ag_support = builder.aclgraph_support
min_ag_support = builder_aclgraph
min_ag_builder_name = builder.__class__.__name__
# This is an imitation of compilation_config.splitting_ops_contain_attention()