import torch from collections import defaultdict from typing import Optional from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput requsets_valid_token_len = {} def check_stop(request: Request, max_model_len: int, pooler_output: Optional[torch.Tensor] = None, use_valid_token_len:bool = False) -> bool: if use_valid_token_len: if request.request_id not in requsets_valid_token_len: requsets_valid_token_len[request.request_id] = 0 return False valid_output_len = requsets_valid_token_len[request.request_id] else: valid_output_len = request.num_output_tokens valid_num_tokens = request.num_prompt_tokens + valid_output_len if (valid_num_tokens >= max_model_len or valid_output_len >= request.max_tokens): request.status = RequestStatus.FINISHED_LENGTH_CAPPED return True if request.pooling_params: if pooler_output is not None: request.status = RequestStatus.FINISHED_STOPPED return True return False sampling_params = request.sampling_params assert sampling_params is not None last_token_id = request.output_token_ids[valid_output_len - 1] if (not sampling_params.ignore_eos and last_token_id == request.eos_token_id): request.status = RequestStatus.FINISHED_STOPPED return True if last_token_id in (sampling_params.stop_token_ids or ()): request.status = RequestStatus.FINISHED_STOPPED request.stop_reason = last_token_id return True return False def zero_overhead_update_from_output(scheduler:Scheduler, scheduler_output: SchedulerOutput, model_runner_output: ZeroV1ModelRunnerOutput): global requsets_valid_token_len sampled_token_ids = model_runner_output.sampled_token_ids spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits new_running: list[Request] = [] outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None # fix last model out in zero overhead if model_runner_output.fix_req_ids is not None: for req_idx, req_id in enumerate(model_runner_output.fix_req_ids): if req_id not in scheduler.requests: continue request = scheduler.requests[req_id] generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx] if req_id not in requsets_valid_token_len: requsets_valid_token_len[req_id] = 0 valid_output_len = requsets_valid_token_len[req_id] fix_offset = valid_output_len - request.num_output_tokens if isinstance(generated_token_ids, int): request._output_token_ids[fix_offset] = generated_token_ids request._all_token_ids[fix_offset] = generated_token_ids requsets_valid_token_len[req_id] += 1 generated_token_ids = [generated_token_ids] else: valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens if valid_output_end == 0: request._output_token_ids[fix_offset : ] = generated_token_ids request._all_token_ids[fix_offset : ] = generated_token_ids else: request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids requsets_valid_token_len[req_id] += len(generated_token_ids) stopped = False new_logprobs = None new_token_ids = generated_token_ids kv_transfer_params = None # Check for stop and update request state. # This must be called before we make the EngineCoreOutput. for num_new, output_token_id in enumerate(new_token_ids, 1): stopped = check_stop(request, scheduler.max_model_len, True) if stopped: kv_transfer_params = scheduler._free_request(request) del new_token_ids[num_new:] # Trim new tokens if needed. break pooler_output = None if pooler_outputs: pooler_output = pooler_outputs[req_idx] stopped = check_stop(request, scheduler.max_model_len, pooler_output, True) if stopped: kv_transfer_params = scheduler._free_request(request) # Extract sample logprobs if needed. if request.sampling_params is not None \ and request.sampling_params.logprobs is not None and logprobs: # NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_idx, req_idx + 1) if new_token_ids and scheduler.structured_output_manager.should_advance( request): # NOTE: structured_output_request # should not be None if use_structured_output, we have # check above, so safe to ignore type warning request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] req_id, new_token_ids) # spec_token_ids comes from the model runner output if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids or pooler_output is not None \ or kv_transfer_params: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( EngineCoreOutput( request_id=req_id, new_token_ids=new_token_ids, finish_reason=request.get_finished_reason(), new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, pooling_output=pooler_output, stop_reason=request.stop_reason, events=request.take_events(), kv_transfer_params=kv_transfer_params, num_cached_tokens=request.num_cached_tokens, )) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors # fix last model out in zero overhead if model_runner_output.fix_draft_req_ids is not None: for req_idx, req_id in enumerate(model_runner_output.fix_draft_req_ids): if req_id not in scheduler.requests: continue request = scheduler.requests[req_id] # Add newly generated spec token ids to the request. if model_runner_output.fix_draft_tokens_ids is not None: if scheduler.structured_output_manager.should_advance(request): metadata = request.structured_output_request # Needs to happen after new_token_ids are accepted. request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] model_runner_output.fix_draft_tokens_ids[req_idx]) else: request.spec_token_ids = model_runner_output.fix_draft_tokens_ids[req_idx] # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # loop can be a performance bottleneck. We should do our best to avoid # expensive operations inside the loop. for request in scheduler.running: req_id = request.request_id if request.is_finished(): if req_id in requsets_valid_token_len: requsets_valid_token_len.pop(req_id) continue num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) if num_tokens_scheduled == 0: # The request was not scheduled in this step. new_running.append(request) continue req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[ req_index] if sampled_token_ids else [] scheduled_spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens.get(req_id)) if scheduled_spec_token_ids: # num_computed_tokens represents the number of tokens # processed in the current step, considering scheduled # tokens and rejections. If some tokens are rejected, # num_computed_tokens is decreased by the number of rejected # tokens, where is given by: # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - len(generated_token_ids)) request.num_computed_tokens -= num_tokens_rejected spec_decoding_stats = scheduler.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=len(scheduled_spec_token_ids), num_accepted_tokens=len(generated_token_ids) - 1) # NOTE(woosuk): This has to be executed after updating # `request.num_computed_tokens`. if request.has_encoder_inputs: scheduler._free_encoder_inputs(request) stopped = False new_logprobs = None new_token_ids = generated_token_ids kv_transfer_params = None # Append generated tokens and check for stop. Note that if # a request is still being prefilled, we expect the model runner # to return empty token ids for the request. for num_new, output_token_id in enumerate(new_token_ids, 1): request.append_output_token_ids(output_token_id) # Check for stop and update request state. # This must be called before we make the EngineCoreOutput. if model_runner_output.is_output_valid: stopped = check_stop(request, scheduler.max_model_len, False) if stopped: kv_transfer_params = scheduler._free_request(request) del new_token_ids[num_new:] # Trim new tokens if needed. break pooler_output = None if pooler_outputs: if model_runner_output.is_output_valid: pooler_output = pooler_outputs[req_index] stopped = check_stop(request, scheduler.max_model_len, pooler_output, False) if stopped: kv_transfer_params = scheduler._free_request(request) # Extract sample logprobs if needed. if request.sampling_params is not None \ and request.sampling_params.logprobs is not None and logprobs: # NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) if new_token_ids and scheduler.structured_output_manager.should_advance( request): # NOTE: structured_output_request # should not be None if use_structured_output, we have # check above, so safe to ignore type warning request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] req_id, new_token_ids) # spec_token_ids comes from the model runner output if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] # Add newly generated spec token ids to the request. if spec_token_ids is not None: if scheduler.structured_output_manager.should_advance(request): metadata = request.structured_output_request # Needs to happen after new_token_ids are accepted. request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] spec_token_ids[req_index]) else: request.spec_token_ids = spec_token_ids[req_index] if model_runner_output.is_output_valid: # # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids or pooler_output is not None \ or kv_transfer_params: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( EngineCoreOutput( request_id=req_id, new_token_ids=new_token_ids, finish_reason=request.get_finished_reason(), new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, pooling_output=pooler_output, stop_reason=request.stop_reason, events=request.take_events(), kv_transfer_params=kv_transfer_params, num_cached_tokens=request.num_cached_tokens, )) if stopped: if req_id in requsets_valid_token_len: requsets_valid_token_len.pop(req_id) else: new_running.append(request) scheduler.running = new_running # KV Connector: update state for finished KV Transfers. scheduler._update_from_kv_xfer_finished(model_runner_output) # Create EngineCoreOutputs for all clients that have requests with # outputs in this step. engine_core_outputs = { client_index: EngineCoreOutputs(outputs=outs) for client_index, outs in outputs.items() } finished_req_ids = scheduler.finished_req_ids_dict if finished_req_ids: # Include ids of requests that finished since last outputs # were sent. for client_index, finished_set in finished_req_ids.items(): # Set finished request set in EngineCoreOutputs for this client. if (eco := engine_core_outputs.get(client_index)) is not None: eco.finished_requests = finished_set else: engine_core_outputs[client_index] = EngineCoreOutputs( finished_requests=finished_set) finished_req_ids.clear() if engine_core_outputs: # Return stats to only one of the front-ends. next(iter(engine_core_outputs.values())).scheduler_stats = ( scheduler.make_stats(spec_decoding_stats)) return engine_core_outputs def engine_core_step(core) -> tuple[dict[int, EngineCoreOutputs], bool]: """Schedule, execute, and make output. Returns tuple of outputs and a flag indicating whether the model was executed. """ # Check for any requests remaining in the scheduler - unfinished, # or finished and not yet removed from the batch. if not core.scheduler.has_requests(): return {}, False scheduler_output = core.scheduler.schedule() model_output = core.execute_model(scheduler_output) if isinstance(model_output, ZeroV1ModelRunnerOutput): engine_core_outputs = zero_overhead_update_from_output(core.scheduler, scheduler_output, model_output) # type: ignore else: engine_core_outputs = core.scheduler.update_from_output( scheduler_output, model_output) # type: ignore return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0)