forked from EngineX-Hygon/enginex-hygon-vllm
init src 0.9.2
This commit is contained in:
359
vllm/zero_overhead/v1/core.py
Normal file
359
vllm/zero_overhead/v1/core.py
Normal file
@@ -0,0 +1,359 @@
|
||||
|
||||
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
|
||||
|
||||
requsets_valid_token_len = {}
|
||||
|
||||
def check_stop(request: Request,
|
||||
max_model_len: int,
|
||||
pooler_output: Optional[torch.Tensor] = None,
|
||||
use_valid_token_len:bool = False) -> bool:
|
||||
if use_valid_token_len:
|
||||
if request.request_id not in requsets_valid_token_len:
|
||||
requsets_valid_token_len[request.request_id] = 0
|
||||
return False
|
||||
valid_output_len = requsets_valid_token_len[request.request_id]
|
||||
else:
|
||||
valid_output_len = request.num_output_tokens
|
||||
valid_num_tokens = request.num_prompt_tokens + valid_output_len
|
||||
if (valid_num_tokens >= max_model_len
|
||||
or valid_output_len >= request.max_tokens):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
return True
|
||||
|
||||
if request.pooling_params:
|
||||
if pooler_output is not None:
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
last_token_id = request.output_token_ids[valid_output_len - 1]
|
||||
if (not sampling_params.ignore_eos
|
||||
and last_token_id == request.eos_token_id):
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
return True
|
||||
|
||||
if last_token_id in (sampling_params.stop_token_ids or ()):
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
request.stop_reason = last_token_id
|
||||
return True
|
||||
return False
|
||||
|
||||
def zero_overhead_update_from_output(scheduler:Scheduler,
|
||||
scheduler_output: SchedulerOutput,
|
||||
model_runner_output: ZeroV1ModelRunnerOutput):
|
||||
global requsets_valid_token_len
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
spec_token_ids = model_runner_output.spec_token_ids
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
pooler_outputs = model_runner_output.pooler_output
|
||||
num_nans_in_logits = model_runner_output.num_nans_in_logits
|
||||
|
||||
new_running: list[Request] = []
|
||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||
|
||||
# fix last model out in zero overhead
|
||||
if model_runner_output.fix_req_ids is not None:
|
||||
for req_idx, req_id in enumerate(model_runner_output.fix_req_ids):
|
||||
if req_id not in scheduler.requests:
|
||||
continue
|
||||
request = scheduler.requests[req_id]
|
||||
generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx]
|
||||
if req_id not in requsets_valid_token_len:
|
||||
requsets_valid_token_len[req_id] = 0
|
||||
valid_output_len = requsets_valid_token_len[req_id]
|
||||
fix_offset = valid_output_len - request.num_output_tokens
|
||||
if isinstance(generated_token_ids, int):
|
||||
request._output_token_ids[fix_offset] = generated_token_ids
|
||||
request._all_token_ids[fix_offset] = generated_token_ids
|
||||
requsets_valid_token_len[req_id] += 1
|
||||
generated_token_ids = [generated_token_ids]
|
||||
else:
|
||||
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
|
||||
if valid_output_end == 0:
|
||||
request._output_token_ids[fix_offset : ] = generated_token_ids
|
||||
request._all_token_ids[fix_offset : ] = generated_token_ids
|
||||
else:
|
||||
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids
|
||||
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids
|
||||
requsets_valid_token_len[req_id] += len(generated_token_ids)
|
||||
|
||||
|
||||
stopped = False
|
||||
new_logprobs = None
|
||||
new_token_ids = generated_token_ids
|
||||
kv_transfer_params = None
|
||||
|
||||
# Check for stop and update request state.
|
||||
# This must be called before we make the EngineCoreOutput.
|
||||
for num_new, output_token_id in enumerate(new_token_ids, 1):
|
||||
stopped = check_stop(request, scheduler.max_model_len, True)
|
||||
if stopped:
|
||||
kv_transfer_params = scheduler._free_request(request)
|
||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||
break
|
||||
|
||||
pooler_output = None
|
||||
if pooler_outputs:
|
||||
pooler_output = pooler_outputs[req_idx]
|
||||
stopped = check_stop(request, scheduler.max_model_len,
|
||||
pooler_output, True)
|
||||
if stopped:
|
||||
kv_transfer_params = scheduler._free_request(request)
|
||||
|
||||
# Extract sample logprobs if needed.
|
||||
if request.sampling_params is not None \
|
||||
and request.sampling_params.logprobs is not None and logprobs:
|
||||
# NOTE: once we support N tokens per step (spec decode),
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_idx, req_idx + 1)
|
||||
|
||||
if new_token_ids and scheduler.structured_output_manager.should_advance(
|
||||
request):
|
||||
# NOTE: structured_output_request
|
||||
# should not be None if use_structured_output, we have
|
||||
# check above, so safe to ignore type warning
|
||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||
req_id, new_token_ids)
|
||||
|
||||
# spec_token_ids comes from the model runner output
|
||||
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||
|
||||
# Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
if new_token_ids or pooler_output is not None \
|
||||
or kv_transfer_params:
|
||||
|
||||
# Add EngineCoreOutput for this Request.
|
||||
outputs[request.client_index].append(
|
||||
EngineCoreOutput(
|
||||
request_id=req_id,
|
||||
new_token_ids=new_token_ids,
|
||||
finish_reason=request.get_finished_reason(),
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
pooling_output=pooler_output,
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events(),
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
num_cached_tokens=request.num_cached_tokens,
|
||||
))
|
||||
|
||||
else:
|
||||
# Invariant: EngineCore returns no partial prefill outputs.
|
||||
assert not prompt_logprobs_tensors
|
||||
|
||||
# fix last model out in zero overhead
|
||||
if model_runner_output.fix_draft_req_ids is not None:
|
||||
for req_idx, req_id in enumerate(model_runner_output.fix_draft_req_ids):
|
||||
if req_id not in scheduler.requests:
|
||||
continue
|
||||
request = scheduler.requests[req_id]
|
||||
|
||||
# Add newly generated spec token ids to the request.
|
||||
if model_runner_output.fix_draft_tokens_ids is not None:
|
||||
if scheduler.structured_output_manager.should_advance(request):
|
||||
metadata = request.structured_output_request
|
||||
# Needs to happen after new_token_ids are accepted.
|
||||
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
|
||||
model_runner_output.fix_draft_tokens_ids[req_idx])
|
||||
else:
|
||||
request.spec_token_ids = model_runner_output.fix_draft_tokens_ids[req_idx]
|
||||
|
||||
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
||||
# loop can be a performance bottleneck. We should do our best to avoid
|
||||
# expensive operations inside the loop.
|
||||
for request in scheduler.running:
|
||||
req_id = request.request_id
|
||||
if request.is_finished():
|
||||
if req_id in requsets_valid_token_len:
|
||||
requsets_valid_token_len.pop(req_id)
|
||||
continue
|
||||
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
||||
if num_tokens_scheduled == 0:
|
||||
# The request was not scheduled in this step.
|
||||
new_running.append(request)
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = sampled_token_ids[
|
||||
req_index] if sampled_token_ids else []
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
if scheduled_spec_token_ids:
|
||||
# num_computed_tokens represents the number of tokens
|
||||
# processed in the current step, considering scheduled
|
||||
# tokens and rejections. If some tokens are rejected,
|
||||
# num_computed_tokens is decreased by the number of rejected
|
||||
# tokens, where is given by:
|
||||
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
|
||||
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
|
||||
len(generated_token_ids))
|
||||
request.num_computed_tokens -= num_tokens_rejected
|
||||
spec_decoding_stats = scheduler.make_spec_decoding_stats(
|
||||
spec_decoding_stats,
|
||||
num_draft_tokens=len(scheduled_spec_token_ids),
|
||||
num_accepted_tokens=len(generated_token_ids) - 1)
|
||||
|
||||
# NOTE(woosuk): This has to be executed after updating
|
||||
# `request.num_computed_tokens`.
|
||||
if request.has_encoder_inputs:
|
||||
scheduler._free_encoder_inputs(request)
|
||||
|
||||
stopped = False
|
||||
new_logprobs = None
|
||||
new_token_ids = generated_token_ids
|
||||
kv_transfer_params = None
|
||||
|
||||
# Append generated tokens and check for stop. Note that if
|
||||
# a request is still being prefilled, we expect the model runner
|
||||
# to return empty token ids for the request.
|
||||
for num_new, output_token_id in enumerate(new_token_ids, 1):
|
||||
request.append_output_token_ids(output_token_id)
|
||||
|
||||
# Check for stop and update request state.
|
||||
# This must be called before we make the EngineCoreOutput.
|
||||
|
||||
if model_runner_output.is_output_valid:
|
||||
stopped = check_stop(request, scheduler.max_model_len,
|
||||
False)
|
||||
if stopped:
|
||||
kv_transfer_params = scheduler._free_request(request)
|
||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||
break
|
||||
|
||||
pooler_output = None
|
||||
if pooler_outputs:
|
||||
if model_runner_output.is_output_valid:
|
||||
pooler_output = pooler_outputs[req_index]
|
||||
stopped = check_stop(request, scheduler.max_model_len,
|
||||
pooler_output,
|
||||
False)
|
||||
if stopped:
|
||||
kv_transfer_params = scheduler._free_request(request)
|
||||
|
||||
# Extract sample logprobs if needed.
|
||||
if request.sampling_params is not None \
|
||||
and request.sampling_params.logprobs is not None and logprobs:
|
||||
# NOTE: once we support N tokens per step (spec decode),
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
|
||||
if new_token_ids and scheduler.structured_output_manager.should_advance(
|
||||
request):
|
||||
# NOTE: structured_output_request
|
||||
# should not be None if use_structured_output, we have
|
||||
# check above, so safe to ignore type warning
|
||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||
req_id, new_token_ids)
|
||||
|
||||
# spec_token_ids comes from the model runner output
|
||||
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||
|
||||
# Add newly generated spec token ids to the request.
|
||||
if spec_token_ids is not None:
|
||||
if scheduler.structured_output_manager.should_advance(request):
|
||||
metadata = request.structured_output_request
|
||||
# Needs to happen after new_token_ids are accepted.
|
||||
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
|
||||
spec_token_ids[req_index])
|
||||
else:
|
||||
request.spec_token_ids = spec_token_ids[req_index]
|
||||
|
||||
if model_runner_output.is_output_valid:
|
||||
# # Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
if new_token_ids or pooler_output is not None \
|
||||
or kv_transfer_params:
|
||||
|
||||
# Add EngineCoreOutput for this Request.
|
||||
outputs[request.client_index].append(
|
||||
EngineCoreOutput(
|
||||
request_id=req_id,
|
||||
new_token_ids=new_token_ids,
|
||||
finish_reason=request.get_finished_reason(),
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
pooling_output=pooler_output,
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events(),
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
num_cached_tokens=request.num_cached_tokens,
|
||||
))
|
||||
if stopped:
|
||||
if req_id in requsets_valid_token_len:
|
||||
requsets_valid_token_len.pop(req_id)
|
||||
else:
|
||||
new_running.append(request)
|
||||
|
||||
scheduler.running = new_running
|
||||
|
||||
# KV Connector: update state for finished KV Transfers.
|
||||
scheduler._update_from_kv_xfer_finished(model_runner_output)
|
||||
|
||||
# Create EngineCoreOutputs for all clients that have requests with
|
||||
# outputs in this step.
|
||||
engine_core_outputs = {
|
||||
client_index: EngineCoreOutputs(outputs=outs)
|
||||
for client_index, outs in outputs.items()
|
||||
}
|
||||
|
||||
finished_req_ids = scheduler.finished_req_ids_dict
|
||||
if finished_req_ids:
|
||||
# Include ids of requests that finished since last outputs
|
||||
# were sent.
|
||||
for client_index, finished_set in finished_req_ids.items():
|
||||
# Set finished request set in EngineCoreOutputs for this client.
|
||||
if (eco := engine_core_outputs.get(client_index)) is not None:
|
||||
eco.finished_requests = finished_set
|
||||
else:
|
||||
engine_core_outputs[client_index] = EngineCoreOutputs(
|
||||
finished_requests=finished_set)
|
||||
finished_req_ids.clear()
|
||||
|
||||
if engine_core_outputs:
|
||||
# Return stats to only one of the front-ends.
|
||||
next(iter(engine_core_outputs.values())).scheduler_stats = (
|
||||
scheduler.make_stats(spec_decoding_stats))
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
|
||||
def engine_core_step(core) -> tuple[dict[int, EngineCoreOutputs], bool]:
|
||||
"""Schedule, execute, and make output.
|
||||
|
||||
Returns tuple of outputs and a flag indicating whether the model
|
||||
was executed.
|
||||
"""
|
||||
|
||||
# Check for any requests remaining in the scheduler - unfinished,
|
||||
# or finished and not yet removed from the batch.
|
||||
if not core.scheduler.has_requests():
|
||||
return {}, False
|
||||
scheduler_output = core.scheduler.schedule()
|
||||
model_output = core.execute_model(scheduler_output)
|
||||
if isinstance(model_output, ZeroV1ModelRunnerOutput):
|
||||
engine_core_outputs = zero_overhead_update_from_output(core.scheduler,
|
||||
scheduler_output, model_output) # type: ignore
|
||||
else:
|
||||
engine_core_outputs = core.scheduler.update_from_output(
|
||||
scheduler_output, model_output) # type: ignore
|
||||
|
||||
return (engine_core_outputs,
|
||||
scheduler_output.total_num_scheduled_tokens > 0)
|
||||
317
vllm/zero_overhead/v1/eagle.py
Normal file
317
vllm/zero_overhead/v1/eagle.py
Normal file
@@ -0,0 +1,317 @@
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
|
||||
|
||||
|
||||
class V1ZeroEagleProposer(EagleProposer):
|
||||
def __init__(self, vllm_config, device, runner=None):
|
||||
super().__init__(vllm_config, device, runner)
|
||||
self.spec_scheduler_max_num_tokens = 0
|
||||
|
||||
|
||||
def propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
target_token_ids: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_positions: torch.Tensor,
|
||||
# [num_tokens, hidden_size]
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_slot_mapping: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
# [batch_size + 1] starting with 0
|
||||
cu_num_tokens: torch.Tensor,
|
||||
# [batch_size, max_num_blocks_per_req]
|
||||
block_table: torch.Tensor,
|
||||
# [batch_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
decoding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
last_token_indices = cu_num_tokens[1:] - 1
|
||||
|
||||
if self.method == "eagle3":
|
||||
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||
target_hidden_states = self.model.combine_hidden_states(
|
||||
target_hidden_states)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
# FA requires seq_len to have dtype int32.
|
||||
seq_lens = (target_positions[last_token_indices] + 1).int()
|
||||
|
||||
if self.method in ["eagle", "eagle3"]:
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
max_seq_len = seq_lens.max().item()
|
||||
max_num_tokens = (cu_num_tokens[1:] -
|
||||
cu_num_tokens[:-1]).max().item()
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_num_tokens,
|
||||
query_start_loc=cu_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table,
|
||||
slot_mapping=target_slot_mapping,
|
||||
# TODO(woosuk): Support cascade attention.
|
||||
use_cascade=False,
|
||||
common_prefix_len=0,
|
||||
cu_prefix_query_lens=None,
|
||||
prefix_kv_lens=None,
|
||||
suffix_kv_lens=None,
|
||||
)
|
||||
elif self.method == "deepseek_mtp":
|
||||
max_query_len = self.spec_scheduler_max_num_tokens
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=cu_num_tokens,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=target_slot_mapping,
|
||||
spec_layer_decoding=decoding
|
||||
)
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata = self.runner.attn_metadata_builders[0].build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {self.method}")
|
||||
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
if (decoding and self.use_full_cuda_graph
|
||||
and num_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
assert self.attn_metadata_cudagraph
|
||||
if self.method in ["eagle", "eagle3"]:
|
||||
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
|
||||
attn_metadata.seq_lens)
|
||||
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
|
||||
attn_metadata.slot_mapping)
|
||||
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
|
||||
attn_metadata.query_start_loc)
|
||||
self.attn_metadata_cudagraph.block_table[:batch_size] = (
|
||||
attn_metadata.block_table)
|
||||
elif self.method == "deepseek_mtp":
|
||||
self.attn_metadata_cudagraph.num_actual_tokens = (
|
||||
attn_metadata.num_actual_tokens)
|
||||
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
|
||||
attn_metadata.query_start_loc)
|
||||
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
|
||||
attn_metadata.slot_mapping)
|
||||
self.attn_metadata_cudagraph.num_decodes = (
|
||||
attn_metadata.num_decodes)
|
||||
self.attn_metadata_cudagraph.num_decode_tokens = (
|
||||
attn_metadata.num_decode_tokens)
|
||||
self.attn_metadata_cudagraph.num_prefills = (
|
||||
attn_metadata.num_prefills)
|
||||
|
||||
if attn_metadata.decode is not None:
|
||||
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.decode.block_table)
|
||||
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.decode.seq_lens)
|
||||
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
skip_cuda_graphs=not decoding):
|
||||
ret_hidden_states = self.model(
|
||||
self.input_ids[:num_input_tokens],
|
||||
self.positions[:num_input_tokens],
|
||||
self.hidden_states[:num_input_tokens],
|
||||
)
|
||||
if self.method == "deepseek_mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
# TODO: Currently, MTP module released by deepseek only has
|
||||
# one layer. Adapt this code to support multiple layers once
|
||||
# there's a multi-layer MTP module.
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
positions = target_positions[last_token_indices]
|
||||
|
||||
if self.method == "deepseek_mtp":
|
||||
hidden_states = last_hidden_states[last_token_indices]
|
||||
else:
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
|
||||
if self.use_cuda_graph and \
|
||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.max_query_len = 1
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
attn_metadata.num_decodes = batch_size
|
||||
attn_metadata.num_decode_tokens = batch_size
|
||||
attn_metadata.num_prefills = 0
|
||||
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
|
||||
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
|
||||
block_table_tensor=block_table,
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
|
||||
for i in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
# but adjust the position ids and slot mappings to avoid the
|
||||
# out-of-range access during the model execution. The draft tokens
|
||||
# generated with this adjustment should be ignored.
|
||||
exceeds_max_model_len = positions >= self.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
attn_metadata.decode.seq_lens += 1
|
||||
else:
|
||||
attn_metadata.seq_lens += 1
|
||||
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata.max_seq_len += 1
|
||||
# Consider max model length.
|
||||
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
|
||||
self.max_model_len)
|
||||
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_numbers = clamped_positions // self.block_size
|
||||
block_ids = block_table.gather(dim=1,
|
||||
index=block_numbers.view(-1, 1))
|
||||
block_ids = block_ids.view(-1)
|
||||
attn_metadata.slot_mapping = (block_ids * self.block_size +
|
||||
clamped_positions % self.block_size)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
|
||||
PADDING_SLOT_ID)
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
|
||||
if (self.use_full_cuda_graph
|
||||
and batch_size <= self.cudagraph_batch_sizes[-1]):
|
||||
assert self.attn_metadata_cudagraph
|
||||
if self.method in ["eagle", "eagle3"]:
|
||||
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
|
||||
attn_metadata.seq_lens)
|
||||
self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
|
||||
attn_metadata.slot_mapping)
|
||||
if i == 0:
|
||||
self.attn_metadata_cudagraph.query_start_loc[:batch_size +
|
||||
1] = (
|
||||
attn_metadata
|
||||
.
|
||||
query_start_loc
|
||||
)
|
||||
self.attn_metadata_cudagraph.block_table[:batch_size] = (
|
||||
attn_metadata.block_table)
|
||||
elif self.method == "deepseek_mtp":
|
||||
self.attn_metadata_cudagraph.num_actual_tokens = (
|
||||
attn_metadata.num_actual_tokens)
|
||||
self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.slot_mapping)
|
||||
self.attn_metadata_cudagraph.num_decodes = (
|
||||
attn_metadata.num_decodes)
|
||||
self.attn_metadata_cudagraph.num_decode_tokens = (
|
||||
attn_metadata.num_decode_tokens)
|
||||
self.attn_metadata_cudagraph.num_prefills = (
|
||||
attn_metadata.num_prefills)
|
||||
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.decode.seq_lens)
|
||||
|
||||
if i == 0:
|
||||
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
|
||||
attn_metadata.query_start_loc)
|
||||
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.decode.block_table)
|
||||
|
||||
# Run the model.
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
ret_hidden_states = self.model(
|
||||
self.input_ids[:input_batch_size],
|
||||
self.positions[:input_batch_size],
|
||||
self.hidden_states[:input_batch_size],
|
||||
)
|
||||
if self.method == "deepseek_mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states[:batch_size]
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||
None)
|
||||
|
||||
# TODO(wenlong): get more than one token for tree attention
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
# [batch_size, num_speculative_tokens]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
|
||||
return draft_token_ids
|
||||
749
vllm/zero_overhead/v1/gpu_model_runner.py
Normal file
749
vllm/zero_overhead/v1/gpu_model_runner.py
Normal file
@@ -0,0 +1,749 @@
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
import torch
|
||||
import numpy as np
|
||||
from vllm import envs
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import async_tensor_h2d, round_up
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.zero_overhead.v1.eagle import V1ZeroEagleProposer
|
||||
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
|
||||
from vllm.profiler.prof import profile
|
||||
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
|
||||
|
||||
|
||||
class V1ZeroModelRunner(GPUModelRunner):
|
||||
def __init__(self, vllm_config, device):
|
||||
super().__init__(vllm_config, device)
|
||||
self.last_sampled_token_ids = None
|
||||
self.last_sampled_req_ids = []
|
||||
self.last_sampled_token_lens = []
|
||||
self.last_sampler_event = torch.cuda.Event(enable_timing=False)
|
||||
self.last_sampler_host_tokens = None
|
||||
self.token_ids_cpu_fix_record = []
|
||||
self.last_draft_token_ids = None
|
||||
self.last_draft_host_tokens = None
|
||||
self.last_draft_event = torch.cuda.Event(enable_timing=False)
|
||||
self.spec_sampler_event = torch.cuda.Event(enable_timing=False)
|
||||
self.spec_scheduler_max_num_tokens = 0
|
||||
if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
|
||||
self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device,
|
||||
self)
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[dict[str, Any], bool, torch.Tensor,
|
||||
Optional[SpecDecodeMetadata], np.ndarray]:
|
||||
"""
|
||||
:return: tuple[
|
||||
attn_metadata: layer-to-attention_metadata mapping,
|
||||
attention_cuda_graphs: whether attention can run in cudagraph
|
||||
logits_indices, spec_decode_metadata
|
||||
]
|
||||
"""
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit(num_reqs)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
req_ids = self.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
self.spec_scheduler_max_num_tokens = max_num_scheduled_tokens
|
||||
|
||||
# Get request indices.
|
||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens)
|
||||
|
||||
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
||||
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Get positions.
|
||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
out=positions_np)
|
||||
|
||||
# Calculate M-RoPE positions.
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
|
||||
# Get token indices.
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
||||
# where M is the max_model_len.
|
||||
token_indices = (positions_np +
|
||||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices),
|
||||
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Calculate the slot mapping for each KV cache group.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
block_size = kv_cache_group_spec.kv_cache_spec.block_size
|
||||
block_table: BlockTable = self.input_batch.block_table[
|
||||
kv_cache_group_id]
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
block_table_indices = (
|
||||
req_indices * block_table.max_num_blocks_per_req +
|
||||
positions_np // block_size)
|
||||
block_table_cpu = block_table.get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten(
|
||||
)[block_table_indices].numpy()
|
||||
block_offsets = positions_np % block_size
|
||||
np.add(
|
||||
block_numbers * block_size,
|
||||
block_offsets,
|
||||
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Copy the tensors to the GPU.
|
||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
|
||||
self.zero_prepare_inputs(scheduler_output, self.input_ids)
|
||||
|
||||
if self.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
||||
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
else:
|
||||
# Common case (1D positions)
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
self.query_start_loc[:num_reqs + 1].copy_(
|
||||
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
||||
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
self.query_start_loc[num_reqs + 1:].fill_(
|
||||
self.query_start_loc_cpu[num_reqs].item())
|
||||
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
)
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
builder = self.attn_metadata_builders[kv_cache_group_id]
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id],
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
builder,
|
||||
)
|
||||
|
||||
attn_metadata_i = (builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
attention_cuda_graphs = all(
|
||||
b.can_run_in_cudagraph(common_attn_metadata)
|
||||
for b in self.attn_metadata_builders)
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if not use_spec_decode:
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||||
# partial requests. While we should not sample any token
|
||||
# from these partial requests, we do so for simplicity.
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
spec_decode_metadata = None
|
||||
else:
|
||||
# Get the number of draft tokens for each request.
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
# requests have draft tokens.
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for req_id, draft_token_ids in (
|
||||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
return (attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
spec_decode_metadata, num_scheduled_tokens)
|
||||
|
||||
def zero_prepare_inputs(self, scheduler_output, input_ids):
|
||||
req_ids = self.input_batch.req_ids
|
||||
update_req_indices = []
|
||||
input_ids_indices = []
|
||||
token_idx = 0
|
||||
if self.last_draft_token_ids is not None:
|
||||
draft_tokens_num = self.last_draft_token_ids.shape[1]
|
||||
for req_id in req_ids:
|
||||
if req_id in self.last_sampled_req_ids:
|
||||
req_idx = self.last_sampled_req_ids.index(req_id) * draft_tokens_num
|
||||
for num_idx in range(draft_tokens_num):
|
||||
update_req_indices.append(req_idx + num_idx)
|
||||
input_ids_indices.append(token_idx + num_idx + 1)
|
||||
token_idx += draft_tokens_num + 1
|
||||
if len(update_req_indices) > 0:
|
||||
update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
|
||||
self.device,
|
||||
True)
|
||||
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
|
||||
self.device,
|
||||
True)
|
||||
last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int)
|
||||
input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor]
|
||||
|
||||
update_req_indices = []
|
||||
input_ids_indices = []
|
||||
token_idx = 0
|
||||
if self.last_sampled_token_ids is not None:
|
||||
sampled_tokens_num = self.last_sampled_token_ids.shape[1]
|
||||
for req_id in req_ids:
|
||||
if req_id in self.last_sampled_req_ids:
|
||||
req_idx = self.last_sampled_req_ids.index(req_id) * sampled_tokens_num
|
||||
update_req_indices.append(req_idx)
|
||||
input_ids_indices.append(token_idx)
|
||||
token_idx += scheduler_output.num_scheduled_tokens[req_id]
|
||||
if len(update_req_indices) > 0:
|
||||
update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
|
||||
self.device,
|
||||
True)
|
||||
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
|
||||
self.device,
|
||||
True)
|
||||
last_sampled_token_ids = self.last_sampled_token_ids.flatten()
|
||||
for i in range(sampled_tokens_num):
|
||||
input_ids[input_ids_indices_tensor + i] = last_sampled_token_ids[update_req_indices_tensor + i]
|
||||
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
num_accepted_tokens_tensor: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
sample_hidden_states: torch.Tensor,
|
||||
aux_hidden_states: Optional[torch.Tensor],
|
||||
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
||||
attn_metadata: dict[str, Any],
|
||||
) -> list[list[int]]:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if self.speculative_config.method == "ngram":
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
spec_token_ids = self.propose_ngram_draft_token_ids(
|
||||
sampled_token_ids)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
if sample_hidden_states.shape[0] == len(sampled_token_ids):
|
||||
# The input to the target model does not include draft tokens.
|
||||
hidden_states = sample_hidden_states
|
||||
else:
|
||||
indices = []
|
||||
offset = 0
|
||||
for num_draft, tokens in zip(
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
sampled_token_ids):
|
||||
indices.append(offset + len(tokens) - 1)
|
||||
offset += num_draft + 1
|
||||
indices = torch.tensor(indices, device=self.device)
|
||||
hidden_states = sample_hidden_states[indices]
|
||||
|
||||
spec_token_ids = self.drafter.propose(
|
||||
target_hidden_states=hidden_states,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
elif self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# TODO(woosuk): Refactor the loop.
|
||||
row_indices = torch.arange(sampled_token_ids.size(0), device=sampled_token_ids.device)
|
||||
next_token_ids = sampled_token_ids[row_indices, num_accepted_tokens_tensor].flatten()
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
eagle_attn_metadata = attn_metadata[
|
||||
self.drafter.attn_layer_names[0]]
|
||||
|
||||
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
|
||||
if hasattr(eagle_attn_metadata, "block_table"):
|
||||
block_table = eagle_attn_metadata.block_table
|
||||
else:
|
||||
block_table = None
|
||||
|
||||
spec_scheduler_max_num_tokens = self.spec_scheduler_max_num_tokens
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
# TODO(woosuk): Support M-RoPE.
|
||||
target_positions = self.positions[:num_scheduled_tokens]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping
|
||||
cu_num_tokens = eagle_attn_metadata.query_start_loc
|
||||
else:
|
||||
# TODO(woosuk): Refactor this.
|
||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||
eagle_attn_metadata.query_start_loc,
|
||||
num_accepted_tokens_tensor,
|
||||
)
|
||||
spec_scheduler_max_num_tokens = 1
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
# TODO(woosuk): Support M-RoPE.
|
||||
target_positions = self.positions[token_indices]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping[
|
||||
token_indices]
|
||||
self.drafter.spec_scheduler_max_num_tokens = spec_scheduler_max_num_tokens
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=block_table,
|
||||
sampling_metadata=sampling_metadata,
|
||||
decoding=spec_decode_metadata is not None,
|
||||
)
|
||||
spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist()
|
||||
self.last_draft_token_ids = draft_token_ids
|
||||
self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True)
|
||||
self.last_draft_event.record()
|
||||
return spec_token_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
spec_decode_metadata,
|
||||
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.use_cuda_graph
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_scheduled_tokens)
|
||||
else:
|
||||
# Eager mode.
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if self.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism and tp_size > 1:
|
||||
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
||||
else:
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
|
||||
# Padding for DP
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
|
||||
num_input_tokens += num_pad
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
# modal outputs after that to ensure the correct order
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
else:
|
||||
mm_embeds = []
|
||||
|
||||
if self.is_multimodal_model and get_pp_group().is_first_rank:
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
input_ids = self.input_ids[:num_scheduled_tokens]
|
||||
if mm_embeds:
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids, mm_embeds)
|
||||
else:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
input_ids = None
|
||||
else:
|
||||
# For text-only models, we use token ids as input.
|
||||
# While it is possible to use embeddings as input just like the
|
||||
# multimodal models, it is not desirable for performance since
|
||||
# then the embedding layer is not included in the CUDA graph.
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
else:
|
||||
positions = self.positions[:num_input_tokens]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
else:
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_input_tokens, intermediate_tensors, True)
|
||||
|
||||
# Some attention backends only support CUDA Graphs in pure decode.
|
||||
# If attention doesn't support CUDA Graphs for this batch, but we
|
||||
# compiled with full CUDA graphs, we have to skip them entirely.
|
||||
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
|
||||
if envs.VLLM_ENABLE_TBO and (not self.use_cuda_graph or skip_cuda_graphs):
|
||||
model_output, finished_sending, finished_recving = \
|
||||
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
|
||||
num_tokens_across_dp, input_ids, positions,
|
||||
inputs_embeds, scheduler_output, intermediate_tensors,
|
||||
skip_cuda_graphs)
|
||||
else:
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs,
|
||||
):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
# Broadcast PP output for external_launcher (torchrun)
|
||||
# to make sure we are synced across pp ranks
|
||||
# TODO: Support overlapping mirco-batches
|
||||
# https://github.com/vllm-project/vllm/issues/18019
|
||||
broadcast_pp_output = \
|
||||
self.parallel_config.distributed_executor_backend \
|
||||
== "external_launcher" and len(get_pp_group().ranks) > 0
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
if not broadcast_pp_output:
|
||||
return hidden_states
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
logits = None
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving)
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
model_output_broadcast_data = {
|
||||
"logits": logits.contiguous(),
|
||||
} if logits is not None else {}
|
||||
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
|
||||
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
|
||||
assert model_output_broadcast_data is not None
|
||||
logits = model_output_broadcast_data["logits"]
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
|
||||
num_nans_in_logits = {}
|
||||
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
||||
num_nans_in_logits = self._get_nans_in_logits(logits)
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
discard_sampled_tokens_req_indices = []
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
if seq_len < req_state.num_tokens:
|
||||
# Ignore the sampled token for partial prefills.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
# Record the index of the request that should not be sampled,
|
||||
# so that we could clear the sampled tokens before returning.
|
||||
discard_sampled_tokens_req_indices.append(i)
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
logprobs_lists = logprobs_tensors.tolists() \
|
||||
if logprobs_tensors is not None else None
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
scheduler_output,
|
||||
)
|
||||
|
||||
fix_req_ids = None
|
||||
fix_sampled_token_ids = None
|
||||
fix_draft_token_ids = None
|
||||
fix_draft_req_ids = self.last_sampled_req_ids
|
||||
is_output_valid = False
|
||||
# Get the valid generated tokens.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if not self.speculative_config:
|
||||
# Speculative decoding is not enabled.
|
||||
spec_token_ids = None
|
||||
fix_draft_req_ids = None
|
||||
else:
|
||||
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
|
||||
self.spec_sampler_event.record()
|
||||
if self.last_draft_host_tokens is not None:
|
||||
self.last_draft_event.synchronize()
|
||||
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
|
||||
|
||||
mask = (sampled_token_ids == -1)
|
||||
mask_int = mask.int()
|
||||
first_neg_one_indices = torch.argmax(mask_int, dim=1)
|
||||
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
|
||||
spec_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
num_accepted_tokens_tensor,
|
||||
sampled_token_ids,
|
||||
sampling_metadata,
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
spec_decode_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
if self.speculative_config:
|
||||
self.spec_sampler_event.synchronize()
|
||||
if max_gen_len == 1:
|
||||
valid_sampled_token_ids = sampled_token_ids_cpu.tolist()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids_cpu,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
self.last_sampler_host_tokens = None
|
||||
self.last_sampled_token_ids = None
|
||||
is_output_valid = True
|
||||
else:
|
||||
# No spec decode tokens.
|
||||
fix_req_ids = self.last_sampled_req_ids
|
||||
if self.last_sampler_host_tokens != None:
|
||||
self.last_sampler_event.synchronize()
|
||||
fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
|
||||
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
|
||||
if start_idx == -1:
|
||||
continue
|
||||
req_id = fix_req_ids[req_idx]
|
||||
if req_id in self.input_batch.req_ids:
|
||||
new_req_idx = self.input_batch.req_ids.index(req_id)
|
||||
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx]
|
||||
for req_idx, req_id in enumerate(fix_req_ids):
|
||||
if req_id in self.requests:
|
||||
req_state = self.requests[req_id]
|
||||
token_idx = self.last_sampled_token_lens[req_idx]
|
||||
if token_idx == -1:
|
||||
continue
|
||||
fix_len = len(fix_sampled_token_ids[req_idx])
|
||||
req_state.output_token_ids[token_idx:token_idx + fix_len] = fix_sampled_token_ids[req_idx]
|
||||
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
|
||||
self.last_sampler_event.record()
|
||||
self.last_sampled_token_ids = sampled_token_ids
|
||||
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
|
||||
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
|
||||
# Cache the sampled tokens in the model runner, so that the scheduler
|
||||
# doesn't need to send them back.
|
||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
self.token_ids_cpu_fix_record.clear()
|
||||
self.last_sampled_req_ids = []
|
||||
self.last_sampled_token_lens = []
|
||||
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||
req_id = self.input_batch.req_ids[req_idx]
|
||||
self.last_sampled_req_ids.append(req_id)
|
||||
cache_output_len = -1
|
||||
if not sampled_ids:
|
||||
self.last_sampled_token_lens.append(-1)
|
||||
self.token_ids_cpu_fix_record.append([req_idx, -1, -1])
|
||||
continue
|
||||
|
||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||
end_idx = start_idx + len(sampled_ids)
|
||||
assert end_idx <= self.max_model_len, (
|
||||
"Sampled token IDs exceed the max model length. "
|
||||
f"Total number of tokens: {end_idx} > max_model_len: "
|
||||
f"{self.max_model_len}")
|
||||
|
||||
self.input_batch.token_ids_cpu[req_idx,
|
||||
start_idx:end_idx] = sampled_ids
|
||||
self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx])
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
if req_id in self.requests:
|
||||
req_state = self.requests[req_id]
|
||||
cache_output_len = len(req_state.output_token_ids)
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
self.last_sampled_token_lens.append(cache_output_len)
|
||||
|
||||
|
||||
# Clear KVConnector state after all KVs are generated.
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().clear_connector_metadata()
|
||||
|
||||
self.eplb_step()
|
||||
|
||||
model_output = ZeroV1ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
fix_req_ids = fix_req_ids,
|
||||
fix_sampled_token_ids = fix_sampled_token_ids,
|
||||
fix_draft_tokens_ids = fix_draft_token_ids,
|
||||
fix_draft_req_ids = fix_draft_req_ids,
|
||||
is_output_valid=is_output_valid
|
||||
)
|
||||
return model_output
|
||||
14
vllm/zero_overhead/v1/outputs.py
Normal file
14
vllm/zero_overhead/v1/outputs.py
Normal file
@@ -0,0 +1,14 @@
|
||||
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
@dataclass
|
||||
class ZeroV1ModelRunnerOutput(ModelRunnerOutput):
|
||||
# [num_reqs]
|
||||
fix_req_ids: list[str] = None
|
||||
fix_sampled_token_ids:list[list[int]] = None
|
||||
fix_draft_req_ids:list[str] = None
|
||||
fix_draft_tokens_ids:list[list[int]] = None
|
||||
is_output_valid:bool = True
|
||||
Reference in New Issue
Block a user