support FULL graph mode for GQA (#3970)
### What this PR does / why we need it?
The current library only supports the FullDecodeOnly graph mode, which
enables full graph execution during the decode. This PR extends support
to allow full graph execution in both the prefill and decode, referred
to as FULL graph mode.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -331,6 +331,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.attn_groups: list[list[AttentionGroup]] = []
|
||||
self.encoder_cache: Dict[str, torch.Tensor] = {}
|
||||
self.attn_mask = None
|
||||
self.fia_attn_mask = None
|
||||
self.attn_state = None
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||
@@ -1030,6 +1031,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
return None
|
||||
|
||||
def _make_fia_attention_mask(self) -> torch.Tensor:
|
||||
if self.attn_mask_builder is None:
|
||||
raise ValueError("Attn mask builder is None")
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
|
||||
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
||||
mrope_pos_ptr = 0
|
||||
for index, req_id in enumerate(self.input_batch.req_ids):
|
||||
@@ -1667,6 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
||||
position=positions_cpu,
|
||||
attn_state=attn_state)
|
||||
self.fia_attn_mask = self._make_fia_attention_mask()
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
self.with_prefill = with_prefill
|
||||
@@ -1899,6 +1906,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
positions=self.positions,
|
||||
attn_mask=self.attn_mask,
|
||||
fia_attn_mask=self.fia_attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
||||
@@ -2756,13 +2764,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
|
||||
self.device).to(torch.int32)
|
||||
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
|
||||
|
||||
self.query_start_loc[1:num_reqs + 1] = torch.Tensor(cu_num_tokens)
|
||||
self.query_start_loc_cpu[1:num_reqs +
|
||||
1] = torch.Tensor(cu_num_tokens)
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
assigned_mask_dim = 2048
|
||||
self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim,
|
||||
assigned_mask_dim),
|
||||
diagonal=1).to(torch.int8).to(
|
||||
self.device)
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
|
||||
@@ -2805,6 +2818,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
positions=self.positions,
|
||||
attn_mask=self.attn_mask,
|
||||
fia_attn_mask=self.fia_attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
max_query_len=max_query_len,
|
||||
@@ -3978,10 +3992,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
graph_support = None
|
||||
if hasattr(builder, 'aclgraph_support'):
|
||||
graph_support = builder.aclgraph_support.value
|
||||
builder_aclgraph = builder.aclgraph_support
|
||||
else:
|
||||
graph_support = builder.cudagraph_support.value
|
||||
builder_aclgraph = builder.cudagraph_support
|
||||
if graph_support < min_ag_support.value:
|
||||
min_ag_support = builder.aclgraph_support
|
||||
min_ag_support = builder_aclgraph
|
||||
min_ag_builder_name = builder.__class__.__name__
|
||||
|
||||
# This is an imitation of compilation_config.splitting_ops_contain_attention()
|
||||
|
||||
Reference in New Issue
Block a user