[long_seq_optim] BSND to TND and FA_UPDATE replacement (#3778)
### What this PR does / why we need it?
We have optimized the performance of long sequences:First,Modify the
input data format for attention calculation. Instead of using the
original BSND format, remove the logic for converting between TND and
BSND, and directly adopt the TND format.
The TND input format can be directly reused, which shortens the data
flow path. Converting to BSND is an unnecessary processing step.Second,
we switched the output update of the concatenated small operators to the
npu_attention_update fusion operator to improve performance.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4
---------
Signed-off-by: pichangping <1337510399@qq.com>
This commit is contained in:
@@ -23,7 +23,6 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
@@ -318,6 +317,18 @@ class AscendAttentionMetadataBuilder:
|
||||
pcp_metadata = None
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
if common_long_seq_metadata is not None:
|
||||
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
|
||||
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
|
||||
tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens
|
||||
pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
if pcp_size > 1:
|
||||
attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0],
|
||||
dim=0).tolist()
|
||||
head_attn_nomask_seqlens = torch.cumsum(
|
||||
head_attn_nomask_seqlens[1], dim=0).tolist()
|
||||
tail_attn_nomask_seqlens = torch.cumsum(
|
||||
tail_attn_nomask_seqlens[1], dim=0).tolist()
|
||||
pcp_metadata = AscendPCPMetadata(
|
||||
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
|
||||
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
|
||||
@@ -329,12 +340,9 @@ class AscendAttentionMetadataBuilder:
|
||||
kv_with_q_tail_nomask_idx_tensor,
|
||||
kv_with_q_tail_mask_idx=common_long_seq_metadata.
|
||||
kv_with_q_tail_mask_idx_tensor,
|
||||
attn_mask_seqlens=common_long_seq_metadata.
|
||||
attn_mask_seqlens,
|
||||
head_attn_nomask_seqlens=common_long_seq_metadata.
|
||||
head_attn_nomask_seqlens,
|
||||
tail_attn_nomask_seqlens=common_long_seq_metadata.
|
||||
tail_attn_nomask_seqlens,
|
||||
attn_mask_seqlens=attn_mask_seqlens,
|
||||
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask)
|
||||
prefill_metadata = AscendMetadataForPrefill(
|
||||
@@ -726,28 +734,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor,
|
||||
lengths: List[int]) -> torch.Tensor:
|
||||
max_len = max(lengths)
|
||||
splits = torch.split(tensor_tnd, lengths, dim=0)
|
||||
|
||||
padded = []
|
||||
for s in splits:
|
||||
pad_len = max_len - s.shape[0]
|
||||
s_pad = F.pad(s, (0, 0, 0, 0, 0, pad_len))
|
||||
padded.append(s_pad)
|
||||
|
||||
tensor_bsnd = torch.stack(padded, dim=0)
|
||||
return tensor_bsnd
|
||||
|
||||
def _unpack_bsnd_2_tnd(self, tensor_bsnd: torch.Tensor,
|
||||
lengths: List[int]) -> torch.Tensor:
|
||||
slices = []
|
||||
for i, length in enumerate(lengths):
|
||||
slices.append(tensor_bsnd[i, :length])
|
||||
tensor_tnd = torch.cat(slices, dim=0)
|
||||
return tensor_tnd
|
||||
|
||||
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
|
||||
q_seqlens: List[int],
|
||||
k_nomask: torch.Tensor,
|
||||
@@ -757,17 +743,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
v_mask: torch.Tensor,
|
||||
kv_seqlens_mask: List[int],
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
q = self._pack_tnd_2_bsnd(q, q_seqlens)
|
||||
|
||||
# nomask Attention
|
||||
if k_nomask is not None:
|
||||
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q,
|
||||
self._pack_tnd_2_bsnd(k_nomask, kv_seqlens_nomask),
|
||||
self._pack_tnd_2_bsnd(v_nomask, kv_seqlens_nomask),
|
||||
k_nomask,
|
||||
v_nomask,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BSND",
|
||||
input_layout="TND",
|
||||
atten_mask=None,
|
||||
scale=self.scale,
|
||||
sparse_mode=0,
|
||||
@@ -776,38 +760,46 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths_kv=kv_seqlens_nomask,
|
||||
actual_seq_lengths=q_seqlens)
|
||||
attn_out_nomask = self._unpack_bsnd_2_tnd(attn_out_nomask,
|
||||
q_seqlens)
|
||||
# (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1)
|
||||
attn_lse_nomask = self._unpack_bsnd_2_tnd(
|
||||
attn_lse_nomask.permute([0, 2, 1, 3]), q_seqlens)
|
||||
|
||||
# mask Attention
|
||||
attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q,
|
||||
self._pack_tnd_2_bsnd(k_mask, kv_seqlens_mask),
|
||||
self._pack_tnd_2_bsnd(v_mask, kv_seqlens_mask),
|
||||
k_mask,
|
||||
v_mask,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BSND",
|
||||
input_layout="TND",
|
||||
atten_mask=mask,
|
||||
scale=self.scale,
|
||||
sparse_mode=0,
|
||||
sparse_mode=3,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths_kv=kv_seqlens_mask,
|
||||
actual_seq_lengths=q_seqlens)
|
||||
attn_out_mask = self._unpack_bsnd_2_tnd(attn_out_mask, q_seqlens)
|
||||
attn_lse_mask = self._unpack_bsnd_2_tnd(
|
||||
attn_lse_mask.permute([0, 2, 1, 3]), q_seqlens)
|
||||
|
||||
# update
|
||||
output = attn_out_mask
|
||||
if k_nomask is not None:
|
||||
output, _ = self._update_out_and_lse(
|
||||
torch.stack([attn_out_nomask, attn_out_mask], dim=0),
|
||||
torch.stack([attn_lse_nomask, attn_lse_mask], dim=0))
|
||||
T = attn_out_mask.shape[0]
|
||||
N = attn_out_mask.shape[1]
|
||||
D = attn_out_mask.shape[2]
|
||||
|
||||
attn_out_mask, attn_lse_mask = self._out_lse_reshape(
|
||||
attn_out_mask, attn_lse_mask)
|
||||
attn_out_nomask, attn_lse_nomask = self._out_lse_reshape(
|
||||
attn_out_nomask, attn_lse_nomask)
|
||||
attn_out_mask = attn_out_mask.to(torch.float32)
|
||||
attn_out_nomask = attn_out_nomask.to(torch.float32)
|
||||
attn_lse_mask = attn_lse_mask.to(torch.float32)
|
||||
attn_lse_nomask = attn_lse_nomask.to(torch.float32)
|
||||
|
||||
attn_output = [attn_out_nomask, attn_out_mask]
|
||||
attn_lse = [attn_lse_nomask, attn_lse_mask]
|
||||
update_type = 0
|
||||
output, _ = torch_npu.npu_attention_update(attn_lse, attn_output,
|
||||
update_type)
|
||||
output = output.view(T, N, D)
|
||||
|
||||
return output
|
||||
|
||||
@@ -832,15 +824,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
# 1. Attention calculation in the first half of Q in load balancing
|
||||
output_head = self._attention_with_nomask_and_mask(
|
||||
q=torch.index_select(query, 0, q_head_idx),
|
||||
q_seqlens=attn_mask_seqlens[0].tolist(),
|
||||
q_seqlens=attn_mask_seqlens,
|
||||
k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx)
|
||||
if self.pcp_rank > 0 else None,
|
||||
v_nomask=torch.index_select(value, 0, kv_with_q_head_nomask_idx)
|
||||
if self.pcp_rank > 0 else None,
|
||||
kv_seqlens_nomask=head_attn_nomask_seqlens[1].tolist(),
|
||||
kv_seqlens_nomask=head_attn_nomask_seqlens,
|
||||
k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx),
|
||||
v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx),
|
||||
kv_seqlens_mask=attn_mask_seqlens[0].tolist(),
|
||||
kv_seqlens_mask=attn_mask_seqlens,
|
||||
mask=mask)
|
||||
|
||||
# 2. the Attention calculation in the latter half of Q in load balancing
|
||||
@@ -848,13 +840,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
# pcp_rank1: Q2*KV0~KV1 + Q2*KV2
|
||||
output_tail = self._attention_with_nomask_and_mask(
|
||||
q=torch.index_select(query, 0, q_tail_idx),
|
||||
q_seqlens=attn_mask_seqlens[0].tolist(),
|
||||
q_seqlens=attn_mask_seqlens,
|
||||
k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx),
|
||||
v_nomask=torch.index_select(value, 0, kv_with_q_tail_nomask_idx),
|
||||
kv_seqlens_nomask=tail_attn_nomask_seqlens[1].tolist(),
|
||||
kv_seqlens_nomask=tail_attn_nomask_seqlens,
|
||||
k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx),
|
||||
v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx),
|
||||
kv_seqlens_mask=attn_mask_seqlens[0].tolist(),
|
||||
kv_seqlens_mask=attn_mask_seqlens,
|
||||
mask=mask)
|
||||
|
||||
# 3. Combine the output of the first half and second half.
|
||||
@@ -863,20 +855,36 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
|
||||
return output
|
||||
|
||||
def _update_out_and_lse(self, out_list: torch.Tensor,
|
||||
lse_list: torch.Tensor) -> torch.Tensor:
|
||||
"""LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
|
||||
Args:
|
||||
out_list: shape = [N, batch_size, num_heads, head_size]
|
||||
lse_list: shape = [N, batch_size, num_heads, 1]
|
||||
Returns:
|
||||
out_final: shape = [batch_size, num_heads, head_size]
|
||||
lse_final: shape = [batch_size, num_heads, 1]
|
||||
"""
|
||||
lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False)
|
||||
out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list,
|
||||
dim=0)
|
||||
return out_final, lse_final
|
||||
def _out_lse_reshape(self, attn_out: torch.Tensor,
|
||||
attn_lse: torch.Tensor) -> torch.Tensor:
|
||||
attn_out = attn_out.contiguous().view(
|
||||
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
|
||||
attn_lse = attn_lse.contiguous().view(
|
||||
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
|
||||
return attn_out, attn_lse
|
||||
|
||||
def _npu_attention_update(
|
||||
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
|
||||
update_type = 0
|
||||
|
||||
batch = attn_out_lse_list[0].shape[0]
|
||||
num_heads = attn_out_lse_list[0].shape[1]
|
||||
head_dim = attn_out_lse_list[0].shape[2] - 1
|
||||
|
||||
attn_out_split_cp = []
|
||||
attn_lse_split_cp = []
|
||||
|
||||
for i in attn_out_lse_list:
|
||||
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
|
||||
*torch.split(i, [self.head_size, 1], dim=-1))
|
||||
attn_out_split_cp.append(attn_out_allgather)
|
||||
attn_lse_split_cp.append(attn_lse_allgather)
|
||||
|
||||
attn_out, attn_lse = torch_npu.npu_attention_update(
|
||||
attn_lse_split_cp, attn_out_split_cp, update_type)
|
||||
attn_out = attn_out.view(batch, num_heads, head_dim)
|
||||
|
||||
return attn_out
|
||||
|
||||
def _forward_decode_pcp_dcp(self, query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata) -> torch.Tensor:
|
||||
@@ -889,9 +897,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
else:
|
||||
num_heads = self.num_heads
|
||||
|
||||
# 1. Compute out&lse by "npu_fused_infer_attention_score"
|
||||
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
|
||||
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
|
||||
k_nope = self.key_cache.view(self.key_cache.shape[0],
|
||||
self.key_cache.shape[1], -1)
|
||||
value = self.value_cache.view(self.key_cache.shape[0],
|
||||
@@ -902,7 +907,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
'num_key_value_heads':
|
||||
self.num_kv_heads,
|
||||
'input_layout':
|
||||
"BSND",
|
||||
"TND",
|
||||
'atten_mask':
|
||||
None,
|
||||
'scale':
|
||||
@@ -917,9 +922,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata.block_tables,
|
||||
'block_size':
|
||||
self.key_cache.shape[1],
|
||||
"actual_seq_lengths_kv":
|
||||
attn_metadata.decode_meta.
|
||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
|
||||
'actual_seq_lengths_kv':
|
||||
attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens],
|
||||
'actual_seq_lengths':
|
||||
attn_metadata.actual_seq_lengths_q[:attn_metadata.
|
||||
num_decode_tokens]
|
||||
}
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
@@ -935,16 +942,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
q_nope, k_nope, value, **common_kwargs)
|
||||
query, k_nope, value, **common_kwargs)
|
||||
update_graph_params_workspaces(num_tokens,
|
||||
weak_ref_tensors(workspace))
|
||||
attn_out = torch.empty_like(q_nope)
|
||||
attn_out = torch.empty_like(query)
|
||||
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
|
||||
dtype=torch.float,
|
||||
device=q_nope.device)
|
||||
device=query.device)
|
||||
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
|
||||
(weak_ref_tensors(query), weak_ref_tensors(k_nope),
|
||||
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
|
||||
self.scale, attn_metadata.block_tables,
|
||||
self.key_cache.shape[1], attn_metadata.decode_meta.
|
||||
@@ -954,7 +961,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.pcp_rank, self.dcp_rank, self.dcp_size))
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope,
|
||||
query,
|
||||
k_nope,
|
||||
value,
|
||||
**common_kwargs,
|
||||
@@ -964,14 +971,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope, k_nope, value, **common_kwargs)
|
||||
query, k_nope, value, **common_kwargs)
|
||||
|
||||
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
|
||||
attn_out.shape[3])
|
||||
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
|
||||
attn_out_lse_list = []
|
||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
|
||||
if self.dcp_size > 1:
|
||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
|
||||
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
|
||||
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
|
||||
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
|
||||
@@ -980,35 +985,28 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
group=self.dcp_group)
|
||||
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
|
||||
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
|
||||
attn_out_lse_split_on_seq = list(
|
||||
if self.pcp_size > 1:
|
||||
attn_out_lse = attn_out_lse_all2all.contiguous()
|
||||
attn_out_lse_list = list(
|
||||
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
|
||||
|
||||
attn_out_lse_split_dcp = torch.stack(
|
||||
attn_out_lse_split_on_seq,
|
||||
dim=0) # [dcp, batch_size, num_heads, head_size+1]
|
||||
# Update out&lse
|
||||
attn_out_split_dcp, attn_lse_split_dcp = torch.split(
|
||||
attn_out_lse_split_dcp, [self.head_size, 1], dim=-1)
|
||||
attn_out, attn_lse = self._update_out_and_lse(
|
||||
attn_out_split_dcp, attn_lse_split_dcp)
|
||||
if self.pcp_size > 1:
|
||||
# 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1]
|
||||
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
|
||||
# 3. AllGather out&lse within CP group
|
||||
# AllGather out&lse within CP group
|
||||
attn_out_lse_list = [
|
||||
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
|
||||
]
|
||||
dist.all_gather(attn_out_lse_list,
|
||||
attn_out_lse,
|
||||
group=self.pcp_group)
|
||||
# 4. Update out&lse
|
||||
attn_out_lse_allgather = torch.stack(
|
||||
attn_out_lse_list,
|
||||
dim=0) # [pcp, batch_size, num_heads, head_size+1]
|
||||
attn_out_allgather, attn_lse_allgather = torch.split(
|
||||
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
|
||||
attn_out, _ = self._update_out_and_lse(attn_out_allgather,
|
||||
attn_lse_allgather)
|
||||
if self.dcp_size > 1 and self.pcp_size > 1:
|
||||
attn_out_lse_list_pcp_dcp = []
|
||||
for s in attn_out_lse_list:
|
||||
attn_out_lse_list_split = list(
|
||||
torch.chunk(s, self.dcp_size, dim=1))
|
||||
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
|
||||
attn_out_lse_list = attn_out_lse_list_pcp_dcp
|
||||
# Update out&lse
|
||||
attn_out = self._npu_attention_update(attn_out_lse_list)
|
||||
return attn_out
|
||||
|
||||
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
|
||||
|
||||
@@ -1374,13 +1374,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
req_indices, positions_np)
|
||||
self.input_batch.block_table.commit_slot_mapping(
|
||||
total_num_scheduled_tokens)
|
||||
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
|
||||
tokens)
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
# update total_num_scheduled_tokens
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
|
||||
self.input_batch.block_table.commit_slot_mapping(
|
||||
total_num_scheduled_tokens)
|
||||
|
||||
total_num_pcp_pads = sum(self.num_pcp_pads)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
@@ -4140,7 +4140,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
||||
num_prefills = num_reqs - num_decodes
|
||||
long_seq_metadata = None
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
@@ -4248,9 +4247,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device,
|
||||
dtype=self.dtype), 1)
|
||||
else:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
pcp_prefill_mask = torch.triu(
|
||||
torch.full((num_prefills, max_seq_len, max_seq_len),
|
||||
torch.full((2048, 2048),
|
||||
True,
|
||||
device=self.device,
|
||||
dtype=torch.bool), 1)
|
||||
|
||||
Reference in New Issue
Block a user