[v0.11.0-dev][bugfix] Fix a bug in wrongly set npu_stream (#4106)
### What this PR does / why we need it? This pr fixes a bug introduced in #3985, which set wrong npu_stream (possibly by mistakes in cherry-pick). I correct it and make `update_attn_params` consistent to main branch. ### Does this PR introduce _any_ user-facing change? No. Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -213,26 +213,24 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
|||||||
) = param
|
) = param
|
||||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||||
|
|
||||||
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
|
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
|
||||||
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
|
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
|
||||||
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
|
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
|
||||||
# might encounter a bigger workspace, while currently we use max_model_len to
|
# might encounter a bigger workspace, while currently we use max_model_len to
|
||||||
# calculate max workspace in capturing. So additional get_workspace is added
|
# calculate max workspace in capturing. So additional get_workspace is added
|
||||||
# here to avoid such bugs.
|
# here to avoid such bugs.
|
||||||
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
|
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
|
||||||
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
|
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
|
||||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||||
query=query,
|
query=query,
|
||||||
key_cache=key_cache,
|
key_cache=key_cache,
|
||||||
value_cache=value_cache,
|
value_cache=value_cache,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
scale_value=scale,
|
scale_value=scale,
|
||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
context_lens=seq_lens,
|
context_lens=seq_lens,
|
||||||
out=output)
|
out=output)
|
||||||
|
|
||||||
with torch.npu.stream(update_stream):
|
|
||||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||||
torch_npu._npu_paged_attention(query=query,
|
torch_npu._npu_paged_attention(query=query,
|
||||||
key_cache=key_cache,
|
key_cache=key_cache,
|
||||||
|
|||||||
Reference in New Issue
Block a user