[PromptLogprobs][V1] Support prompt logprobs to fix ceval accuracy in V1 (#1483)
### What this PR does / why we need it? Support prompt logprobs in V1. This also enable lm_eval to test accuracy on V1 ### Does this PR introduce _any_ user-facing change? support prompt logprobs output ### How was this patch tested? CI passed with accuracy test. Using lm_eval, which use prompt logprobs as output to test accuracy, to test: ```python VLLM_USE_V1=1 lm_eval \ --model vllm \ --model_args pretrained=Qwen/Qwen2.5-7B-Instruct,max_model_len=4096,block_size=4 \ --tasks ceval-valid_computer_network \ --batch_size 8 ``` After this pr, the accuracy test results of `Qwen/Qwen2.5-7B-Instruct` on V1 is: ```bash | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |----------------------------|------:|------|-----:|--------|---|-----:|---|-----:| |ceval-valid_computer_network| 2|none | 0|acc |↑ |0.7368|± |0.1038| | | |none | 0|acc_norm|↑ |0.7368|± |0.1038| ``` Closes: https://github.com/vllm-project/vllm-ascend/issues/1043 Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -53,7 +53,8 @@ from vllm.utils import DeviceMemoryProfiler, LazyLoader, cdiv
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
@@ -1506,6 +1507,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs_lists = logprobs_tensors.tolists() \
|
||||
if logprobs_tensors is not None else None
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
scheduler_output,
|
||||
)
|
||||
|
||||
# Get the valid generated tokens.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
@@ -1540,7 +1547,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict={},
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
)
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
@@ -1549,7 +1556,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict={},
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
)
|
||||
|
||||
@@ -2149,6 +2156,101 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
return spec_token_ids
|
||||
|
||||
def _get_prompt_logprobs_dict(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> dict[str, Optional[LogprobsTensors]]:
|
||||
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
|
||||
if not num_prompt_logprobs_dict:
|
||||
return {}
|
||||
|
||||
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
|
||||
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
|
||||
|
||||
# Since prompt logprobs are a rare feature, prioritize simple,
|
||||
# maintainable loop over optimal performance.
|
||||
completed_prefill_reqs = []
|
||||
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
|
||||
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
|
||||
# Get metadata for this request.
|
||||
request = self.requests[req_id]
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
# Set up target LogprobsTensors object.
|
||||
logprobs_tensors = in_progress_dict.get(req_id)
|
||||
if not logprobs_tensors:
|
||||
# Create empty logprobs CPU tensors for the entire prompt.
|
||||
# If chunked, we'll copy in slice by slice.
|
||||
logprobs_tensors = LogprobsTensors.empty_cpu(
|
||||
num_prompt_tokens - 1, num_prompt_logprobs + 1)
|
||||
in_progress_dict[req_id] = logprobs_tensors
|
||||
|
||||
# Determine number of logits to retrieve.
|
||||
start_idx = request.num_computed_tokens
|
||||
start_tok = start_idx + 1
|
||||
num_remaining_tokens = num_prompt_tokens - start_tok
|
||||
if num_tokens <= num_remaining_tokens:
|
||||
# This is a chunk, more tokens remain.
|
||||
# In the == case, there are no more prompt logprobs to produce
|
||||
# but we want to defer returning them to the next step where we
|
||||
# have new generated tokens to return.
|
||||
num_logits = num_tokens
|
||||
else:
|
||||
# This is the last chunk of prompt tokens to return.
|
||||
num_logits = num_remaining_tokens
|
||||
completed_prefill_reqs.append(req_id)
|
||||
prompt_logprobs_dict[req_id] = logprobs_tensors
|
||||
|
||||
if num_logits <= 0:
|
||||
# This can happen for the final chunk if we prefilled exactly
|
||||
# (num_prompt_tokens - 1) tokens for this request in the prior
|
||||
# step. There are no more prompt logprobs to produce.
|
||||
continue
|
||||
|
||||
# Get the logits corresponding to this req's prompt tokens.
|
||||
# If this is a partial request (i.e. chunked prefill),
|
||||
# then there is prompt logprob generated for each index.
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
offset = self.query_start_loc_np[req_idx].item()
|
||||
prompt_hidden_states = hidden_states[offset:offset + num_logits]
|
||||
logits = self.model.compute_logits(prompt_hidden_states, None)
|
||||
|
||||
# Get the "target" tokens for each index. For prompt at index i,
|
||||
# the token at prompt index i+1 is the "sampled" token we want
|
||||
# to gather the logprob for.
|
||||
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
|
||||
|
||||
# Compute prompt logprobs.
|
||||
logprobs = self.sampler.compute_logprobs(logits)
|
||||
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
|
||||
logprobs, num_prompt_logprobs, tgt_token_ids)
|
||||
|
||||
# Transfer NPU->CPU async.
|
||||
chunk_slice = slice(start_idx, start_idx + num_logits)
|
||||
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
|
||||
token_ids, non_blocking=True)
|
||||
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
|
||||
non_blocking=True)
|
||||
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
|
||||
ranks, non_blocking=True)
|
||||
|
||||
# Remove requests that have completed prefill from the batch
|
||||
# num_prompt_logprobs_dict.
|
||||
for req_id in completed_prefill_reqs:
|
||||
del num_prompt_logprobs_dict[req_id]
|
||||
del in_progress_dict[req_id]
|
||||
|
||||
# Must synchronize the non-blocking NPU->CPU transfers.
|
||||
if prompt_logprobs_dict:
|
||||
torch.npu.synchronize()
|
||||
|
||||
return prompt_logprobs_dict
|
||||
|
||||
def init_torchair_graph_batch_sizes(self):
|
||||
start_graph_batch_size = 4
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
Reference in New Issue
Block a user