[PromptLogprobs][V1] Support prompt logprobs to fix ceval accuracy in V1 (#1483)
### What this PR does / why we need it? Support prompt logprobs in V1. This also enable lm_eval to test accuracy on V1 ### Does this PR introduce _any_ user-facing change? support prompt logprobs output ### How was this patch tested? CI passed with accuracy test. Using lm_eval, which use prompt logprobs as output to test accuracy, to test: ```python VLLM_USE_V1=1 lm_eval \ --model vllm \ --model_args pretrained=Qwen/Qwen2.5-7B-Instruct,max_model_len=4096,block_size=4 \ --tasks ceval-valid_computer_network \ --batch_size 8 ``` After this pr, the accuracy test results of `Qwen/Qwen2.5-7B-Instruct` on V1 is: ```bash | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |----------------------------|------:|------|-----:|--------|---|-----:|---|-----:| |ceval-valid_computer_network| 2|none | 0|acc |↑ |0.7368|± |0.1038| | | |none | 0|acc_norm|↑ |0.7368|± |0.1038| ``` Closes: https://github.com/vllm-project/vllm-ascend/issues/1043 Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -53,7 +53,8 @@ from vllm.utils import DeviceMemoryProfiler, LazyLoader, cdiv
|
|||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||||
|
ModelRunnerOutput)
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.sampler import Sampler
|
from vllm.v1.sample.sampler import Sampler
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
@@ -1506,6 +1507,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logprobs_lists = logprobs_tensors.tolists() \
|
logprobs_lists = logprobs_tensors.tolists() \
|
||||||
if logprobs_tensors is not None else None
|
if logprobs_tensors is not None else None
|
||||||
|
|
||||||
|
# Compute prompt logprobs if needed.
|
||||||
|
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||||
|
hidden_states[:num_scheduled_tokens],
|
||||||
|
scheduler_output,
|
||||||
|
)
|
||||||
|
|
||||||
# Get the valid generated tokens.
|
# Get the valid generated tokens.
|
||||||
sampled_token_ids = sampler_output.sampled_token_ids
|
sampled_token_ids = sampler_output.sampled_token_ids
|
||||||
max_gen_len = sampled_token_ids.shape[-1]
|
max_gen_len = sampled_token_ids.shape[-1]
|
||||||
@@ -1540,7 +1547,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
sampled_token_ids=valid_sampled_token_ids,
|
sampled_token_ids=valid_sampled_token_ids,
|
||||||
spec_token_ids=spec_token_ids,
|
spec_token_ids=spec_token_ids,
|
||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
@@ -1549,7 +1556,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
sampled_token_ids=valid_sampled_token_ids,
|
sampled_token_ids=valid_sampled_token_ids,
|
||||||
spec_token_ids=spec_token_ids,
|
spec_token_ids=spec_token_ids,
|
||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2149,6 +2156,101 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
spec_token_ids = draft_token_ids.tolist()
|
spec_token_ids = draft_token_ids.tolist()
|
||||||
return spec_token_ids
|
return spec_token_ids
|
||||||
|
|
||||||
|
def _get_prompt_logprobs_dict(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
) -> dict[str, Optional[LogprobsTensors]]:
|
||||||
|
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
|
||||||
|
if not num_prompt_logprobs_dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
|
||||||
|
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
|
||||||
|
|
||||||
|
# Since prompt logprobs are a rare feature, prioritize simple,
|
||||||
|
# maintainable loop over optimal performance.
|
||||||
|
completed_prefill_reqs = []
|
||||||
|
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
|
||||||
|
|
||||||
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||||
|
|
||||||
|
# Get metadata for this request.
|
||||||
|
request = self.requests[req_id]
|
||||||
|
num_prompt_tokens = len(request.prompt_token_ids)
|
||||||
|
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
||||||
|
self.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Set up target LogprobsTensors object.
|
||||||
|
logprobs_tensors = in_progress_dict.get(req_id)
|
||||||
|
if not logprobs_tensors:
|
||||||
|
# Create empty logprobs CPU tensors for the entire prompt.
|
||||||
|
# If chunked, we'll copy in slice by slice.
|
||||||
|
logprobs_tensors = LogprobsTensors.empty_cpu(
|
||||||
|
num_prompt_tokens - 1, num_prompt_logprobs + 1)
|
||||||
|
in_progress_dict[req_id] = logprobs_tensors
|
||||||
|
|
||||||
|
# Determine number of logits to retrieve.
|
||||||
|
start_idx = request.num_computed_tokens
|
||||||
|
start_tok = start_idx + 1
|
||||||
|
num_remaining_tokens = num_prompt_tokens - start_tok
|
||||||
|
if num_tokens <= num_remaining_tokens:
|
||||||
|
# This is a chunk, more tokens remain.
|
||||||
|
# In the == case, there are no more prompt logprobs to produce
|
||||||
|
# but we want to defer returning them to the next step where we
|
||||||
|
# have new generated tokens to return.
|
||||||
|
num_logits = num_tokens
|
||||||
|
else:
|
||||||
|
# This is the last chunk of prompt tokens to return.
|
||||||
|
num_logits = num_remaining_tokens
|
||||||
|
completed_prefill_reqs.append(req_id)
|
||||||
|
prompt_logprobs_dict[req_id] = logprobs_tensors
|
||||||
|
|
||||||
|
if num_logits <= 0:
|
||||||
|
# This can happen for the final chunk if we prefilled exactly
|
||||||
|
# (num_prompt_tokens - 1) tokens for this request in the prior
|
||||||
|
# step. There are no more prompt logprobs to produce.
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the logits corresponding to this req's prompt tokens.
|
||||||
|
# If this is a partial request (i.e. chunked prefill),
|
||||||
|
# then there is prompt logprob generated for each index.
|
||||||
|
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||||
|
offset = self.query_start_loc_np[req_idx].item()
|
||||||
|
prompt_hidden_states = hidden_states[offset:offset + num_logits]
|
||||||
|
logits = self.model.compute_logits(prompt_hidden_states, None)
|
||||||
|
|
||||||
|
# Get the "target" tokens for each index. For prompt at index i,
|
||||||
|
# the token at prompt index i+1 is the "sampled" token we want
|
||||||
|
# to gather the logprob for.
|
||||||
|
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
|
||||||
|
|
||||||
|
# Compute prompt logprobs.
|
||||||
|
logprobs = self.sampler.compute_logprobs(logits)
|
||||||
|
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
|
||||||
|
logprobs, num_prompt_logprobs, tgt_token_ids)
|
||||||
|
|
||||||
|
# Transfer NPU->CPU async.
|
||||||
|
chunk_slice = slice(start_idx, start_idx + num_logits)
|
||||||
|
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
|
||||||
|
token_ids, non_blocking=True)
|
||||||
|
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
|
||||||
|
non_blocking=True)
|
||||||
|
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
|
||||||
|
ranks, non_blocking=True)
|
||||||
|
|
||||||
|
# Remove requests that have completed prefill from the batch
|
||||||
|
# num_prompt_logprobs_dict.
|
||||||
|
for req_id in completed_prefill_reqs:
|
||||||
|
del num_prompt_logprobs_dict[req_id]
|
||||||
|
del in_progress_dict[req_id]
|
||||||
|
|
||||||
|
# Must synchronize the non-blocking NPU->CPU transfers.
|
||||||
|
if prompt_logprobs_dict:
|
||||||
|
torch.npu.synchronize()
|
||||||
|
|
||||||
|
return prompt_logprobs_dict
|
||||||
|
|
||||||
def init_torchair_graph_batch_sizes(self):
|
def init_torchair_graph_batch_sizes(self):
|
||||||
start_graph_batch_size = 4
|
start_graph_batch_size = 4
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|||||||
Reference in New Issue
Block a user