From 5f4391652f4a62d791fa4b9dfa3fc9d802d5a250 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Sat, 28 Jun 2025 09:38:52 +0800 Subject: [PATCH] [PromptLogprobs][V1] Support prompt logprobs to fix ceval accuracy in V1 (#1483) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? Support prompt logprobs in V1. This also enable lm_eval to test accuracy on V1 ### Does this PR introduce _any_ user-facing change? support prompt logprobs output ### How was this patch tested? CI passed with accuracy test. Using lm_eval, which use prompt logprobs as output to test accuracy, to test: ```python VLLM_USE_V1=1 lm_eval \ --model vllm \ --model_args pretrained=Qwen/Qwen2.5-7B-Instruct,max_model_len=4096,block_size=4 \ --tasks ceval-valid_computer_network \ --batch_size 8 ``` After this pr, the accuracy test results of `Qwen/Qwen2.5-7B-Instruct` on V1 is: ```bash | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |----------------------------|------:|------|-----:|--------|---|-----:|---|-----:| |ceval-valid_computer_network| 2|none | 0|acc |↑ |0.7368|± |0.1038| | | |none | 0|acc_norm|↑ |0.7368|± |0.1038| ``` Closes: https://github.com/vllm-project/vllm-ascend/issues/1043 Signed-off-by: MengqingCao --- vllm_ascend/worker/model_runner_v1.py | 108 +++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5a95147..f919595 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -53,7 +53,8 @@ from vllm.utils import DeviceMemoryProfiler, LazyLoader, cdiv from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, + ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -1506,6 +1507,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): logprobs_lists = logprobs_tensors.tolists() \ if logprobs_tensors is not None else None + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output, + ) + # Get the valid generated tokens. sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] @@ -1540,7 +1547,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): sampled_token_ids=valid_sampled_token_ids, spec_token_ids=spec_token_ids, logprobs=logprobs_lists, - prompt_logprobs_dict={}, + prompt_logprobs_dict=prompt_logprobs_dict, ) else: model_runner_output = ModelRunnerOutput( @@ -1549,7 +1556,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): sampled_token_ids=valid_sampled_token_ids, spec_token_ids=spec_token_ids, logprobs=logprobs_lists, - prompt_logprobs_dict={}, + prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], ) @@ -2149,6 +2156,101 @@ class NPUModelRunner(LoRAModelRunnerMixin): spec_token_ids = draft_token_ids.tolist() return spec_token_ids + def _get_prompt_logprobs_dict( + self, + hidden_states: torch.Tensor, + scheduler_output: "SchedulerOutput", + ) -> dict[str, Optional[LogprobsTensors]]: + num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs + if not num_prompt_logprobs_dict: + return {} + + in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + + # Since prompt logprobs are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): + + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + + # Get metadata for this request. + request = self.requests[req_id] + num_prompt_tokens = len(request.prompt_token_ids) + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( + self.device, non_blocking=True) + + # Set up target LogprobsTensors object. + logprobs_tensors = in_progress_dict.get(req_id) + if not logprobs_tensors: + # Create empty logprobs CPU tensors for the entire prompt. + # If chunked, we'll copy in slice by slice. + logprobs_tensors = LogprobsTensors.empty_cpu( + num_prompt_tokens - 1, num_prompt_logprobs + 1) + in_progress_dict[req_id] = logprobs_tensors + + # Determine number of logits to retrieve. + start_idx = request.num_computed_tokens + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + prompt_logprobs_dict[req_id] = logprobs_tensors + + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + continue + + # Get the logits corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt logprob generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc_np[req_idx].item() + prompt_hidden_states = hidden_states[offset:offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states, None) + + # Get the "target" tokens for each index. For prompt at index i, + # the token at prompt index i+1 is the "sampled" token we want + # to gather the logprob for. + tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + + # Compute prompt logprobs. + logprobs = self.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.sampler.gather_logprobs( + logprobs, num_prompt_logprobs, tgt_token_ids) + + # Transfer NPU->CPU async. + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_( + token_ids, non_blocking=True) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, + non_blocking=True) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_( + ranks, non_blocking=True) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + del num_prompt_logprobs_dict[req_id] + del in_progress_dict[req_id] + + # Must synchronize the non-blocking NPU->CPU transfers. + if prompt_logprobs_dict: + torch.npu.synchronize() + + return prompt_logprobs_dict + def init_torchair_graph_batch_sizes(self): start_graph_batch_size = 4 tp_size = get_tensor_model_parallel_world_size()