[long_seq_optim] BSND to TND and FA_UPDATE replacement (#3778)

### What this PR does / why we need it?
We have optimized the performance of long sequences:First,Modify the
input data format for attention calculation. Instead of using the
original BSND format, remove the logic for converting between TND and
BSND, and directly adopt the TND format.
The TND input format can be directly reused, which shortens the data
flow path. Converting to BSND is an unnecessary processing step.Second,
we switched the output update of the concatenated small operators to the
npu_attention_update fusion operator to improve performance.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4

---------

Signed-off-by: pichangping <1337510399@qq.com>
This commit is contained in:
pichangping
2025-10-29 09:33:35 +08:00
committed by GitHub
parent e56b0017a3
commit f57bdb09fc
2 changed files with 108 additions and 112 deletions

View File

@@ -1374,13 +1374,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
tokens)
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
# update total_num_scheduled_tokens
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
total_num_pcp_pads = sum(self.num_pcp_pads)
max_num_scheduled_tokens = max(tokens)
@@ -4140,7 +4140,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
>= self.input_batch.num_prompt_tokens[:num_reqs])
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
num_prefills = num_reqs - num_decodes
long_seq_metadata = None
if self.pcp_size * self.dcp_size > 1:
long_seq_metadata = AscendPrefillContextParallelMetadata(
@@ -4248,9 +4247,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
device=self.device,
dtype=self.dtype), 1)
else:
max_seq_len = max(seq_lens, default=0)
pcp_prefill_mask = torch.triu(
torch.full((num_prefills, max_seq_len, max_seq_len),
torch.full((2048, 2048),
True,
device=self.device,
dtype=torch.bool), 1)