[bugfix] Fixed an accuracy problem of gdn layer in graph (#6822)
### What this PR does / why we need it?
There will be random ouputs if we run model with GDN attention in graph
mode:
```python
prompts = [
"1. Who are you?",
]
sampling_params = SamplingParams(temperature=0.6, top_p=0.95, top_k=40, max_tokens=32)
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, top_k=40, max_tokens=5)
llm = LLM(model="/home/model/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
distributed_executor_backend="mp",
gpu_memory_utilization=0.7,
speculative_config={
"method": "qwen3_next_mtp",
"num_speculative_tokens": 3,
},
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [8],
},
max_model_len=4096,
enable_prefix_caching=False)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"{output.prompt_token_ids=}")
print(f"{output.outputs[0].token_ids=}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
Before appling this change, the outputs was:
```text
output.prompt_token_ids=[16, 13, 10479, 525, 498, 30]
output.outputs[0].token_ids=[3555, 323, 279, 1112, 279]
Prompt: '1. Who are you?', Generated text: ' What and the... the'
```
After applying this change, the output is:
```text
output.prompt_token_ids=[16, 13, 10479, 525, 498, 30]
output.outputs[0].token_ids=[3555, 374, 697, 829, 30]
Prompt: '1. Who are you?', Generated text: ' What is your name?'
```
**Why does this change sovle the problem?**
Now, `query_start_loc` is padded because of `fia`.
But, for `gdn-attention`, padded version of `query_start_loc` will cause
accuracy problem.
So, we need an unpadded version of `query_start_loc` named
`gdn_query_start_loc` and use it in `gdn-attention`, it works fine.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
As described aboved.
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -112,6 +112,7 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.utils import (
|
||||
check_gdn_layer,
|
||||
enable_flash_comm_v1,
|
||||
enable_sp,
|
||||
is_drafter_moe_model,
|
||||
@@ -229,6 +230,17 @@ class NPUModelRunner(GPUModelRunner):
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# Now, query_start_loc is padded.
|
||||
# But gdn needs an unpadded one.
|
||||
# gdn_query_start_loc is an unpadded version of query_start_loc.
|
||||
# TODO delete it if fia's check is removed.
|
||||
self._has_gdn = check_gdn_layer(vllm_config)
|
||||
if self._has_gdn:
|
||||
self.gdn_query_start_loc = self._make_buffer(
|
||||
self.max_num_reqs + 1, # type: ignore[has-type]
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
vllm_config.scheduler_config.max_num_batched_tokens -= max_pcp_pad_tokens
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
@@ -677,6 +689,16 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
|
||||
# Now, query_start_loc is padded.
|
||||
# But gdn needs an unpadded one.
|
||||
# gdn_query_start_loc is an unpadded version of query_start_loc.
|
||||
# TODO delete it if fia's check is removed.
|
||||
if self._has_gdn:
|
||||
self.gdn_query_start_loc.np[0] = 0
|
||||
self.gdn_query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
|
||||
self.gdn_query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
|
||||
self.gdn_query_start_loc.copy_to_gpu()
|
||||
|
||||
self.seq_lens.np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
|
||||
self.seq_lens.copy_to_gpu()
|
||||
|
||||
@@ -2019,6 +2041,18 @@ class NPUModelRunner(GPUModelRunner):
|
||||
kv_cache_group.kv_cache_spec,
|
||||
num_reqs_padded,
|
||||
)
|
||||
|
||||
# Now, query_start_loc is padded.
|
||||
# But gdn needs an unpadded one.
|
||||
# gdn_query_start_loc is an unpadded version of query_start_loc.
|
||||
# TODO delete it if fia's check is removed.
|
||||
if self._has_gdn:
|
||||
attn_group = self.attn_groups[kv_cache_gid][0]
|
||||
builder = attn_group.get_metadata_builder(0)
|
||||
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
cm.query_start_loc_cpu = self.gdn_query_start_loc.cpu[: num_reqs_padded + 1]
|
||||
cm.query_start_loc = self.gdn_query_start_loc.gpu[: num_reqs_padded + 1]
|
||||
|
||||
if kv_cache_gid > 0:
|
||||
cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid)
|
||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||
|
||||
Reference in New Issue
Block a user