[v0.11.0-dev][bugfix] Add branch for stream up-lifting in update_attn_params (#4437)
### What this PR does / why we need it? #3985 move stream context initialization before for-loops to improve performance. However, we find that this might cause potential accuracy drop when used with pd disaggregation. Thus we partly revert this change when using pd disaggregation, and we shall fix this bug in th future. ### Does this PR introduce _any_ user-facing change? No. --------- Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -189,11 +189,20 @@ class ACLGraphWrapper:
|
||||
return entry.output
|
||||
|
||||
|
||||
def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
def update_attn_params(update_stream,
|
||||
forward_context,
|
||||
runtime_shape,
|
||||
kv_transfer_config=None):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
with torch.npu.stream(update_stream):
|
||||
|
||||
# NOTE(Angazenn): By moving the npu-stream context ahead,
|
||||
# (see https://github.com/vllm-project/vllm-ascend/pull/3985)
|
||||
# we can reduce host overhead introduced by stream initialization.
|
||||
# However, we find that this might cause potential accuracy problems
|
||||
# with pd-disaggreagation. Therefore, this optimization is only enabled
|
||||
# without pd-disaggreagation. We are working on to solve this problem
|
||||
# directly int the future.
|
||||
if kv_transfer_config is not None:
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
@@ -215,10 +224,9 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
|
||||
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
|
||||
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
|
||||
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
|
||||
# might encounter a bigger workspace, while currently we use max_model_len to
|
||||
# calculate max workspace in capturing. So additional get_workspace is added
|
||||
# here to avoid such bugs.
|
||||
# in torch_npu. On some cases, _npu_paged_attention requires different workspace
|
||||
# among various seq_lens. So additional get_workspace is added here
|
||||
# to avoid such bugs.
|
||||
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
|
||||
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
@@ -231,20 +239,67 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
else:
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
block_table,
|
||||
seq_lens,
|
||||
output,
|
||||
) = param
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
|
||||
@@ -1598,7 +1598,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.speculative_config)
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
maybe_padded_num_tokens)
|
||||
maybe_padded_num_tokens,
|
||||
self.vllm_config.kv_transfer_config)
|
||||
|
||||
if get_forward_context().sp_enabled:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||
@@ -2359,7 +2360,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens, self.speculative_config)
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
num_tokens)
|
||||
num_tokens,
|
||||
self.vllm_config.kv_transfer_config)
|
||||
|
||||
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
||||
hidden_states, _ = hidden_states
|
||||
|
||||
Reference in New Issue
Block a user