init src 0.9.2
This commit is contained in:
661
vllm/zero_overhead/llm_engine.py
Normal file
661
vllm/zero_overhead/llm_engine.py
Normal file
@@ -0,0 +1,661 @@
|
||||
from functools import partial
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Callable, Dict, List, Mapping, Optional, Type, Union
|
||||
from zlib import ZLIB_VERSION
|
||||
import torch
|
||||
from vllm import envs
|
||||
from vllm.config import DecodingConfig, ObservabilityConfig, VllmConfig
|
||||
from vllm.core.scheduler import ScheduledSequenceGroup
|
||||
from vllm.engine.llm_engine import _LOCAL_LOGGING_INTERVAL_SEC, LLMEngine, SchedulerContext, SchedulerOutputState
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.inputs.data import ProcessorInputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.inputs.registry import InputRegistry
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.registry import MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, ParallelSampleSequenceGroup, SequenceGroup, SequenceGroupBase, SequenceGroupMetadata
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled
|
||||
from vllm.utils import resolve_obj_by_qualname, weak_bind, Counter
|
||||
from vllm.zero_overhead.sampler import SampleRecorder, get_last_sampler
|
||||
from vllm.zero_overhead.sequence import ZeroOverheadSequence
|
||||
from vllm.zero_overhead.stop_check import ZeroOverheadStopChecker
|
||||
from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.profiler.prof import profile
|
||||
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_spec_step, is_zero_no_thread, set_spec_step, zero_overhead_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class ZeroOverheadEngine(LLMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[ExecutorBase],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
) -> None:
|
||||
if envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"LLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config # noqa
|
||||
self.load_config = vllm_config.load_config
|
||||
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
|
||||
)
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
|
||||
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Initializing a V0 LLM engine (v%s) with config: %s, "
|
||||
"use_cached_outputs=%s, ",
|
||||
VLLM_VERSION,
|
||||
vllm_config,
|
||||
use_cached_outputs,
|
||||
)
|
||||
|
||||
self.log_stats = log_stats
|
||||
self.use_cached_outputs = use_cached_outputs
|
||||
self.thread_running = False
|
||||
|
||||
if not self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
self.detokenizer = ZeroOverheadDetokenizer(self.tokenizer)
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
else:
|
||||
self.tokenizer = None
|
||||
self.detokenizer = None
|
||||
tokenizer_group = None
|
||||
|
||||
# Ensure that the function doesn't contain a reference to self,
|
||||
# to avoid engine GC issues
|
||||
def get_tokenizer_for_seq(sequence: ZeroOverheadSequence) -> AnyTokenizer:
|
||||
assert tokenizer_group, ("tokenizer_group cannot be None, "
|
||||
"make sure skip_tokenizer_init is False")
|
||||
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
self.seq_counter = Counter()
|
||||
self.generation_config_fields = (
|
||||
self.model_config.try_get_generation_config())
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
|
||||
self.model_executor = executor_class(vllm_config=vllm_config, )
|
||||
|
||||
if self.model_config.runner_type != "pooling":
|
||||
self._initialize_kv_caches()
|
||||
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
if is_usage_stats_enabled():
|
||||
from vllm.model_executor.model_loader import (
|
||||
get_architecture_class_name)
|
||||
usage_message.report_usage(
|
||||
get_architecture_class_name(self.model_config),
|
||||
usage_context,
|
||||
extra_kvs={
|
||||
# Common configuration
|
||||
"dtype":
|
||||
str(self.model_config.dtype),
|
||||
"tensor_parallel_size":
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
"block_size":
|
||||
self.cache_config.block_size,
|
||||
"gpu_memory_utilization":
|
||||
self.cache_config.gpu_memory_utilization,
|
||||
|
||||
# Quantization
|
||||
"quantization":
|
||||
self.model_config.quantization,
|
||||
"kv_cache_dtype":
|
||||
str(self.cache_config.cache_dtype),
|
||||
|
||||
# Feature flags
|
||||
"enable_lora":
|
||||
bool(self.lora_config),
|
||||
"enable_prompt_adapter":
|
||||
bool(self.prompt_adapter_config),
|
||||
"enable_prefix_caching":
|
||||
self.cache_config.enable_prefix_caching,
|
||||
"enforce_eager":
|
||||
self.model_config.enforce_eager,
|
||||
"disable_custom_all_reduce":
|
||||
self.parallel_config.disable_custom_all_reduce,
|
||||
})
|
||||
|
||||
self.cached_scheduler_outputs = [
|
||||
SchedulerOutputState()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.scheduler_contexts = [
|
||||
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
|
||||
multi_step_stream_outputs)
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
if self.model_config.use_async_output_proc:
|
||||
process_model_outputs = weak_bind(self._process_model_outputs)
|
||||
|
||||
self.async_callbacks = [
|
||||
partial(process_model_outputs,
|
||||
ctx=self.scheduler_contexts[v_id])
|
||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
else:
|
||||
self.async_callbacks = []
|
||||
|
||||
# Currently used by AsyncLLMEngine to ensure quick append
|
||||
# of request outputs to asyncio queues
|
||||
self.process_request_outputs_callback: Optional[Callable] = None
|
||||
|
||||
# Create the scheduler.
|
||||
# NOTE: the cache_config here have been updated with the numbers of
|
||||
# GPU and CPU blocks, which are profiled in the distributed executor.
|
||||
if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
|
||||
Scheduler = resolve_obj_by_qualname(
|
||||
self.vllm_config.scheduler_config.scheduler_cls)
|
||||
else:
|
||||
Scheduler = self.vllm_config.scheduler_config.scheduler_cls
|
||||
self.scheduler = [
|
||||
Scheduler(
|
||||
self.scheduler_config, self.cache_config, self.lora_config,
|
||||
self.parallel_config.pipeline_parallel_size,
|
||||
self.async_callbacks[v_id]
|
||||
if self.model_config.use_async_output_proc else None)
|
||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Metric Logging.
|
||||
if self.log_stats:
|
||||
if stat_loggers is not None:
|
||||
self.stat_loggers = stat_loggers
|
||||
else:
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from vllm.engine.metrics import (LoggingStatLogger,
|
||||
PrometheusStatLogger)
|
||||
|
||||
self.stat_loggers = {
|
||||
"logging":
|
||||
LoggingStatLogger(
|
||||
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
||||
vllm_config=vllm_config),
|
||||
"prometheus":
|
||||
PrometheusStatLogger(
|
||||
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
||||
labels=dict(
|
||||
model_name=self.model_config.served_model_name),
|
||||
vllm_config=vllm_config),
|
||||
}
|
||||
self.stat_loggers["prometheus"].info("cache_config",
|
||||
self.cache_config)
|
||||
|
||||
self.tracer = None
|
||||
if self.observability_config.otlp_traces_endpoint:
|
||||
self.tracer = init_tracer(
|
||||
"vllm.llm_engine",
|
||||
self.observability_config.otlp_traces_endpoint)
|
||||
|
||||
# Create sequence output processor, e.g. for beam search or
|
||||
# speculative decoding.
|
||||
self.output_processor = (
|
||||
SequenceGroupOutputProcessor.create_output_processor(
|
||||
self.scheduler_config,
|
||||
self.detokenizer,
|
||||
self.scheduler,
|
||||
self.seq_counter,
|
||||
get_tokenizer_for_seq,
|
||||
stop_checker=ZeroOverheadStopChecker(
|
||||
self.scheduler_config.max_model_len,
|
||||
get_tokenizer_for_seq,
|
||||
),
|
||||
))
|
||||
self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1'
|
||||
|
||||
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
|
||||
|
||||
# Flag to set when an input fails to process and the engine should run
|
||||
# the next step without re-scheduling.
|
||||
self._skip_scheduling_next_step = False
|
||||
self.async_d2h = None
|
||||
self.last_record = None
|
||||
self.async_event = torch.cuda.Event(enable_timing=False)
|
||||
self.q_recorder = queue.Queue()
|
||||
self.use_stream = zero_overhead_stream(self.model_executor.device_config.device)
|
||||
if not is_zero_no_thread():
|
||||
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
|
||||
self.thread_running = True
|
||||
self.sem_m2s = threading.Semaphore(0) # main to scheduler thread
|
||||
self.zero_thread.start()
|
||||
profile.StartTracer()
|
||||
|
||||
def __del__(self):
|
||||
self.finish_thread()
|
||||
return super().__del__()
|
||||
|
||||
def finish_thread(self):
|
||||
if self.thread_running:
|
||||
self.thread_running = False
|
||||
self.sem_m2s.release()
|
||||
|
||||
|
||||
def thread_zero_overhead(self):
|
||||
logger.info('zero overhead thread start!')
|
||||
last_sampler = get_last_sampler()
|
||||
last_sampler.seq_ids.clear()
|
||||
try:
|
||||
with torch.cuda.stream(self.use_stream):
|
||||
while True:
|
||||
self.sem_m2s.acquire()
|
||||
if not self.thread_running:
|
||||
logger.debug("Stopping remote worker execution loop.")
|
||||
self.model_executor.stop_remote_worker_execution_loop()
|
||||
|
||||
break
|
||||
|
||||
virtual_engine = 0
|
||||
# Clear outputs for each new scheduler iteration
|
||||
|
||||
# Schedule iteration
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc
|
||||
) = self.scheduler[virtual_engine].schedule()
|
||||
if self.last_record is not None:
|
||||
last_sampler = self.last_record[1]
|
||||
spec_step = get_spec_step()
|
||||
if spec_step == SpecStepKind.KIND_DEFAULT:
|
||||
if last_sampler.sampled_token_ids_tensor is not None:
|
||||
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
|
||||
else:
|
||||
self.async_d2h = None
|
||||
elif spec_step == SpecStepKind.SCORE_DECODE:
|
||||
self.async_d2h = last_sampler.to('cpu', non_blocking=True)
|
||||
self.async_event.record()
|
||||
self.q_recorder.put(self.last_record)
|
||||
else:
|
||||
self.q_recorder.put(None)
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
self.last_record = None
|
||||
continue
|
||||
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
assert seq_group_metadata_list is not None
|
||||
assert scheduler_outputs is not None
|
||||
last_sampled_token_ids = \
|
||||
self._get_last_sampled_token_ids(virtual_engine)
|
||||
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||
running_queue_size=scheduler_outputs.running_queue_size,
|
||||
finished_requests_ids=finished_requests_ids,
|
||||
# We use ExecuteModelRequest to pass the last sampled_token_ids
|
||||
# to each of the non-last PP stages for in-place prepare_input.
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
outputs = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
|
||||
for output in outputs:
|
||||
self._advance_to_next_step(
|
||||
output, seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
|
||||
last_sampler = None
|
||||
spec_step = get_spec_step()
|
||||
if spec_step == SpecStepKind.KIND_DEFAULT:
|
||||
last_sampler = get_last_sampler()
|
||||
elif spec_step == SpecStepKind.SCORE_DECODE:
|
||||
last_sampler, _ = get_accepted_token_ids()
|
||||
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step]
|
||||
|
||||
except Exception as e:
|
||||
print(f"thread_zero_overhead error : {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def zero_overhead_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
if not self.thread_running:
|
||||
self.zero_thread.join()
|
||||
self.thread_running = True
|
||||
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
|
||||
self.zero_thread.start()
|
||||
self.sem_m2s.release()
|
||||
recode_output = self.q_recorder.get()
|
||||
if recode_output is None: # None is for the first step
|
||||
return None
|
||||
virtual_engine = 0
|
||||
ctx = self.scheduler_contexts[virtual_engine]
|
||||
ctx.request_outputs.clear()
|
||||
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step = recode_output
|
||||
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||
ctx.scheduler_outputs = scheduler_outputs
|
||||
if spec_step == SpecStepKind.KIND_DEFAULT:
|
||||
self.async_event.synchronize()
|
||||
if self.async_d2h is not None:
|
||||
self._fix_last_step(
|
||||
outputs, last_sampler, seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
elif spec_step == SpecStepKind.SCORE_DECODE:
|
||||
self.async_event.synchronize()
|
||||
self._fix_spec_decode_steps(
|
||||
outputs, seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
# is_first_step_output is True only when the num_steps of all
|
||||
# the sequences are 1. When the num_steps > 1,
|
||||
# multi_step_model_runner does the first-step output append.
|
||||
is_first_step_output: bool = False if not seq_group_metadata_list \
|
||||
else seq_group_metadata_list[0].state.num_steps == 1
|
||||
|
||||
# Add results to the output_queue
|
||||
ctx.append_output(outputs=outputs,
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
scheduler_outputs=scheduler_outputs,
|
||||
is_async=True,
|
||||
is_last_step=True,
|
||||
is_first_step_output=is_first_step_output)
|
||||
|
||||
# Check if need to run the usual non-async path
|
||||
#if not allow_async_output_proc:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
#profile.ProfRangeAutoPush('has_unfinish')
|
||||
if not self.has_unfinished_requests():
|
||||
# Drain async postprocessor (if exists)
|
||||
if len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
assert len(ctx.output_queue) == 0
|
||||
|
||||
# Stop the execute model loop in parallel workers until there are
|
||||
# more requests to process. This avoids waiting indefinitely in
|
||||
# torch.distributed ops which may otherwise timeout, and unblocks
|
||||
# the RPC thread in the workers so that they can process any other
|
||||
# queued control plane messages, such as add/remove lora adapters.
|
||||
# logger.debug("Stopping remote worker execution loop.")
|
||||
# self.model_executor.stop_remote_worker_execution_loop()
|
||||
self.finish_thread()
|
||||
return ctx.request_outputs
|
||||
|
||||
|
||||
def _fix_last_step(
|
||||
self, output: List[SamplerOutput],
|
||||
last_sampler: SampleRecorder,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
|
||||
|
||||
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
|
||||
sample_out_list = self.async_d2h.tolist()
|
||||
sample_out_ids = last_sampler.seq_ids
|
||||
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
|
||||
zip(seq_group_metadata_list, output[0], scheduled_seq_groups):
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
|
||||
if seq_group.is_finished():
|
||||
continue
|
||||
|
||||
if seq_group_metadata.do_sample:
|
||||
sample = sequence_group_outputs.samples[0]
|
||||
|
||||
assert len(seq_group.seqs) == 1
|
||||
seq : ZeroOverheadSequence = seq_group.seqs[0]
|
||||
for token_id, seq_id in zip(sample_out_list, sample_out_ids):
|
||||
if seq.seq_id == seq_id:
|
||||
if type(token_id) is list:
|
||||
sample.output_token = token_id[0]
|
||||
else:
|
||||
sample.output_token = token_id
|
||||
seq.fix_last_token_id(sample.output_token)
|
||||
break
|
||||
|
||||
def _fix_spec_decode_steps(
|
||||
self, output: List[SamplerOutput],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
scheduled_seq_groups: List[ScheduledSequenceGroup]):
|
||||
|
||||
sample_out_list = self.async_d2h.tolist()
|
||||
group_idx = 0
|
||||
for seq_group_metadata, accept_token_ids, scheduled_seq_group in \
|
||||
zip(seq_group_metadata_list, sample_out_list, scheduled_seq_groups):
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
|
||||
if seq_group.is_finished():
|
||||
group_idx += 1
|
||||
continue
|
||||
|
||||
if seq_group_metadata.do_sample:
|
||||
assert len(seq_group.seqs) == 1
|
||||
seq : ZeroOverheadSequence = seq_group.seqs[0]
|
||||
remove_count = 0
|
||||
for token_id in accept_token_ids:
|
||||
if token_id == -1:
|
||||
remove_count += 1
|
||||
else:
|
||||
seq.fix_last_token_id(token_id)
|
||||
seq.remove_last_place_holder(remove_count)
|
||||
group_idx += 1
|
||||
|
||||
def no_thread_step(self):
|
||||
virtual_engine = 0
|
||||
# Clear outputs for each new scheduler iteration
|
||||
|
||||
# Schedule iteration
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc
|
||||
) = self.scheduler[virtual_engine].schedule()
|
||||
if self.last_record is not None:
|
||||
last_sampler = self.last_record[1]
|
||||
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
|
||||
self.async_event.record()
|
||||
self.q_recorder.put(self.last_record)
|
||||
else:
|
||||
self.q_recorder.put(None)
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
self.last_record = None
|
||||
else:
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
assert seq_group_metadata_list is not None
|
||||
assert scheduler_outputs is not None
|
||||
last_sampled_token_ids = \
|
||||
self._get_last_sampled_token_ids(virtual_engine)
|
||||
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||
running_queue_size=scheduler_outputs.running_queue_size,
|
||||
finished_requests_ids=finished_requests_ids,
|
||||
# We use ExecuteModelRequest to pass the last sampled_token_ids
|
||||
# to each of the non-last PP stages for in-place prepare_input.
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
outputs = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
|
||||
if len(outputs) == 1:
|
||||
self._advance_to_next_step(
|
||||
outputs[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
|
||||
last_sampler = get_last_sampler()
|
||||
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs]
|
||||
|
||||
recode_output = self.q_recorder.get()
|
||||
if recode_output is None: # None is for the first step
|
||||
return None
|
||||
virtual_engine = 0
|
||||
ctx = self.scheduler_contexts[virtual_engine]
|
||||
ctx.request_outputs.clear()
|
||||
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs = recode_output
|
||||
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||
ctx.scheduler_outputs = scheduler_outputs
|
||||
self.async_event.synchronize()
|
||||
self._fix_last_step(
|
||||
outputs, last_sampler, seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
# is_first_step_output is True only when the num_steps of all
|
||||
# the sequences are 1. When the num_steps > 1,
|
||||
# multi_step_model_runner does the first-step output append.
|
||||
is_first_step_output: bool = False if not seq_group_metadata_list \
|
||||
else seq_group_metadata_list[0].state.num_steps == 1
|
||||
|
||||
# Add results to the output_queue
|
||||
ctx.append_output(outputs=outputs,
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
scheduler_outputs=scheduler_outputs,
|
||||
is_async=True,
|
||||
is_last_step=True,
|
||||
is_first_step_output=is_first_step_output)
|
||||
|
||||
# Check if need to run the usual non-async path
|
||||
#if not allow_async_output_proc:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
#profile.ProfRangeAutoPush('has_unfinish')
|
||||
if not self.has_unfinished_requests():
|
||||
# Drain async postprocessor (if exists)
|
||||
if len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
assert len(ctx.output_queue) == 0
|
||||
|
||||
# Stop the execute model loop in parallel workers until there are
|
||||
# more requests to process. This avoids waiting indefinitely in
|
||||
# torch.distributed ops which may otherwise timeout, and unblocks
|
||||
# the RPC thread in the workers so that they can process any other
|
||||
# queued control plane messages, such as add/remove lora adapters.
|
||||
logger.debug("Stopping remote worker execution loop.")
|
||||
self.model_executor.stop_remote_worker_execution_loop()
|
||||
return ctx.request_outputs
|
||||
|
||||
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
with torch.cuda.stream(self.use_stream):
|
||||
if is_zero_no_thread():
|
||||
out = self.no_thread_step()
|
||||
if out is None: #the first step need launch twice
|
||||
out = self.no_thread_step()
|
||||
else:
|
||||
out = self.zero_overhead_step()
|
||||
if out is None: #the first step need launch twice
|
||||
out = self.zero_overhead_step()
|
||||
return out
|
||||
|
||||
def _add_processed_request(
|
||||
self,
|
||||
request_id: str,
|
||||
processed_inputs: ProcessorInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> Optional[SequenceGroup]:
|
||||
"""Add a processed request to the engine's request pool.
|
||||
return the created sequence group.
|
||||
"""
|
||||
if isinstance(params, SamplingParams) and params.n > 1:
|
||||
ParallelSampleSequenceGroup.add_request(
|
||||
request_id,
|
||||
self,
|
||||
params,
|
||||
processed_inputs=processed_inputs,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
)
|
||||
return None
|
||||
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
|
||||
seq = ZeroOverheadSequence(seq_id, decoder_inputs, block_size, eos_token_id,
|
||||
lora_request, prompt_adapter_request)
|
||||
|
||||
encoder_seq = (None if encoder_inputs is None else ZeroOverheadSequence(
|
||||
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
|
||||
prompt_adapter_request))
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
seq_group = self._create_sequence_group_with_sampling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority)
|
||||
elif isinstance(params, PoolingParams):
|
||||
seq_group = self._create_sequence_group_with_pooling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either SamplingParams or PoolingParams must be provided.")
|
||||
|
||||
# Add the sequence group to the scheduler with least unfinished seqs.
|
||||
costs = [
|
||||
scheduler.get_num_unfinished_seq_groups()
|
||||
for scheduler in self.scheduler
|
||||
]
|
||||
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
|
||||
min_cost_scheduler.add_seq_group(seq_group)
|
||||
|
||||
return seq_group
|
||||
171
vllm/zero_overhead/model_runner.py
Normal file
171
vllm/zero_overhead/model_runner.py
Normal file
@@ -0,0 +1,171 @@
|
||||
|
||||
|
||||
import torch
|
||||
import itertools
|
||||
from typing import List, Optional, Set
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.utils import async_tensor_h2d, flatten_2d_lists
|
||||
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
|
||||
from vllm.zero_overhead.sampler import get_last_sampler
|
||||
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_proposal_token_ids, get_spec_last_step, get_spec_step
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def _update_input_tokens(
|
||||
accepted_req_ids,
|
||||
accepted_req_ids_len,
|
||||
accepted_token_ids,
|
||||
accepted_token_len,
|
||||
chidren_req_ids,
|
||||
chidren_req_ids_len,
|
||||
input_tokens,
|
||||
input_tokens_len,
|
||||
input_positions,
|
||||
seq_lens,
|
||||
seq_lens_meta,
|
||||
seq_lens_tensor,
|
||||
slot_mapping,
|
||||
seq_start_loc,
|
||||
context_lens_tensor,
|
||||
):
|
||||
chidren_req_ids_ = tl.load(chidren_req_ids + tl.arange(0, chidren_req_ids_len))
|
||||
accepted_req_ids_ = tl.load(accepted_req_ids + tl.arange(0, chidren_req_ids_len))
|
||||
|
||||
for seq_id_idx in range(chidren_req_ids_len / 2):
|
||||
seq_id = chidren_req_ids_[2 * seq_id_idx]
|
||||
for i in range(accepted_req_ids_len):
|
||||
if seq_id == accepted_req_ids_[i]:
|
||||
accepted_token_ids_ = tl.load(accepted_token_ids + tl.arange(i * accepted_token_len, tl.arange(0, accepted_token_len)))
|
||||
accepted_token_counter = 0
|
||||
for j in range(accepted_token_len):
|
||||
if accepted_token_ids_[j] == -1:
|
||||
break
|
||||
accepted_token_counter += 1
|
||||
if accepted_token_counter == accepted_token_len:
|
||||
tl.store(input_tokens + seq_id_idx * 2 + tl.arange(0, 2), accepted_token_ids_[-2:])
|
||||
else:
|
||||
tl.store(input_tokens + seq_id_idx * 2, 0)
|
||||
tl.store(input_tokens + seq_id_idx * 2 + 1, accepted_token_ids_[accepted_token_counter - 1])
|
||||
input_pos = tl.load(input_positions + seq_id_idx * 2 + tl.arange(0, 2))
|
||||
input_pos[0] = 0
|
||||
input_pos[1] = input_pos[1] - (accepted_req_ids_len - accepted_token_counter)
|
||||
tl.store(input_positions + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||||
tl.store(context_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||||
input_pos[0] = -1
|
||||
tl.store(slot_mapping + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||||
input_pos[0] = 1
|
||||
input_pos[1] = input_pos[1] + 1
|
||||
tl.store(seq_lens + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||||
tl.store(seq_lens_meta + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||||
tl.store(seq_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||||
seq_lens_ = tl.load(seq_lens + tl.arange(0, input_tokens_len))
|
||||
seq_start_loc_ = tl.zero_like(seq_start_loc)
|
||||
for i in range(input_tokens_len):
|
||||
seq_start_loc_[i + 1] = seq_start_loc_[i] + seq_lens_[i]
|
||||
tl.store(seq_start_loc + tl.arange(0, input_tokens_len + 1), seq_start_loc_)
|
||||
|
||||
|
||||
|
||||
class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
|
||||
def __init__(self, runner, finished_requests_ids = None):
|
||||
super().__init__(runner, finished_requests_ids)
|
||||
self.req_ids = []
|
||||
|
||||
def prepare(self,
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
self.req_ids.clear()
|
||||
return super().prepare(finished_requests_ids)
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
seq_ids = seq_group_metadata.seq_data.keys()
|
||||
n_seqs = len(seq_ids)
|
||||
seq_ids = list(seq_ids)
|
||||
for seq_idx in range(n_seqs):
|
||||
self.req_ids.append(seq_ids[seq_idx])
|
||||
return super().add_seq_group(seq_group_metadata)
|
||||
|
||||
def build(self) -> ModelInputForGPU:
|
||||
model_input = super().build()
|
||||
last_sampler = get_last_sampler()
|
||||
spec_step = get_spec_step()
|
||||
last_step = get_spec_last_step()
|
||||
if last_sampler is not None:
|
||||
if spec_step == SpecStepKind.KIND_DEFAULT:
|
||||
update_indices = []
|
||||
select_indices = []
|
||||
query_idx = 0
|
||||
for i, seq_id in enumerate(self.req_ids):
|
||||
for j, seq_id_ in enumerate(last_sampler.seq_ids):
|
||||
if seq_id == seq_id_:
|
||||
select_indices.append(j)
|
||||
update_indices.append(query_idx)
|
||||
break
|
||||
query_idx += model_input.query_lens[i]
|
||||
if len(select_indices) > 0 and last_sampler.sampled_token_ids_tensor is not None:
|
||||
select_indices = async_tensor_h2d(select_indices, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
update_indices = async_tensor_h2d(update_indices, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
|
||||
if spec_step == SpecStepKind.OTHER_PROPOSAL:
|
||||
if last_step == SpecStepKind.OTHER_PROPOSAL: # copy last sampled token ids to input tokens directly.
|
||||
update_indices = [i for i in range(len(self.req_ids))]
|
||||
update_indices = async_tensor_h2d(update_indices, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
|
||||
if last_step == SpecStepKind.FIRST_PROPOSAL: # TODO: ajust input tokens number to 1 per request.
|
||||
update_indices = [i for i in range(len(self.req_ids))]
|
||||
update_indices = async_tensor_h2d(update_indices, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
|
||||
|
||||
if spec_step == SpecStepKind.SCORE_DECODE:
|
||||
proposal_token_ids = get_proposal_token_ids()
|
||||
shape = proposal_token_ids.shape
|
||||
batch_size = shape[0]
|
||||
proposal_len = shape[1]
|
||||
update_indices = []
|
||||
for i in range(batch_size):
|
||||
for j in range(proposal_len):
|
||||
update_indices.append(i * (proposal_len + 1) + j + 1)
|
||||
|
||||
update_indices = async_tensor_h2d(update_indices, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
model_input.input_tokens[update_indices] = proposal_token_ids.view(-1)
|
||||
if spec_step == SpecStepKind.FIRST_PROPOSAL:
|
||||
if last_step == SpecStepKind.PREFILL:# TODO: when last step is prefill, just update the input ids for last seqence_id onely.
|
||||
pass
|
||||
if last_step == SpecStepKind.SCORE_DECODE:# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids
|
||||
accept_token_ids, accept_seq_ids = get_accepted_token_ids()
|
||||
|
||||
chidren_req_ids = async_tensor_h2d(self.req_ids, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
grid = [1, 1, 1]
|
||||
_update_input_tokens[grid](
|
||||
accept_seq_ids, accept_seq_ids.shape[0],
|
||||
accept_token_ids, accept_token_ids.shape[1],
|
||||
chidren_req_ids, chidren_req_ids.shape[0],
|
||||
model_input.input_tokens, model_input.input_tokens.shape[0],
|
||||
model_input.input_positions,
|
||||
model_input.seq_lens,
|
||||
model_input.attn_metadata.seq_lens_tensor,
|
||||
model_input.attn_metadata.seq_lens,
|
||||
model_input.attn_metadata.slot_mapping,
|
||||
model_input.attn_metadata.seq_start_loc,
|
||||
model_input.attn_metadata.context_lens_tensor,
|
||||
)
|
||||
|
||||
|
||||
return model_input
|
||||
500
vllm/zero_overhead/sampler.py
Normal file
500
vllm/zero_overhead/sampler.py
Normal file
@@ -0,0 +1,500 @@
|
||||
from importlib.util import find_spec
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.model_executor.layers.sampler import MaybeDeferredSampleResultType, MultinomialSamplesType, SampleMetadataType, \
|
||||
SampleResultArgsType, SampleResultType, SampleResultsDictType, SampleReturnType, Sampler, \
|
||||
SamplerOutput, _apply_min_p, _apply_min_tokens_penalty, _apply_top_k_top_p, \
|
||||
_modify_greedy_probs_inplace, _top_k_top_p_multinomial_with_flashinfer, get_logprobs, _multinomial
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors, SequenceGroupToSample
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, PromptLogprobs, SampleLogprobs, SequenceOutput
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
||||
import flashinfer.sampling
|
||||
# yapf: disable
|
||||
from flashinfer.sampling import (
|
||||
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
|
||||
# yapf: enable
|
||||
else:
|
||||
flashinfer_top_k_top_p_sampling = None
|
||||
|
||||
class SampleRecorder:
|
||||
def __init__(self):
|
||||
self.seq_ids:torch.Tensor = None
|
||||
self.sampled_token_ids_tensor:torch.Tensor = None
|
||||
|
||||
last_sampler = None
|
||||
|
||||
def get_last_sampler():
|
||||
return last_sampler
|
||||
|
||||
class ZeroOverheadSampler(Sampler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
"""
|
||||
Single-step scheduling:
|
||||
* Perform GPU-side sampling computation & compute
|
||||
GPU-side logprobs tensor
|
||||
* Pythonize sampling result & logprobs tensor
|
||||
|
||||
Multi-step scheduling:
|
||||
* Perform GPU-side sampling computation & compute
|
||||
GPU-side logprobs tensor
|
||||
* Defer Pythonization of sampling result & logprobs
|
||||
tensor
|
||||
* Encapsulate arguments required for deferred Pythonization
|
||||
in the :class:`SamplerOutput` structure
|
||||
|
||||
Args:
|
||||
logits: (num_tokens, vocab_size).
|
||||
sampling_metadata: Metadata for sampling.
|
||||
"""
|
||||
global last_sampler
|
||||
last_sampler = SampleRecorder()
|
||||
assert logits is not None
|
||||
_, vocab_size = logits.shape
|
||||
|
||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||
if not sampling_metadata.reuse_sampling_tensors:
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
elif self._do_penalties:
|
||||
# In this case, the sampling tensors logic depends on
|
||||
# "output_tokens" of a sequence. As a result, we cannot
|
||||
# reuse sampling tensors, since "output_tokens" changes
|
||||
# between decode runs.
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
|
||||
assert self._sampling_tensors is not None
|
||||
sampling_tensors = self._sampling_tensors
|
||||
do_penalties = self._do_penalties
|
||||
do_top_p_top_k = self._do_top_p_top_k
|
||||
do_min_p = self._do_min_p
|
||||
|
||||
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
if do_penalties:
|
||||
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
|
||||
sampling_tensors.output_tokens,
|
||||
sampling_tensors.presence_penalties,
|
||||
sampling_tensors.frequency_penalties,
|
||||
sampling_tensors.repetition_penalties)
|
||||
|
||||
# Use float32 to apply temperature scaling.
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits = logits.to(torch.float)
|
||||
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
|
||||
|
||||
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
|
||||
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
||||
sampling_tensors.top_ks)
|
||||
|
||||
if do_min_p:
|
||||
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
||||
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
|
||||
probs,
|
||||
logprobs,
|
||||
sampling_metadata,
|
||||
sampling_tensors,
|
||||
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
||||
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
||||
)
|
||||
|
||||
if self.include_gpu_probs_tensor:
|
||||
# Since we will defer sampler result Pythonization,
|
||||
# preserve GPU-side tensors in support of later
|
||||
# deferred pythonization of logprobs
|
||||
assert maybe_sampled_tokens_tensor is not None
|
||||
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
||||
else:
|
||||
# Since Pythonization has already happened, don't preserve
|
||||
# GPU-side tensors.
|
||||
on_device_tensors = None
|
||||
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs = None
|
||||
sample_logprobs = None
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
||||
assert not isinstance(maybe_deferred_sample_results,
|
||||
SampleResultArgsType)
|
||||
prompt_logprobs, sample_logprobs = get_logprobs(
|
||||
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
||||
|
||||
return _build_sampler_output(
|
||||
maybe_deferred_sample_results,
|
||||
sampling_metadata,
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
on_device_tensors=on_device_tensors,
|
||||
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
|
||||
logits=logits)
|
||||
|
||||
def _greedy_sample(
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
samples: torch.Tensor,
|
||||
) -> SampleResultType:
|
||||
"""Run greedy sampling on a given samples.
|
||||
|
||||
Args:
|
||||
selected_seq_groups: A list of sequence groups batched.
|
||||
samples: (num_selected_samples,) A tensor of samples. The length of
|
||||
samples could be smaller than selected_seq_groups if
|
||||
seq_group.do_sample is False.
|
||||
Returns:
|
||||
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
||||
same as the length of selected_seq_groups. If the corresponding
|
||||
seq_group has do_sample=False, tuple contains ([], [])
|
||||
"""
|
||||
sample_idx = 0
|
||||
results: SampleResultType = []
|
||||
for seq_group in selected_seq_groups:
|
||||
if not seq_group.do_sample:
|
||||
results.append(([], []))
|
||||
continue
|
||||
|
||||
seq_ids = seq_group.seq_ids
|
||||
num_parent_seqs = len(seq_ids)
|
||||
assert num_parent_seqs == 1, (
|
||||
"Greedy sampling should have only one seq.")
|
||||
parent_ids = list(range(num_parent_seqs))
|
||||
assert num_parent_seqs == 1 # not support muti seqences in seqence group
|
||||
next_token_ids = [0] #place holder token id
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
return results
|
||||
|
||||
def _random_sample(
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
random_samples: torch.Tensor,
|
||||
) -> SampleResultType:
|
||||
"""Run random sampling on a given samples.
|
||||
|
||||
Args:
|
||||
selected_seq_groups: A list of sequence groups batched.
|
||||
random_samples: (num_selected_samples,) A tensor of samples. The
|
||||
length of samples could be smaller than selected_seq_groups if
|
||||
seq_group.do_sample is False.
|
||||
Returns:
|
||||
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
||||
same as the length of selected_seq_groups. If the corresponding
|
||||
seq_group has do_sample=False, tuple contains ([], [])
|
||||
"""
|
||||
# Find the maximum n value of the prompt phase requests.
|
||||
sample_idx = 0
|
||||
results: SampleResultType = []
|
||||
for seq_group in selected_seq_groups:
|
||||
if not seq_group.do_sample:
|
||||
results.append(([], []))
|
||||
continue
|
||||
|
||||
seq_ids = seq_group.seq_ids
|
||||
sampling_params = seq_group.sampling_params
|
||||
is_prompt = seq_group.is_prompt
|
||||
num_parent_seqs = len(seq_ids)
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
parent_ids = [0] * sampling_params.n
|
||||
assert num_parent_seqs == 1 # not support muti seqences in seqence group
|
||||
next_token_ids = [0] * sampling_params.n #place holder token id
|
||||
else:
|
||||
# Generation phase.
|
||||
parent_ids = list(range(num_parent_seqs))
|
||||
assert num_parent_seqs == 1 # not support muti seqences in seqence group
|
||||
next_token_ids = [0] * num_parent_seqs #place holder token id
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
return results
|
||||
|
||||
def _sample(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_tensors: SamplingTensors,
|
||||
include_gpu_probs_tensor: bool,
|
||||
modify_greedy_probs: bool,
|
||||
) -> SampleReturnType:
|
||||
"""
|
||||
Args:
|
||||
probs: (num_query_tokens_in_batch, num_vocab)
|
||||
logprobs: (num_query_tokens_in_batch, num_vocab)
|
||||
sampling_metadata: The metadata for a batch for sampling.
|
||||
sampling_tensors: Tensors that include sampling related metadata.
|
||||
|
||||
Returns:
|
||||
(next_token_ids, parent_seq_ids) for each seq group in a batch.
|
||||
If sampling is skipped, it returns ([], [])
|
||||
sampled_token_ids_tensor: A tensor of sampled token ids.
|
||||
"""
|
||||
return _sample_with_torch(
|
||||
probs,
|
||||
logprobs,
|
||||
sampling_metadata,
|
||||
sampling_tensors,
|
||||
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
)
|
||||
|
||||
def _sample_with_torch(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_tensors: SamplingTensors,
|
||||
include_gpu_probs_tensor: bool,
|
||||
modify_greedy_probs: bool,
|
||||
) -> SampleReturnType:
|
||||
'''Torch-oriented _sample() implementation.
|
||||
|
||||
Single-step scheduling:
|
||||
* Perform GPU-side sampling computation
|
||||
* Immediately Pythonize sampling result
|
||||
|
||||
Multi-step scheduling:
|
||||
* Perform GPU-side sampling computation
|
||||
* Defer Pythonization & preserve GPU-side
|
||||
tensors required for Pythonization
|
||||
'''
|
||||
|
||||
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
sampling_params = seq_group.sampling_params
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
sample_results_dict: SampleResultsDictType = {}
|
||||
sample_metadata: SampleMetadataType = {}
|
||||
multinomial_samples: MultinomialSamplesType = {}
|
||||
greedy_samples: Optional[torch.Tensor] = None
|
||||
|
||||
# Create output tensor for sampled token ids.
|
||||
if include_gpu_probs_tensor:
|
||||
sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
|
||||
VLLM_INVALID_TOKEN_ID,
|
||||
dtype=torch.long,
|
||||
device=logprobs.device)
|
||||
else:
|
||||
sampled_token_ids_tensor = None
|
||||
|
||||
# Counterintiutively, having two loops here is actually faster.
|
||||
# The first loop can run without waiting on GPU<->CPU sync.
|
||||
for sampling_type in SamplingType:
|
||||
sample_indices = categorized_sample_indices[sampling_type]
|
||||
num_tokens = len(sample_indices)
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
|
||||
seq_group_id = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
||||
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
|
||||
long_sample_indices = sample_indices.long()
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
||||
dim=-1)
|
||||
|
||||
last_sampler.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
|
||||
|
||||
if sampled_token_ids_tensor is not None:
|
||||
# Store sampled tokens in output tensor.
|
||||
sampled_token_ids_tensor[
|
||||
long_sample_indices] = greedy_samples.unsqueeze(-1)
|
||||
|
||||
if modify_greedy_probs:
|
||||
# If required, modify the probabilities such that sampling from
|
||||
# the modified distribution would always sample the argmax
|
||||
# token id.
|
||||
_modify_greedy_probs_inplace(logprobs, probs,
|
||||
long_sample_indices,
|
||||
greedy_samples)
|
||||
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
max_n_in_batch = 1
|
||||
for seq_group in seq_groups:
|
||||
if seq_group.is_prompt:
|
||||
sampling_params = seq_group.sampling_params
|
||||
max_n_in_batch = max(max_n_in_batch, sampling_params.n)
|
||||
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
|
||||
seq_groups)
|
||||
|
||||
if flashinfer_top_k_top_p_sampling is not None:
|
||||
multinomial_samples[
|
||||
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
|
||||
probs[long_sample_indices],
|
||||
sampling_tensors.top_ks[long_sample_indices],
|
||||
sampling_tensors.top_ps[long_sample_indices],
|
||||
max_n_in_batch,
|
||||
seq_groups_arg,
|
||||
)
|
||||
else:
|
||||
multinomial_samples[sampling_type] = _multinomial(
|
||||
probs[long_sample_indices],
|
||||
max_n_in_batch,
|
||||
seq_groups=seq_groups_arg)
|
||||
|
||||
last_sampler.sampled_token_ids_tensor = \
|
||||
multinomial_samples[sampling_type].to(torch.long)
|
||||
|
||||
if sampled_token_ids_tensor is not None:
|
||||
# Store sampled tokens in output tensor.
|
||||
sampled_token_ids_tensor[long_sample_indices] = \
|
||||
multinomial_samples[sampling_type].to(torch.long)
|
||||
|
||||
# Encapsulate arguments for computing Pythonized sampler
|
||||
# results, whether deferred or otherwise.
|
||||
maybe_deferred_args = SampleResultArgsType(
|
||||
sampling_metadata=sampling_metadata,
|
||||
sample_metadata=sample_metadata,
|
||||
multinomial_samples=multinomial_samples,
|
||||
greedy_samples=greedy_samples,
|
||||
sample_results_dict=sample_results_dict)
|
||||
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
# GPU<->CPU sync happens here.
|
||||
# This also converts the sampler output to a Python object.
|
||||
# Return Pythonized sampler result & sampled token ids
|
||||
return get_pythonized_sample_results(
|
||||
maybe_deferred_args), sampled_token_ids_tensor
|
||||
else:
|
||||
# Defer sampler result Pythonization; return deferred
|
||||
# Pythonization args & sampled token ids
|
||||
return (
|
||||
maybe_deferred_args,
|
||||
sampled_token_ids_tensor,
|
||||
)
|
||||
|
||||
|
||||
def get_pythonized_sample_results(
|
||||
sample_result_args: SampleResultArgsType) -> SampleResultType:
|
||||
'''This function consumes GPU-side sampler results and computes
|
||||
Pythonized CPU-side sampler results (GPU -> CPU sync.)
|
||||
|
||||
Single-step scheduling: this function is invoked at sampling-time
|
||||
for immediate Pythonization.
|
||||
|
||||
Multi-step scheduling: Pythonization is deferred until after multiple
|
||||
GPU-side steps have been completed.
|
||||
|
||||
Args:
|
||||
sample_result_args: GPU-side inputs to the Pythonization process
|
||||
|
||||
Returns:
|
||||
Pythonized sampler results
|
||||
'''
|
||||
|
||||
(
|
||||
sample_metadata,
|
||||
sampling_metadata,
|
||||
greedy_samples,
|
||||
multinomial_samples,
|
||||
sample_results_dict,
|
||||
) = (
|
||||
sample_result_args.sample_metadata,
|
||||
sample_result_args.sampling_metadata,
|
||||
sample_result_args.greedy_samples,
|
||||
sample_result_args.multinomial_samples,
|
||||
sample_result_args.sample_results_dict,
|
||||
)
|
||||
|
||||
for sampling_type in SamplingType:
|
||||
if sampling_type not in sample_metadata:
|
||||
continue
|
||||
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
sample_results = _random_sample(seq_groups,
|
||||
multinomial_samples[sampling_type])
|
||||
sample_results_dict.update(zip(seq_group_id, sample_results))
|
||||
|
||||
return [
|
||||
sample_results_dict.get(i, ([], []))
|
||||
for i in range(len(sampling_metadata.seq_groups))
|
||||
]
|
||||
|
||||
def _build_sampler_output(
|
||||
maybe_deferred_sample_results: MaybeDeferredSampleResultType,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
|
||||
sample_logprobs: Optional[List[SampleLogprobs]],
|
||||
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]],
|
||||
skip_sampler_cpu_output: bool = False,
|
||||
logits: Optional[torch.Tensor] = None
|
||||
) -> SamplerOutput:
|
||||
"""Construct Python objects with the output of sampling.
|
||||
|
||||
Args:
|
||||
on_device_tensors: Tuple containing on-device tensors with the
|
||||
probabilities used in sampling and the sampled token ids. This
|
||||
allows post-processing without copies to CPU/serialization, e.g. in
|
||||
speculative decoding rejection sampling.
|
||||
"""
|
||||
sampler_output: List[CompletionSequenceGroupOutput] = []
|
||||
|
||||
last_sampler.seq_ids = []
|
||||
if skip_sampler_cpu_output:
|
||||
assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
|
||||
deferred_sample_results_args = maybe_deferred_sample_results
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
assert sample_logprobs is not None
|
||||
assert not isinstance(maybe_deferred_sample_results,
|
||||
SampleResultArgsType)
|
||||
assert len(sampling_metadata.seq_groups) \
|
||||
== len(maybe_deferred_sample_results) \
|
||||
== len(prompt_logprobs) \
|
||||
== len(sample_logprobs)
|
||||
deferred_sample_results_args = None
|
||||
|
||||
for (seq_group, sample_result, group_prompt_logprobs,
|
||||
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
||||
maybe_deferred_sample_results,
|
||||
prompt_logprobs, sample_logprobs):
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_token_ids, parent_ids = sample_result
|
||||
seq_outputs: List[SequenceOutput] = []
|
||||
for parent_id, next_token_id, logprobs in zip(
|
||||
parent_ids, next_token_ids, group_sample_logprobs):
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id,
|
||||
logprobs))
|
||||
sampler_output.append(
|
||||
CompletionSequenceGroupOutput(seq_outputs,
|
||||
group_prompt_logprobs))
|
||||
if len(seq_outputs) > 0:
|
||||
last_sampler.seq_ids.append(seq_outputs[0].parent_seq_id)
|
||||
|
||||
# If not specified, store None values in SamplerOutput.
|
||||
if on_device_tensors is not None:
|
||||
(sampled_token_probs, logprobs_tensor,
|
||||
sampled_token_ids) = on_device_tensors
|
||||
else:
|
||||
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
|
||||
None)
|
||||
|
||||
return SamplerOutput(
|
||||
outputs=sampler_output,
|
||||
sampled_token_probs=sampled_token_probs,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=logprobs_tensor,
|
||||
deferred_sample_results_args=deferred_sample_results_args,
|
||||
logits=logits)
|
||||
64
vllm/zero_overhead/sequence.py
Normal file
64
vllm/zero_overhead/sequence.py
Normal file
@@ -0,0 +1,64 @@
|
||||
|
||||
from typing import Union
|
||||
from vllm.sequence import Sequence
|
||||
from typing import Sequence as GenericSequence
|
||||
|
||||
|
||||
class ZeroOverheadSequence(Sequence):
|
||||
def __init__(self, seq_id, inputs, block_size, eos_token_id = None, lora_request = None, prompt_adapter_request = None):
|
||||
super().__init__(seq_id, inputs, block_size, eos_token_id, lora_request, prompt_adapter_request)
|
||||
self.effective_output_len : int = 0
|
||||
|
||||
def fix_last_token_id(self, token_id: int) -> None:
|
||||
effect_offset = self.effective_output_len - len(self.data.output_token_ids)
|
||||
if effect_offset < 0:
|
||||
self.data._output_token_ids[effect_offset] = token_id
|
||||
if len(self.data._new_appended_tokens) >= effect_offset * -1:
|
||||
self.data._new_appended_tokens[effect_offset] = token_id
|
||||
self.data._cached_all_token_ids[effect_offset] = token_id
|
||||
self.effective_output_len += 1
|
||||
|
||||
def remove_last_place_holder(self, count):
|
||||
self.data._output_token_ids = self.data._output_token_ids[:-1 * count]
|
||||
self.data._new_appended_tokens = self.data._new_appended_tokens[:-1 * count]
|
||||
self.data._cached_all_token_ids = self.data._cached_all_token_ids[:-1 * count]
|
||||
self.data._num_computed_tokens -= count
|
||||
|
||||
def zero_overhead_get_output_token_ids(self) -> tuple[int, ...]:
|
||||
return self.data.output_token_ids[:self.effective_output_len]
|
||||
|
||||
def zero_overhead_get_output_len(self) -> int:
|
||||
return self.effective_output_len
|
||||
|
||||
def zero_overhead_get_last_token_id(self) -> int:
|
||||
if self.effective_output_len == 0:
|
||||
return self.data._prompt_token_ids[-1]
|
||||
return self.data._output_token_ids[self.effective_output_len - 1]
|
||||
|
||||
def zero_overhead_get_len(self) -> int:
|
||||
return self.effective_output_len + len(self.data._prompt_token_ids)
|
||||
|
||||
def get_output_token_ids_to_return(
|
||||
self, delta: bool) -> Union[GenericSequence[int], int]:
|
||||
"""If delta is True, only new tokens since the last call to
|
||||
this method are returned"""
|
||||
if not delta:
|
||||
return self.zero_overhead_get_output_token_ids()
|
||||
|
||||
output_len = self.zero_overhead_get_output_len()
|
||||
|
||||
# Get the number of new tokens
|
||||
num_new_tokens = output_len - self._last_output_token_ids_offset
|
||||
self._last_output_token_ids_offset = output_len
|
||||
|
||||
# Return new tokens
|
||||
if num_new_tokens == 1:
|
||||
# Optimization for single decode token case
|
||||
# (which is what we have most of the time)
|
||||
return self.data._cached_all_token_ids[self.effective_output_len - 1]
|
||||
|
||||
if num_new_tokens == 0:
|
||||
return []
|
||||
|
||||
effect_offset = self.effective_output_len - len(self.data.output_token_ids)
|
||||
return self.data._cached_all_token_ids[-num_new_tokens : effect_offset]
|
||||
141
vllm/zero_overhead/spec_decode/batch_expansion.py
Normal file
141
vllm/zero_overhead/spec_decode/batch_expansion.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from array import array
|
||||
import numpy as np
|
||||
from itertools import chain, count
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
ExecuteModelRequest, SequenceData,
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
||||
from vllm.utils import async_tensor_h2d
|
||||
from vllm.zero_overhead.utils import get_proposal_lens_list, record_proposal_token_ids
|
||||
|
||||
SeqId = int
|
||||
TargetSeqId = int
|
||||
TokenId = int
|
||||
|
||||
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
|
||||
|
||||
|
||||
class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
|
||||
|
||||
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
||||
def score_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
"""Score the proposed tokens via the scorer model.
|
||||
|
||||
This converts each input sequence to a set of k+1 target sequences. The
|
||||
target sequences have the unique continuations to be scored and a
|
||||
unique sequence ID that is different from all input sequence ids.
|
||||
|
||||
If a speculative sequence length would exceed the max model length, then
|
||||
no speculation is produced for that sequence.
|
||||
|
||||
Args:
|
||||
execute_model_req: The execution request.
|
||||
proposals: The speculative proposals to score.
|
||||
Returns:
|
||||
SpeculativeScores: The scores of each speculative token, along with
|
||||
which sequences were ignored during scoring.
|
||||
"""
|
||||
|
||||
proposal_lens_list = get_proposal_lens_list()
|
||||
record_proposal_token_ids(proposals.proposal_token_ids)
|
||||
proposal_token_ids_list = np.zeros(proposals.proposal_token_ids.shape, dtype=int).tolist() # place holder tokens
|
||||
|
||||
# Filter the list to ignore invalid proposals.
|
||||
proposal_token_ids_list_without_skips = [
|
||||
proposals for proposals in proposal_token_ids_list
|
||||
if VLLM_INVALID_TOKEN_ID not in proposals
|
||||
]
|
||||
|
||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens) = self._expand_batch(
|
||||
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list))
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
if not non_spec_indices:
|
||||
# All sequence groups in batch have spec decoding enabled
|
||||
return self._contract_batch_all_spec(
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
)
|
||||
else:
|
||||
# Batch has a mix of spec decode enabled and disabled seq groups
|
||||
return self._contract_batch(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
non_spec_indices=non_spec_indices,
|
||||
spec_indices=spec_indices,
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
)
|
||||
|
||||
def _contract_non_speculative(
|
||||
self, scores: SpeculativeScores,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
|
||||
has_prompt_log: bool) -> SpeculativeScores:
|
||||
"""
|
||||
Augment input `scores` with non-speculative requests outputs.
|
||||
This includes decode requests with speculation turned off, as well
|
||||
as prefill requests when `enable_chunked_prefill` is set.
|
||||
For the latter, prefills are further separated into terminal and
|
||||
non-terminal chunks (from which no token is sampled).
|
||||
"""
|
||||
if not non_spec_indices:
|
||||
return scores
|
||||
|
||||
if has_prompt_log:
|
||||
# When prompt_logprobs is enabled, prefills yield output token
|
||||
# (and respective prob) in the last entry (prompt|out):
|
||||
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
|
||||
# With chunked prefill, non-terminal chunks have -1 on each
|
||||
# position: they're still picked, but they're discarded later.
|
||||
seq_meta = seq_group_metadata_list
|
||||
nospec_sizes = torch.tensor([
|
||||
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
|
||||
for i in non_spec_indices
|
||||
])
|
||||
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
|
||||
else:
|
||||
# In this case only sampled tokens are returned, select all.
|
||||
nospec_sampled_token_idxs = list(
|
||||
range(len(non_spec_outputs.token_ids)))
|
||||
|
||||
nospec_sampled_token_idxs = async_tensor_h2d(nospec_sampled_token_idxs, torch.int32,
|
||||
self._device,
|
||||
True)
|
||||
non_spec_indices = async_tensor_h2d(non_spec_indices, torch.int32,
|
||||
self._device,
|
||||
True)
|
||||
|
||||
scores.token_ids[non_spec_indices, :1] = \
|
||||
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
scores.probs[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
scores.logprobs[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
if scores.hidden_states is not None:
|
||||
assert non_spec_outputs.hidden_states is not None
|
||||
scores.hidden_states[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
return scores
|
||||
137
vllm/zero_overhead/spec_decode/muti_step_worker.py
Normal file
137
vllm/zero_overhead/spec_decode/muti_step_worker.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import copy
|
||||
import weakref
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.utils import async_tensor_h2d
|
||||
from vllm.zero_overhead.spec_decode.top1_proproser import ZeroOverheadTop1Proposer
|
||||
from vllm.zero_overhead.utils import SpecStepKind, set_spec_step
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.worker.worker_base import DelegateWorkerBase
|
||||
|
||||
class ZeroOverheadMultiStepWorker(MultiStepWorker):
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.worker.init_device()
|
||||
self._proposer = ZeroOverheadTop1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
self.device,
|
||||
self.vocab_size,
|
||||
max_proposal_len=self.max_model_len,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
whether torch tensor in sampler output need to be transposed in latter
|
||||
sampler_output_to_torch logic.
|
||||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
# Expand the batch for sequences with a bonus token.
|
||||
# Perform a forward pass on the expanded batch and filter the
|
||||
# response to retain only the original sequences' responses.
|
||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||
self._expand_execute_model_request(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
if current_platform.is_cuda_alike() and isinstance(
|
||||
self.model_runner, TP1DraftModelRunner
|
||||
) and self.model_runner.supports_gpu_multi_step(expanded_request):
|
||||
# Here we run the draft_model_runner with multi-step prepare
|
||||
# on the GPU directly
|
||||
expanded_request.num_steps = sample_len
|
||||
self.model_runner.set_indices_of_seq_with_bonus_tokens(
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs = self.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
else:
|
||||
# Here we run multi-step directly, with every step prepared
|
||||
# on the CPU.
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
|
||||
set_spec_step(SpecStepKind.FIRST_PROPOSAL)
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
set_spec_step(SpecStepKind.OTHER_PROPOSAL)
|
||||
self._append_new_tokens(
|
||||
model_output, expanded_request.seq_group_metadata_list,
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs.append(model_output)
|
||||
set_spec_step(SpecStepKind.SCORE_DECODE)
|
||||
|
||||
filtered_model_outputs = self._filter_model_output_zero_overhead(
|
||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||
|
||||
return filtered_model_outputs, True
|
||||
|
||||
def _filter_model_output_zero_overhead(self,
|
||||
expanded_batch_outputs: List[SamplerOutput],
|
||||
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
|
||||
"""
|
||||
Filters the model output to include only the specified sequence
|
||||
outputs. This method contracts the expanded batch output from the
|
||||
model to retain the outputs of only those sequences indicated by the
|
||||
provided indices.
|
||||
|
||||
Args:
|
||||
expanded_batch_output (List[SamplerOutput]): The expanded output
|
||||
batch from the model.
|
||||
output_indices_to_retain (torch.Tensor): Indices of the model
|
||||
outputs to retain.
|
||||
|
||||
Returns:
|
||||
List[SamplerOutput]: A list containing the filtered model
|
||||
outputs for the specified indices.
|
||||
"""
|
||||
|
||||
indices_of_seq_with_bonus_tokens = async_tensor_h2d(output_indices_to_retain, torch.int32,
|
||||
self.device,
|
||||
True)
|
||||
|
||||
return [
|
||||
SamplerOutput(
|
||||
outputs=[
|
||||
expanded_batch_output.outputs[i]
|
||||
for i in output_indices_to_retain
|
||||
] if len(expanded_batch_output.outputs) > 0 else [],
|
||||
sampled_token_probs=(
|
||||
expanded_batch_output.
|
||||
sampled_token_probs[indices_of_seq_with_bonus_tokens]
|
||||
if expanded_batch_output.sampled_token_probs is not None
|
||||
else None),
|
||||
logprobs=(
|
||||
expanded_batch_output.logprobs[indices_of_seq_with_bonus_tokens]
|
||||
if expanded_batch_output.logprobs is not None else None),
|
||||
sampled_token_ids=(expanded_batch_output.
|
||||
sampled_token_ids[indices_of_seq_with_bonus_tokens]
|
||||
if expanded_batch_output.sampled_token_ids
|
||||
is not None else None))
|
||||
for expanded_batch_output in expanded_batch_outputs
|
||||
]
|
||||
565
vllm/zero_overhead/spec_decode/spec_decode_worker.py
Normal file
565
vllm/zero_overhead/spec_decode/spec_decode_worker.py
Normal file
@@ -0,0 +1,565 @@
|
||||
|
||||
import os
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.distributed.communication_op import (broadcast_tensor_dict,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.distributed.parallel_state import model_parallel_is_initialized
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SequenceGroupMetadata,
|
||||
get_all_seq_ids_and_request_ids, Logits)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTreeStyleScorer
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, prepare_prefill_hidden_states
|
||||
from vllm.zero_overhead.spec_decode.batch_expansion import ZeroOverheadBatchExpansionTop1Scorer
|
||||
from vllm.zero_overhead.utils import SpecStepKind, record_accepted_token_ids, set_spec_step
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.medusa_worker import MedusaWorker
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
||||
from vllm.spec_decode.mqa_scorer import MQAScorer
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
||||
from vllm.spec_decode.target_model_runner import TargetModelRunner
|
||||
from vllm.spec_decode.util import (Timer, create_logprobs_output,
|
||||
create_sequence_group_output,
|
||||
get_all_num_logprobs,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.utils import async_tensor_h2d, resolve_obj_by_qualname
|
||||
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
|
||||
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize both scorer and proposer models.
|
||||
"""
|
||||
# The scorer worker model is initialized first in case the proposer
|
||||
# model has a smaller TP degree than the target worker.
|
||||
self.scorer_worker.init_device()
|
||||
self.proposer_worker.init_device()
|
||||
|
||||
# NOTE(cade): load_model is not part of the WorkerBase interface.
|
||||
self.scorer_worker.load_model()
|
||||
self.proposer_worker.load_model()
|
||||
|
||||
if self._enable_lm_head_weight_load:
|
||||
# NOTE(Shangming): gather lm_head weight when tp enabled
|
||||
target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather(
|
||||
self.scorer_worker.model_runner.model_runner.model.lm_head.\
|
||||
weight.data,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
self.proposer_worker.maybe_load_lm_head_weight(
|
||||
target_lm_head_weight)
|
||||
|
||||
self._metrics.init_tensors(self.rank, device_type=self.device)
|
||||
if model_parallel_is_initialized():
|
||||
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
|
||||
device_type=self.device)
|
||||
else:
|
||||
self.spec_decode_sampler.init_tensors(self.rank,
|
||||
device_type=self.device)
|
||||
|
||||
scorer_cls: Type[SpeculativeScorer]
|
||||
if self.disable_mqa_scorer:
|
||||
scorer_cls = ZeroOverheadBatchExpansionTop1Scorer
|
||||
logger.info("[Speculative Decoding] Use batch "
|
||||
"expansion for scoring proposals.")
|
||||
else:
|
||||
scorer_cls = MQAScorer
|
||||
logger.info(
|
||||
"[Speculative Decoding] Use MQA scorer for scoring proposals.")
|
||||
|
||||
if not self.tree_decoding:
|
||||
self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
|
||||
device=self.device,
|
||||
vocab_size=self._vocab_size)
|
||||
else:
|
||||
self.scorer = BatchExpansionTreeStyleScorer(
|
||||
scorer_worker=self.scorer_worker,
|
||||
device=self.device,
|
||||
vocab_size=self._vocab_size)
|
||||
|
||||
self._configure_model_sampler_for_spec_decode()
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
||||
skip_proposer: bool) -> List[SamplerOutput]:
|
||||
"""Run a single generation step without any speculation. The input is
|
||||
sent to the proposer and scorer model so that the KV cache is consistent
|
||||
between the two. When skip_proposer is True, the proposer model is
|
||||
not called, meaning that the kv-cache in proposer for requests is not
|
||||
updated, so they cannot enable spec decode in the rest decoding.
|
||||
"""
|
||||
if self.tree_decoding and self.kvcache_slot_to_be_moved is not None:
|
||||
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
|
||||
self.kvcache_slot_to_be_moved = None
|
||||
|
||||
set_spec_step(SpecStepKind.PREFILL)
|
||||
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
||||
|
||||
assert len(sampler_output) == 1
|
||||
sampler_output = sampler_output[0]
|
||||
|
||||
# Store hidden states from target model execution, BxD.
|
||||
hidden_states = sampler_output.hidden_states
|
||||
if hidden_states is not None:
|
||||
# Only decodes and prefill terminal chunks need a hidden state.
|
||||
seq_group_meta_with_hidden = [
|
||||
sg for sg in execute_model_req.seq_group_metadata_list
|
||||
if sg.do_sample
|
||||
]
|
||||
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
|
||||
# Drop hidden_states with no prediction (eg non-terminal chunks)
|
||||
hidden_states = hidden_states[
|
||||
torch.where(sampler_output.sampled_token_ids -
|
||||
VLLM_INVALID_TOKEN_ID)[0]]
|
||||
# if not skip_proposer:
|
||||
# if self.previous_hidden_states is None and len(
|
||||
# seq_group_meta_with_hidden):
|
||||
# self.previous_hidden_states = HiddenStates(
|
||||
# hidden_states, seq_group_meta_with_hidden)
|
||||
# elif self.previous_hidden_states and len(
|
||||
# seq_group_meta_with_hidden):
|
||||
# self.previous_hidden_states.update(hidden_states,
|
||||
# seq_group_meta_with_hidden)
|
||||
if self.previous_hidden_states is None and len(
|
||||
seq_group_meta_with_hidden):
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, seq_group_meta_with_hidden)
|
||||
elif self.previous_hidden_states and len(
|
||||
seq_group_meta_with_hidden):
|
||||
self.previous_hidden_states.update(hidden_states,
|
||||
seq_group_meta_with_hidden)
|
||||
|
||||
# Store logits from target model execution.
|
||||
if self.tree_decoding:
|
||||
logits = sampler_output.logits
|
||||
if logits is not None:
|
||||
if self.previous_logits is None:
|
||||
self.previous_logits = Logits(
|
||||
logits, execute_model_req.seq_group_metadata_list)
|
||||
else:
|
||||
self.previous_logits.update(
|
||||
logits, execute_model_req.seq_group_metadata_list)
|
||||
|
||||
if not skip_proposer:
|
||||
# We prepare the prefill hidden states here so that there no
|
||||
# additional complexity in worker for spec_decode vs non_spec_decode
|
||||
# flow and execute_model doesn't need additional modifications.
|
||||
execute_model_req.previous_hidden_states = \
|
||||
prepare_prefill_hidden_states(
|
||||
sampler_output.prefill_hidden_states)
|
||||
for i in range(self._num_spec_prefill_steps):
|
||||
execute_model_req.spec_step_idx = i
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
||||
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
||||
if self._disable_logprobs else
|
||||
[sampler_output])
|
||||
|
||||
# Clear device tensors from sampler output. This reduces communication
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
sampler_output.sampled_token_probs = None
|
||||
sampler_output.sampled_token_ids = None
|
||||
sampler_output.logprobs = None
|
||||
return sampler_output_to_return
|
||||
|
||||
@nvtx_range("spec_decode_worker._verify_tokens")
|
||||
def _verify_tokens(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_scores: SpeculativeScores,
|
||||
proposals: SpeculativeProposals,
|
||||
max_proposal_len: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]:
|
||||
"""Determine which speculative tokens are accepted using the
|
||||
probabilities of each token according to the proposer and scorer models.
|
||||
|
||||
Returns a tuple of Tensors, one for the accepted token ids and one for
|
||||
the logprobs according to the scoring model.
|
||||
"""
|
||||
proposal_lens_list = proposals.proposal_lens
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list, proposal_lens_list)
|
||||
original_indices = spec_indices + non_spec_indices
|
||||
|
||||
# Get probabilities of target model, including bonus tokens.
|
||||
if non_spec_indices:
|
||||
proposal_verifier_probs = proposal_scores.probs[spec_indices]
|
||||
else:
|
||||
proposal_verifier_probs = proposal_scores.probs
|
||||
|
||||
if self.tree_decoding:
|
||||
retrieve_indices = proposals.retrieve_indices
|
||||
proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices]
|
||||
|
||||
# Get non-speculative sampled tokens from target model.
|
||||
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
||||
|
||||
# Get bonus tokens from target model.
|
||||
bonus_token_ids = proposal_scores.token_ids[:, -1:]
|
||||
if non_spec_indices:
|
||||
bonus_token_ids = bonus_token_ids[spec_indices, :]
|
||||
|
||||
# Get probabilities according to proposal method.
|
||||
proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
|
||||
if proposal_probs is not None and non_spec_indices:
|
||||
proposal_probs = proposal_probs[spec_indices]
|
||||
|
||||
# Get proposed tokens.
|
||||
proposal_token_ids = proposals.proposal_token_ids
|
||||
if non_spec_indices:
|
||||
proposal_token_ids = proposal_token_ids[spec_indices]
|
||||
|
||||
# Get tree buffers.
|
||||
cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
|
||||
if cart_candidates is not None and non_spec_indices:
|
||||
cart_candidates = cart_candidates[spec_indices]
|
||||
|
||||
# Sampler arguments
|
||||
sampler_extra_kwargs: Dict[str, Any] = {}
|
||||
if self.generators and isinstance(self.spec_decode_sampler,
|
||||
SpecDecodeStochasticBaseSampler):
|
||||
sampler_extra_kwargs["seeded_seqs"] = {
|
||||
idx: self.generators[sgm.request_id]
|
||||
for idx, sgm in enumerate(seq_group_metadata_list)
|
||||
if sgm.sampling_params.seed is not None
|
||||
}
|
||||
|
||||
if isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler):
|
||||
sampler_extra_kwargs["cart_candidates"] = cart_candidates
|
||||
sampler_extra_kwargs["best_candidates"] = []
|
||||
sampler_extra_kwargs["accept_lengths"] = []
|
||||
|
||||
first_step_flags = []
|
||||
for i, sgm in enumerate(seq_group_metadata_list):
|
||||
seq = next(iter(sgm.seq_data.values()))
|
||||
first_step_flags.append(True if seq.get_first_step_flag() else False)
|
||||
|
||||
sampler_extra_kwargs["first_step_flags"] = first_step_flags
|
||||
|
||||
accepted_token_ids = self.spec_decode_sampler(
|
||||
target_with_bonus_probs=proposal_verifier_probs,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
draft_probs=proposal_probs,
|
||||
draft_token_ids=proposal_token_ids,
|
||||
**sampler_extra_kwargs,
|
||||
)
|
||||
# Append output tokens from non-speculative sequences to
|
||||
# the accepted token ids tensor.
|
||||
if not self.tree_decoding:
|
||||
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
|
||||
1).clone()
|
||||
else:
|
||||
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len).clone()
|
||||
|
||||
non_spec_token_ids[:, 1:] = -1
|
||||
accepted_token_ids = torch.cat(
|
||||
[accepted_token_ids, non_spec_token_ids])
|
||||
logprobs = proposal_scores.logprobs
|
||||
# Rearrange so that results are in the order of the original seq group
|
||||
# metadata.
|
||||
original_indices = async_tensor_h2d(original_indices, torch.int32,
|
||||
self.device,
|
||||
True)
|
||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||
|
||||
# B x K+1 x D
|
||||
hidden_states = proposal_scores.hidden_states
|
||||
|
||||
select_indices = None
|
||||
accept_lengths = None
|
||||
|
||||
select_indices_list = []
|
||||
|
||||
if cart_candidates is None:
|
||||
if hidden_states is not None:
|
||||
# Only get terminal hidden states for next step
|
||||
terminal_metadata = [
|
||||
sg for sg in seq_group_metadata_list if sg.do_sample
|
||||
]
|
||||
# Contract hidden states based on accepted tokens
|
||||
hs_size = hidden_states.shape[-1]
|
||||
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
|
||||
# Drop non-terminal prefill chunks hidden states.
|
||||
hidden_states = hidden_states[accepted_index !=
|
||||
VLLM_INVALID_TOKEN_ID]
|
||||
accepted_index = accepted_index[accepted_index !=
|
||||
VLLM_INVALID_TOKEN_ID]
|
||||
assert len(accepted_index) == hidden_states.shape[0] == len(
|
||||
terminal_metadata)
|
||||
index = accepted_index[:, None, None].expand(-1, 1,
|
||||
hs_size) # b x 1 x d
|
||||
second_last_token_hidden_states = hidden_states[:, -2] # b x d
|
||||
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
||||
|
||||
# Store hidden states from target model for subsequent decode step
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, terminal_metadata,
|
||||
second_last_token_hidden_states)
|
||||
else:
|
||||
retrieve_indices = proposals.retrieve_indices
|
||||
|
||||
batch_size = len(seq_group_metadata_list)
|
||||
|
||||
best_candidates = sampler_extra_kwargs["best_candidates"]
|
||||
accept_lengths = sampler_extra_kwargs["accept_lengths"]
|
||||
|
||||
# Contract hidden states based on accepted tokens
|
||||
hs_size = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(batch_size, -1, hs_size)
|
||||
|
||||
# Store logits from target model for subsequent proposal
|
||||
logits = proposal_scores.logits
|
||||
logits = logits.view(batch_size, -1, logits.shape[-1])
|
||||
logits = logits[:, retrieve_indices] # [batch_size, retrieve_size, max_depth, vocab_size]
|
||||
|
||||
previous_logits_list = []
|
||||
|
||||
previous_hidden_state_list = []
|
||||
|
||||
retrieve_indices = retrieve_indices.cpu()
|
||||
|
||||
for i in range(batch_size):
|
||||
logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
|
||||
previous_logits_list.append(logit)
|
||||
select_indices = retrieve_indices[best_candidates[i], :accept_lengths[i]+1]
|
||||
hidden_state = hidden_states[i, select_indices[-1]].unsqueeze(0)
|
||||
select_indices_list.append(select_indices)
|
||||
previous_hidden_state_list.append(hidden_state)
|
||||
|
||||
logits = torch.cat(previous_logits_list, dim=0)
|
||||
self.previous_logits = Logits(logits, seq_group_metadata_list)
|
||||
|
||||
hidden_states = torch.cat(previous_hidden_state_list, dim=0) # [batch_size, 1, vocab_size]
|
||||
self.previous_hidden_states = HiddenStates(hidden_states,
|
||||
seq_group_metadata_list,)
|
||||
|
||||
return accepted_token_ids, logprobs, select_indices_list, accept_lengths
|
||||
|
||||
def _create_output_sampler_list(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
||||
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
||||
prompt_logprobs: Optional[
|
||||
torch.Tensor], # shape: [nprompt_tokens, vocab_size]
|
||||
k: int,
|
||||
stage_times: Tuple[float, float, float],
|
||||
) -> List[SamplerOutput]:
|
||||
"""Given the accepted token ids, create a list of SamplerOutput.
|
||||
|
||||
The output is padded with -1 tokens such that each sequence has
|
||||
the same number of outputs.
|
||||
"""
|
||||
batch_size, num_steps = accepted_token_ids.shape
|
||||
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
|
||||
if self._disable_logprobs:
|
||||
# We are skipping the logprobs. Hence don't serialize the
|
||||
# logprobs related tensors from the GPU. Instead create
|
||||
# empty/dummy lists.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step,
|
||||
topk_logprobs_by_step, topk_indices_by_step) =\
|
||||
self._create_dummy_logprob_lists(
|
||||
batch_size, num_steps,
|
||||
self.scorer_worker.model_config.max_logprobs)
|
||||
else:
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
target_logprobs_by_step = target_logprobs.transpose(0, 1)
|
||||
# Serialize all tensors into Python lists.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step,
|
||||
topk_logprobs_by_step, topk_indices_by_step) =\
|
||||
self._create_logprob_lists_from_tensors(
|
||||
target_logprobs_by_step, accepted_token_ids_by_step,
|
||||
self.scorer_worker.model_config.max_logprobs)
|
||||
|
||||
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
||||
# batch.
|
||||
seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
|
||||
seq_group_metadata_list)
|
||||
|
||||
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
||||
|
||||
# Serialize tensor to CPU Python list.
|
||||
#accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
record_accepted_token_ids(accepted_token_ids, seq_ids)
|
||||
|
||||
# Construct the output on a per-step, per-sequence basis.
|
||||
# Non-terminal prefill chunks will end up here as rows with just -1s
|
||||
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
|
||||
# terminal chunks will only have one generated token at time 0.
|
||||
sampler_output_list: List[SamplerOutput] = []
|
||||
|
||||
# Prefills are not multi-step (return at most 1 token), in order to
|
||||
# avoid padding or repetition to fit decodes, we separate them.
|
||||
for i, sg in enumerate(seq_group_metadata_list):
|
||||
if not sg.is_prompt:
|
||||
# Requests are ordered as prefills|decodes=>no more prefills.
|
||||
break
|
||||
num_logprobs = num_logprobs_per_seq[i]
|
||||
seq_kwargs = dict(token_id=-1,
|
||||
token_id_logprob_rank=0,
|
||||
token_id_logprob=-float('inf'),
|
||||
topk_token_ids=[-1] * num_logprobs,
|
||||
topk_logprobs=[-float('inf')] * num_logprobs,
|
||||
seq_id=seq_ids[i])
|
||||
# Terminal chunk, has token.
|
||||
if sg.do_sample:
|
||||
seq_kwargs.update(
|
||||
dict(
|
||||
token_id=accepted_token_ids[i][0].item(),
|
||||
token_id_logprob_rank=accepted_token_id_ranks_by_step[
|
||||
0][i],
|
||||
token_id_logprob=accepted_token_id_logprobs_by_step[0]
|
||||
[i],
|
||||
topk_token_ids=topk_indices_by_step[0][i]
|
||||
[:num_logprobs],
|
||||
# output only so step is 0
|
||||
topk_logprobs=topk_logprobs_by_step[0][i]
|
||||
[:num_logprobs],
|
||||
))
|
||||
needs_plogs = (sg.sampling_params.prompt_logprobs
|
||||
and sg.sampling_params.prompt_logprobs > 0)
|
||||
plogs = None
|
||||
if prompt_logprobs is not None:
|
||||
# Even non-terminal prompt chunks can have logprobs here.
|
||||
plogs = prompt_logprobs[i]
|
||||
elif needs_plogs:
|
||||
# Prompt logprobs are requested but `_disable_logprobs` is set.
|
||||
seq_data = next(iter(sg.seq_data.values()))
|
||||
# Get only the tokens in this chunk!
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
prompt_token_ids = prompt_token_ids[
|
||||
seq_data.
|
||||
_num_computed_tokens:seq_data._num_computed_tokens +
|
||||
sg.token_chunk_size]
|
||||
|
||||
is_first_chunk = seq_data._num_computed_tokens == 0
|
||||
# There's no prob generated for the first token in a sequence.
|
||||
if is_first_chunk:
|
||||
prompt_token_ids = prompt_token_ids[1:]
|
||||
plogs = [
|
||||
create_logprobs_output(
|
||||
token_id=p_token_id,
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
) for p_token_id in prompt_token_ids
|
||||
]
|
||||
seq_kwargs.update(dict(prompt_logprobs=plogs))
|
||||
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(
|
||||
outputs=[create_sequence_group_output(
|
||||
**seq_kwargs)])) # type: ignore
|
||||
|
||||
# Decodes, create one SamplerOutput per-step (at most K+1).
|
||||
for step_index in range(num_steps):
|
||||
# if all(token_id == -1 for sg, token_id in zip(
|
||||
# seq_group_metadata_list,
|
||||
# accepted_token_ids_by_step[step_index])
|
||||
# if not sg.is_prompt):
|
||||
# break
|
||||
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
|
||||
for sequence_index in range(batch_size):
|
||||
seq_meta = seq_group_metadata_list[sequence_index]
|
||||
# Prompts already processed above.
|
||||
if seq_meta.is_prompt:
|
||||
continue
|
||||
|
||||
# Each sequence may have a different num_logprobs; retrieve it.
|
||||
num_logprobs = num_logprobs_per_seq[sequence_index]
|
||||
step_output_token_ids.append(
|
||||
create_sequence_group_output(
|
||||
token_id = 0,
|
||||
token_id_logprob_rank=accepted_token_id_ranks_by_step[
|
||||
step_index][sequence_index],
|
||||
token_id_logprob=accepted_token_id_logprobs_by_step[
|
||||
step_index][sequence_index],
|
||||
seq_id=seq_ids[sequence_index],
|
||||
topk_token_ids=topk_indices_by_step[step_index]
|
||||
[sequence_index][:num_logprobs],
|
||||
topk_logprobs=topk_logprobs_by_step[step_index]
|
||||
[sequence_index][:num_logprobs],
|
||||
))
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
|
||||
# Populate the data structures needed to keep track of sequences with
|
||||
# bonus tokens.
|
||||
self._track_sequences_with_bonus_tokens(seq_ids,
|
||||
request_ids_seq_ids_mapping,
|
||||
accepted_token_ids_by_step)
|
||||
maybe_rejsample_metrics = (
|
||||
self._metrics.maybe_collect_rejsample_metrics(k))
|
||||
if maybe_rejsample_metrics is not None and sampler_output_list:
|
||||
sampler_output_list[
|
||||
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||
|
||||
# Log time spent in each stage periodically.
|
||||
# This is periodic because the rejection sampler emits metrics
|
||||
# periodically.
|
||||
self._maybe_log_stage_times(*stage_times)
|
||||
# First `n_prefills` entries will contain prefills SamplerOutput when
|
||||
# chunked prefill is enabled, the rest is decodes in multi-step format.
|
||||
return sampler_output_list
|
||||
|
||||
def _track_sequences_with_bonus_tokens(
|
||||
self, seq_ids: List[int],
|
||||
request_ids_seq_ids_mapping: Dict[str, Set[int]],
|
||||
accepted_token_ids_by_step: List[List[int]]):
|
||||
"""
|
||||
Updates the internal data structures which keep track of sequences
|
||||
which have been assigned bonus tokens in their last forward pass.
|
||||
"""
|
||||
for seq_index, seq_id in enumerate(seq_ids):
|
||||
# last_token_id = accepted_token_ids_by_step[-1][seq_index]
|
||||
# if last_token_id == -1:
|
||||
# self._seq_with_bonus_token_in_last_step.discard(seq_id)
|
||||
# else:
|
||||
self._seq_with_bonus_token_in_last_step.add(seq_id)
|
||||
for request_id, sequences in request_ids_seq_ids_mapping.items():
|
||||
self._request_id_seq_id_mapping[request_id].update(sequences)
|
||||
84
vllm/zero_overhead/spec_decode/top1_proproser.py
Normal file
84
vllm/zero_overhead/spec_decode/top1_proproser.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
from vllm.utils import async_tensor_h2d
|
||||
from vllm.zero_overhead.utils import record_proposal_lens_list
|
||||
|
||||
class ZeroOverheadTop1Proposer(Top1Proposer):
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
batch_size: int,
|
||||
proposal_len: int,
|
||||
maybe_sampler_output: Optional[List[SamplerOutput]],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
sampler_transposed: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""After speculations are produced, merge the speculation results with
|
||||
the skipped sequences.
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.tensor(-1,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len)
|
||||
proposal_probs = torch.tensor(0,
|
||||
dtype=torch.float32,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len,
|
||||
self._vocab_size)
|
||||
proposal_lens_tensor = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
len(proposal_lens))
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
|
||||
sampler_output, sampler_transposed)
|
||||
|
||||
proposal_lens_list = [0 for i in range(batch_size)]
|
||||
for indices in nonzero_proposal_len_indices:
|
||||
proposal_lens_list[indices] = proposal_len
|
||||
record_proposal_lens_list(proposal_lens_list)
|
||||
|
||||
nonzero_proposal_len_indices = async_tensor_h2d(nonzero_proposal_len_indices, torch.int32,
|
||||
self._device,
|
||||
True)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = proposal_tokens.new_full(
|
||||
size=(batch_size, *proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = proposal_probs.new_zeros(
|
||||
batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
proposal_tokens, proposal_probs = (
|
||||
entire_proposal_tokens,
|
||||
entire_proposal_probs,
|
||||
)
|
||||
|
||||
proposal_lens_tensor = async_tensor_h2d(proposal_lens_list, torch.long,
|
||||
self._device,
|
||||
True)
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
77
vllm/zero_overhead/stop_check.py
Normal file
77
vllm/zero_overhead/stop_check.py
Normal file
@@ -0,0 +1,77 @@
|
||||
|
||||
|
||||
from typing import Optional
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SequenceStatus
|
||||
from vllm.zero_overhead.sequence import ZeroOverheadSequence
|
||||
|
||||
|
||||
class ZeroOverheadStopChecker(StopChecker):
|
||||
def __init__(self, max_model_len, get_tokenizer_for_seq):
|
||||
super().__init__(max_model_len, get_tokenizer_for_seq)
|
||||
|
||||
|
||||
def maybe_stop_sequence(
|
||||
self,
|
||||
seq: ZeroOverheadSequence,
|
||||
new_char_count: int,
|
||||
sampling_params: SamplingParams,
|
||||
lora_req: Optional[LoRARequest] = None,
|
||||
) -> None:
|
||||
"""Stop the finished sequences.
|
||||
|
||||
new_char_count is the number of chars added to the
|
||||
sequence's output text for the newly generated token
|
||||
"""
|
||||
|
||||
# Check if the minimum number of tokens has been generated yet;
|
||||
# skip the stop string/token checks if not
|
||||
if seq.zero_overhead_get_output_len() < sampling_params.min_tokens:
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.zero_overhead_get_last_token_id() == seq.eos_token_id):
|
||||
# Remove the last EOS token unless explicitly specified
|
||||
# This prevents unintended exposure of the EOS token
|
||||
if new_char_count and (
|
||||
not sampling_params.include_stop_str_in_output):
|
||||
seq.output_text = seq.output_text[:-new_char_count]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if a stop token was encountered.
|
||||
# This assumes a single token produced per step.
|
||||
last_token_id = seq.zero_overhead_get_last_token_id()
|
||||
if last_token_id in (sampling_params.stop_token_ids or ()):
|
||||
if new_char_count and (
|
||||
not sampling_params.include_stop_str_in_output):
|
||||
# Remove last token
|
||||
seq.output_text = seq.output_text[:-new_char_count]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = last_token_id
|
||||
return
|
||||
|
||||
# Check if any stop strings are matched.
|
||||
stop = self.check_stop_strings(
|
||||
seq.output_text, new_char_count, sampling_params.stop,
|
||||
sampling_params.include_stop_str_in_output)
|
||||
if stop is not None:
|
||||
stop_str, truncate_to = stop
|
||||
if truncate_to != -1:
|
||||
seq.output_text = seq.output_text[:truncate_to]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = stop_str
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.zero_overhead_get_len() > self._get_max_model_len(lora_req):
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.zero_overhead_get_output_len() == sampling_params.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
84
vllm/zero_overhead/tokenizer.py
Normal file
84
vllm/zero_overhead/tokenizer.py
Normal file
@@ -0,0 +1,84 @@
|
||||
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import VLLM_INVALID_TOKEN_ID
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.detokenizer_utils import convert_prompt_ids_to_tokens, detokenize_incrementally
|
||||
from vllm.zero_overhead.sequence import ZeroOverheadSequence
|
||||
|
||||
|
||||
class ZeroOverheadDetokenizer(Detokenizer):
|
||||
def __init__(self, tokenizer_group):
|
||||
super().__init__(tokenizer_group)
|
||||
|
||||
def decode_sequence_inplace(self, seq: ZeroOverheadSequence,
|
||||
prms: SamplingParams) -> int:
|
||||
"""Decodes the new token for a sequence. In-place operation.
|
||||
|
||||
Args:
|
||||
seq: The sequence to decode.
|
||||
prms: The sampling parameters used to generate the sequence.
|
||||
|
||||
Returns:
|
||||
The number of characters added to the output text.
|
||||
"""
|
||||
eff_length = seq.get_prompt_len() + seq.effective_output_len
|
||||
all_input_ids = seq.get_token_ids()[ : eff_length]
|
||||
|
||||
token_id_generated_this_iteration = all_input_ids[-1]
|
||||
tokenizer = self.get_tokenizer_for_seq(seq)
|
||||
|
||||
# Convert prompt token IDs to tokens if necessary.
|
||||
# Do it here so that we don't have to repeat this
|
||||
# computation for each logprob.
|
||||
if seq.tokens is None:
|
||||
(seq.tokens, seq.prefix_offset,
|
||||
seq.read_offset) = convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=all_input_ids[:-1],
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
)
|
||||
|
||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
all_input_ids=all_input_ids,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
# Decode logprobs
|
||||
logprobs = seq.output_logprobs[-1]
|
||||
if logprobs:
|
||||
previous_tokens = all_input_ids[:-1]
|
||||
for token_id, sample_logprob in logprobs.items():
|
||||
# If the token was generated this iteration,
|
||||
# use the provided text.
|
||||
if token_id == token_id_generated_this_iteration:
|
||||
sample_logprob.decoded_token = new_decoded_token_text
|
||||
continue
|
||||
|
||||
if (sample_logprob.decoded_token is None
|
||||
and token_id != VLLM_INVALID_TOKEN_ID):
|
||||
all_input_ids_with_logprob = previous_tokens + [token_id]
|
||||
(_, new_text, _, _) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
all_input_ids=all_input_ids_with_logprob,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.
|
||||
spaces_between_special_tokens,
|
||||
)
|
||||
sample_logprob.decoded_token = new_text
|
||||
|
||||
seq.tokens.extend(new_tokens)
|
||||
seq.prefix_offset = prefix_offset
|
||||
seq.read_offset = read_offset
|
||||
seq.output_text += new_decoded_token_text
|
||||
|
||||
return len(new_decoded_token_text)
|
||||
71
vllm/zero_overhead/utils.py
Normal file
71
vllm/zero_overhead/utils.py
Normal file
@@ -0,0 +1,71 @@
|
||||
|
||||
|
||||
from enum import Enum
|
||||
import os
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
|
||||
zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1'
|
||||
|
||||
def is_zero_no_thread():
|
||||
return zero_no_thread and envs.VLLM_ZERO_OVERHEAD
|
||||
|
||||
class SpecStepKind(Enum):
|
||||
KIND_DEFAULT = 0
|
||||
PREFILL = 1
|
||||
FIRST_PROPOSAL = 2
|
||||
OTHER_PROPOSAL = 3
|
||||
SCORE_DECODE = 4
|
||||
|
||||
class ZeroOverheadSpecContext():
|
||||
def __init__(self):
|
||||
self.step_kind = SpecStepKind.KIND_DEFAULT
|
||||
self.last_step = SpecStepKind.KIND_DEFAULT
|
||||
self.proposal_lens_list = None
|
||||
self.proposal_token_ids = None
|
||||
self.accepted_token_ids = None
|
||||
self.accepted_seq_ids = None
|
||||
|
||||
spec_context = ZeroOverheadSpecContext()
|
||||
|
||||
def set_spec_step(_step):
|
||||
global spec_context
|
||||
spec_context.last_step = spec_context.step_kind
|
||||
spec_context.step_kind = _step
|
||||
|
||||
def get_spec_step():
|
||||
return spec_context.step_kind
|
||||
|
||||
def get_spec_last_step():
|
||||
return spec_context.last_step
|
||||
|
||||
def record_proposal_lens_list(list):
|
||||
global spec_context
|
||||
spec_context.proposal_lens_list = list
|
||||
|
||||
def get_proposal_lens_list():
|
||||
return spec_context.proposal_lens_list
|
||||
|
||||
def record_proposal_token_ids(tensor):
|
||||
global spec_context
|
||||
spec_context.proposal_token_ids = tensor
|
||||
|
||||
def get_proposal_token_ids():
|
||||
return spec_context.proposal_token_ids
|
||||
|
||||
def record_accepted_token_ids(tensor, seq_ids):
|
||||
global spec_context
|
||||
spec_context.accepted_token_ids = tensor
|
||||
spec_context.accepted_seq_ids = seq_ids
|
||||
|
||||
def get_accepted_token_ids():
|
||||
return spec_context.accepted_token_ids, spec_context.accepted_seq_ids
|
||||
|
||||
# 零消耗调度不在默认流上推理,用以规避runtime引入的内存申请流同步问题。
|
||||
alloc_stream = {}
|
||||
|
||||
def zero_overhead_stream(target_device):
|
||||
"""Asynchronously create a tensor and copy it from host to device."""
|
||||
if target_device not in alloc_stream.keys():
|
||||
alloc_stream[target_device] = torch.cuda.Stream(device=target_device)
|
||||
return alloc_stream[target_device]
|
||||
359
vllm/zero_overhead/v1/core.py
Normal file
359
vllm/zero_overhead/v1/core.py
Normal file
@@ -0,0 +1,359 @@
|
||||
|
||||
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
|
||||
|
||||
requsets_valid_token_len = {}
|
||||
|
||||
def check_stop(request: Request,
|
||||
max_model_len: int,
|
||||
pooler_output: Optional[torch.Tensor] = None,
|
||||
use_valid_token_len:bool = False) -> bool:
|
||||
if use_valid_token_len:
|
||||
if request.request_id not in requsets_valid_token_len:
|
||||
requsets_valid_token_len[request.request_id] = 0
|
||||
return False
|
||||
valid_output_len = requsets_valid_token_len[request.request_id]
|
||||
else:
|
||||
valid_output_len = request.num_output_tokens
|
||||
valid_num_tokens = request.num_prompt_tokens + valid_output_len
|
||||
if (valid_num_tokens >= max_model_len
|
||||
or valid_output_len >= request.max_tokens):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
return True
|
||||
|
||||
if request.pooling_params:
|
||||
if pooler_output is not None:
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
last_token_id = request.output_token_ids[valid_output_len - 1]
|
||||
if (not sampling_params.ignore_eos
|
||||
and last_token_id == request.eos_token_id):
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
return True
|
||||
|
||||
if last_token_id in (sampling_params.stop_token_ids or ()):
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
request.stop_reason = last_token_id
|
||||
return True
|
||||
return False
|
||||
|
||||
def zero_overhead_update_from_output(scheduler:Scheduler,
|
||||
scheduler_output: SchedulerOutput,
|
||||
model_runner_output: ZeroV1ModelRunnerOutput):
|
||||
global requsets_valid_token_len
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
spec_token_ids = model_runner_output.spec_token_ids
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
pooler_outputs = model_runner_output.pooler_output
|
||||
num_nans_in_logits = model_runner_output.num_nans_in_logits
|
||||
|
||||
new_running: list[Request] = []
|
||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||
|
||||
# fix last model out in zero overhead
|
||||
if model_runner_output.fix_req_ids is not None:
|
||||
for req_idx, req_id in enumerate(model_runner_output.fix_req_ids):
|
||||
if req_id not in scheduler.requests:
|
||||
continue
|
||||
request = scheduler.requests[req_id]
|
||||
generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx]
|
||||
if req_id not in requsets_valid_token_len:
|
||||
requsets_valid_token_len[req_id] = 0
|
||||
valid_output_len = requsets_valid_token_len[req_id]
|
||||
fix_offset = valid_output_len - request.num_output_tokens
|
||||
if isinstance(generated_token_ids, int):
|
||||
request._output_token_ids[fix_offset] = generated_token_ids
|
||||
request._all_token_ids[fix_offset] = generated_token_ids
|
||||
requsets_valid_token_len[req_id] += 1
|
||||
generated_token_ids = [generated_token_ids]
|
||||
else:
|
||||
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
|
||||
if valid_output_end == 0:
|
||||
request._output_token_ids[fix_offset : ] = generated_token_ids
|
||||
request._all_token_ids[fix_offset : ] = generated_token_ids
|
||||
else:
|
||||
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids
|
||||
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids
|
||||
requsets_valid_token_len[req_id] += len(generated_token_ids)
|
||||
|
||||
|
||||
stopped = False
|
||||
new_logprobs = None
|
||||
new_token_ids = generated_token_ids
|
||||
kv_transfer_params = None
|
||||
|
||||
# Check for stop and update request state.
|
||||
# This must be called before we make the EngineCoreOutput.
|
||||
for num_new, output_token_id in enumerate(new_token_ids, 1):
|
||||
stopped = check_stop(request, scheduler.max_model_len, True)
|
||||
if stopped:
|
||||
kv_transfer_params = scheduler._free_request(request)
|
||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||
break
|
||||
|
||||
pooler_output = None
|
||||
if pooler_outputs:
|
||||
pooler_output = pooler_outputs[req_idx]
|
||||
stopped = check_stop(request, scheduler.max_model_len,
|
||||
pooler_output, True)
|
||||
if stopped:
|
||||
kv_transfer_params = scheduler._free_request(request)
|
||||
|
||||
# Extract sample logprobs if needed.
|
||||
if request.sampling_params is not None \
|
||||
and request.sampling_params.logprobs is not None and logprobs:
|
||||
# NOTE: once we support N tokens per step (spec decode),
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_idx, req_idx + 1)
|
||||
|
||||
if new_token_ids and scheduler.structured_output_manager.should_advance(
|
||||
request):
|
||||
# NOTE: structured_output_request
|
||||
# should not be None if use_structured_output, we have
|
||||
# check above, so safe to ignore type warning
|
||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||
req_id, new_token_ids)
|
||||
|
||||
# spec_token_ids comes from the model runner output
|
||||
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||
|
||||
# Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
if new_token_ids or pooler_output is not None \
|
||||
or kv_transfer_params:
|
||||
|
||||
# Add EngineCoreOutput for this Request.
|
||||
outputs[request.client_index].append(
|
||||
EngineCoreOutput(
|
||||
request_id=req_id,
|
||||
new_token_ids=new_token_ids,
|
||||
finish_reason=request.get_finished_reason(),
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
pooling_output=pooler_output,
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events(),
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
num_cached_tokens=request.num_cached_tokens,
|
||||
))
|
||||
|
||||
else:
|
||||
# Invariant: EngineCore returns no partial prefill outputs.
|
||||
assert not prompt_logprobs_tensors
|
||||
|
||||
# fix last model out in zero overhead
|
||||
if model_runner_output.fix_draft_req_ids is not None:
|
||||
for req_idx, req_id in enumerate(model_runner_output.fix_draft_req_ids):
|
||||
if req_id not in scheduler.requests:
|
||||
continue
|
||||
request = scheduler.requests[req_id]
|
||||
|
||||
# Add newly generated spec token ids to the request.
|
||||
if model_runner_output.fix_draft_tokens_ids is not None:
|
||||
if scheduler.structured_output_manager.should_advance(request):
|
||||
metadata = request.structured_output_request
|
||||
# Needs to happen after new_token_ids are accepted.
|
||||
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
|
||||
model_runner_output.fix_draft_tokens_ids[req_idx])
|
||||
else:
|
||||
request.spec_token_ids = model_runner_output.fix_draft_tokens_ids[req_idx]
|
||||
|
||||
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
||||
# loop can be a performance bottleneck. We should do our best to avoid
|
||||
# expensive operations inside the loop.
|
||||
for request in scheduler.running:
|
||||
req_id = request.request_id
|
||||
if request.is_finished():
|
||||
if req_id in requsets_valid_token_len:
|
||||
requsets_valid_token_len.pop(req_id)
|
||||
continue
|
||||
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
||||
if num_tokens_scheduled == 0:
|
||||
# The request was not scheduled in this step.
|
||||
new_running.append(request)
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = sampled_token_ids[
|
||||
req_index] if sampled_token_ids else []
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
if scheduled_spec_token_ids:
|
||||
# num_computed_tokens represents the number of tokens
|
||||
# processed in the current step, considering scheduled
|
||||
# tokens and rejections. If some tokens are rejected,
|
||||
# num_computed_tokens is decreased by the number of rejected
|
||||
# tokens, where is given by:
|
||||
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
|
||||
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
|
||||
len(generated_token_ids))
|
||||
request.num_computed_tokens -= num_tokens_rejected
|
||||
spec_decoding_stats = scheduler.make_spec_decoding_stats(
|
||||
spec_decoding_stats,
|
||||
num_draft_tokens=len(scheduled_spec_token_ids),
|
||||
num_accepted_tokens=len(generated_token_ids) - 1)
|
||||
|
||||
# NOTE(woosuk): This has to be executed after updating
|
||||
# `request.num_computed_tokens`.
|
||||
if request.has_encoder_inputs:
|
||||
scheduler._free_encoder_inputs(request)
|
||||
|
||||
stopped = False
|
||||
new_logprobs = None
|
||||
new_token_ids = generated_token_ids
|
||||
kv_transfer_params = None
|
||||
|
||||
# Append generated tokens and check for stop. Note that if
|
||||
# a request is still being prefilled, we expect the model runner
|
||||
# to return empty token ids for the request.
|
||||
for num_new, output_token_id in enumerate(new_token_ids, 1):
|
||||
request.append_output_token_ids(output_token_id)
|
||||
|
||||
# Check for stop and update request state.
|
||||
# This must be called before we make the EngineCoreOutput.
|
||||
|
||||
if model_runner_output.is_output_valid:
|
||||
stopped = check_stop(request, scheduler.max_model_len,
|
||||
False)
|
||||
if stopped:
|
||||
kv_transfer_params = scheduler._free_request(request)
|
||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||
break
|
||||
|
||||
pooler_output = None
|
||||
if pooler_outputs:
|
||||
if model_runner_output.is_output_valid:
|
||||
pooler_output = pooler_outputs[req_index]
|
||||
stopped = check_stop(request, scheduler.max_model_len,
|
||||
pooler_output,
|
||||
False)
|
||||
if stopped:
|
||||
kv_transfer_params = scheduler._free_request(request)
|
||||
|
||||
# Extract sample logprobs if needed.
|
||||
if request.sampling_params is not None \
|
||||
and request.sampling_params.logprobs is not None and logprobs:
|
||||
# NOTE: once we support N tokens per step (spec decode),
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
|
||||
if new_token_ids and scheduler.structured_output_manager.should_advance(
|
||||
request):
|
||||
# NOTE: structured_output_request
|
||||
# should not be None if use_structured_output, we have
|
||||
# check above, so safe to ignore type warning
|
||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||
req_id, new_token_ids)
|
||||
|
||||
# spec_token_ids comes from the model runner output
|
||||
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||
|
||||
# Add newly generated spec token ids to the request.
|
||||
if spec_token_ids is not None:
|
||||
if scheduler.structured_output_manager.should_advance(request):
|
||||
metadata = request.structured_output_request
|
||||
# Needs to happen after new_token_ids are accepted.
|
||||
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
|
||||
spec_token_ids[req_index])
|
||||
else:
|
||||
request.spec_token_ids = spec_token_ids[req_index]
|
||||
|
||||
if model_runner_output.is_output_valid:
|
||||
# # Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
if new_token_ids or pooler_output is not None \
|
||||
or kv_transfer_params:
|
||||
|
||||
# Add EngineCoreOutput for this Request.
|
||||
outputs[request.client_index].append(
|
||||
EngineCoreOutput(
|
||||
request_id=req_id,
|
||||
new_token_ids=new_token_ids,
|
||||
finish_reason=request.get_finished_reason(),
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
pooling_output=pooler_output,
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events(),
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
num_cached_tokens=request.num_cached_tokens,
|
||||
))
|
||||
if stopped:
|
||||
if req_id in requsets_valid_token_len:
|
||||
requsets_valid_token_len.pop(req_id)
|
||||
else:
|
||||
new_running.append(request)
|
||||
|
||||
scheduler.running = new_running
|
||||
|
||||
# KV Connector: update state for finished KV Transfers.
|
||||
scheduler._update_from_kv_xfer_finished(model_runner_output)
|
||||
|
||||
# Create EngineCoreOutputs for all clients that have requests with
|
||||
# outputs in this step.
|
||||
engine_core_outputs = {
|
||||
client_index: EngineCoreOutputs(outputs=outs)
|
||||
for client_index, outs in outputs.items()
|
||||
}
|
||||
|
||||
finished_req_ids = scheduler.finished_req_ids_dict
|
||||
if finished_req_ids:
|
||||
# Include ids of requests that finished since last outputs
|
||||
# were sent.
|
||||
for client_index, finished_set in finished_req_ids.items():
|
||||
# Set finished request set in EngineCoreOutputs for this client.
|
||||
if (eco := engine_core_outputs.get(client_index)) is not None:
|
||||
eco.finished_requests = finished_set
|
||||
else:
|
||||
engine_core_outputs[client_index] = EngineCoreOutputs(
|
||||
finished_requests=finished_set)
|
||||
finished_req_ids.clear()
|
||||
|
||||
if engine_core_outputs:
|
||||
# Return stats to only one of the front-ends.
|
||||
next(iter(engine_core_outputs.values())).scheduler_stats = (
|
||||
scheduler.make_stats(spec_decoding_stats))
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
|
||||
def engine_core_step(core) -> tuple[dict[int, EngineCoreOutputs], bool]:
|
||||
"""Schedule, execute, and make output.
|
||||
|
||||
Returns tuple of outputs and a flag indicating whether the model
|
||||
was executed.
|
||||
"""
|
||||
|
||||
# Check for any requests remaining in the scheduler - unfinished,
|
||||
# or finished and not yet removed from the batch.
|
||||
if not core.scheduler.has_requests():
|
||||
return {}, False
|
||||
scheduler_output = core.scheduler.schedule()
|
||||
model_output = core.execute_model(scheduler_output)
|
||||
if isinstance(model_output, ZeroV1ModelRunnerOutput):
|
||||
engine_core_outputs = zero_overhead_update_from_output(core.scheduler,
|
||||
scheduler_output, model_output) # type: ignore
|
||||
else:
|
||||
engine_core_outputs = core.scheduler.update_from_output(
|
||||
scheduler_output, model_output) # type: ignore
|
||||
|
||||
return (engine_core_outputs,
|
||||
scheduler_output.total_num_scheduled_tokens > 0)
|
||||
317
vllm/zero_overhead/v1/eagle.py
Normal file
317
vllm/zero_overhead/v1/eagle.py
Normal file
@@ -0,0 +1,317 @@
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
|
||||
|
||||
|
||||
class V1ZeroEagleProposer(EagleProposer):
|
||||
def __init__(self, vllm_config, device, runner=None):
|
||||
super().__init__(vllm_config, device, runner)
|
||||
self.spec_scheduler_max_num_tokens = 0
|
||||
|
||||
|
||||
def propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
target_token_ids: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_positions: torch.Tensor,
|
||||
# [num_tokens, hidden_size]
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_slot_mapping: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
# [batch_size + 1] starting with 0
|
||||
cu_num_tokens: torch.Tensor,
|
||||
# [batch_size, max_num_blocks_per_req]
|
||||
block_table: torch.Tensor,
|
||||
# [batch_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
decoding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
last_token_indices = cu_num_tokens[1:] - 1
|
||||
|
||||
if self.method == "eagle3":
|
||||
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||
target_hidden_states = self.model.combine_hidden_states(
|
||||
target_hidden_states)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
# FA requires seq_len to have dtype int32.
|
||||
seq_lens = (target_positions[last_token_indices] + 1).int()
|
||||
|
||||
if self.method in ["eagle", "eagle3"]:
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
max_seq_len = seq_lens.max().item()
|
||||
max_num_tokens = (cu_num_tokens[1:] -
|
||||
cu_num_tokens[:-1]).max().item()
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_num_tokens,
|
||||
query_start_loc=cu_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table,
|
||||
slot_mapping=target_slot_mapping,
|
||||
# TODO(woosuk): Support cascade attention.
|
||||
use_cascade=False,
|
||||
common_prefix_len=0,
|
||||
cu_prefix_query_lens=None,
|
||||
prefix_kv_lens=None,
|
||||
suffix_kv_lens=None,
|
||||
)
|
||||
elif self.method == "deepseek_mtp":
|
||||
max_query_len = self.spec_scheduler_max_num_tokens
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=cu_num_tokens,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=target_slot_mapping,
|
||||
spec_layer_decoding=decoding
|
||||
)
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata = self.runner.attn_metadata_builders[0].build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {self.method}")
|
||||
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
if (decoding and self.use_full_cuda_graph
|
||||
and num_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
assert self.attn_metadata_cudagraph
|
||||
if self.method in ["eagle", "eagle3"]:
|
||||
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
|
||||
attn_metadata.seq_lens)
|
||||
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
|
||||
attn_metadata.slot_mapping)
|
||||
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
|
||||
attn_metadata.query_start_loc)
|
||||
self.attn_metadata_cudagraph.block_table[:batch_size] = (
|
||||
attn_metadata.block_table)
|
||||
elif self.method == "deepseek_mtp":
|
||||
self.attn_metadata_cudagraph.num_actual_tokens = (
|
||||
attn_metadata.num_actual_tokens)
|
||||
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
|
||||
attn_metadata.query_start_loc)
|
||||
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
|
||||
attn_metadata.slot_mapping)
|
||||
self.attn_metadata_cudagraph.num_decodes = (
|
||||
attn_metadata.num_decodes)
|
||||
self.attn_metadata_cudagraph.num_decode_tokens = (
|
||||
attn_metadata.num_decode_tokens)
|
||||
self.attn_metadata_cudagraph.num_prefills = (
|
||||
attn_metadata.num_prefills)
|
||||
|
||||
if attn_metadata.decode is not None:
|
||||
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.decode.block_table)
|
||||
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.decode.seq_lens)
|
||||
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
skip_cuda_graphs=not decoding):
|
||||
ret_hidden_states = self.model(
|
||||
self.input_ids[:num_input_tokens],
|
||||
self.positions[:num_input_tokens],
|
||||
self.hidden_states[:num_input_tokens],
|
||||
)
|
||||
if self.method == "deepseek_mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
# TODO: Currently, MTP module released by deepseek only has
|
||||
# one layer. Adapt this code to support multiple layers once
|
||||
# there's a multi-layer MTP module.
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
positions = target_positions[last_token_indices]
|
||||
|
||||
if self.method == "deepseek_mtp":
|
||||
hidden_states = last_hidden_states[last_token_indices]
|
||||
else:
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
|
||||
if self.use_cuda_graph and \
|
||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.max_query_len = 1
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
attn_metadata.num_decodes = batch_size
|
||||
attn_metadata.num_decode_tokens = batch_size
|
||||
attn_metadata.num_prefills = 0
|
||||
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
|
||||
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
|
||||
block_table_tensor=block_table,
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
|
||||
for i in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
# but adjust the position ids and slot mappings to avoid the
|
||||
# out-of-range access during the model execution. The draft tokens
|
||||
# generated with this adjustment should be ignored.
|
||||
exceeds_max_model_len = positions >= self.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
attn_metadata.decode.seq_lens += 1
|
||||
else:
|
||||
attn_metadata.seq_lens += 1
|
||||
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata.max_seq_len += 1
|
||||
# Consider max model length.
|
||||
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
|
||||
self.max_model_len)
|
||||
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_numbers = clamped_positions // self.block_size
|
||||
block_ids = block_table.gather(dim=1,
|
||||
index=block_numbers.view(-1, 1))
|
||||
block_ids = block_ids.view(-1)
|
||||
attn_metadata.slot_mapping = (block_ids * self.block_size +
|
||||
clamped_positions % self.block_size)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
|
||||
PADDING_SLOT_ID)
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
|
||||
if (self.use_full_cuda_graph
|
||||
and batch_size <= self.cudagraph_batch_sizes[-1]):
|
||||
assert self.attn_metadata_cudagraph
|
||||
if self.method in ["eagle", "eagle3"]:
|
||||
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
|
||||
attn_metadata.seq_lens)
|
||||
self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
|
||||
attn_metadata.slot_mapping)
|
||||
if i == 0:
|
||||
self.attn_metadata_cudagraph.query_start_loc[:batch_size +
|
||||
1] = (
|
||||
attn_metadata
|
||||
.
|
||||
query_start_loc
|
||||
)
|
||||
self.attn_metadata_cudagraph.block_table[:batch_size] = (
|
||||
attn_metadata.block_table)
|
||||
elif self.method == "deepseek_mtp":
|
||||
self.attn_metadata_cudagraph.num_actual_tokens = (
|
||||
attn_metadata.num_actual_tokens)
|
||||
self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.slot_mapping)
|
||||
self.attn_metadata_cudagraph.num_decodes = (
|
||||
attn_metadata.num_decodes)
|
||||
self.attn_metadata_cudagraph.num_decode_tokens = (
|
||||
attn_metadata.num_decode_tokens)
|
||||
self.attn_metadata_cudagraph.num_prefills = (
|
||||
attn_metadata.num_prefills)
|
||||
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.decode.seq_lens)
|
||||
|
||||
if i == 0:
|
||||
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
|
||||
attn_metadata.query_start_loc)
|
||||
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.decode.block_table)
|
||||
|
||||
# Run the model.
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
ret_hidden_states = self.model(
|
||||
self.input_ids[:input_batch_size],
|
||||
self.positions[:input_batch_size],
|
||||
self.hidden_states[:input_batch_size],
|
||||
)
|
||||
if self.method == "deepseek_mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states[:batch_size]
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||
None)
|
||||
|
||||
# TODO(wenlong): get more than one token for tree attention
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
# [batch_size, num_speculative_tokens]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
|
||||
return draft_token_ids
|
||||
749
vllm/zero_overhead/v1/gpu_model_runner.py
Normal file
749
vllm/zero_overhead/v1/gpu_model_runner.py
Normal file
@@ -0,0 +1,749 @@
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
import torch
|
||||
import numpy as np
|
||||
from vllm import envs
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import async_tensor_h2d, round_up
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.zero_overhead.v1.eagle import V1ZeroEagleProposer
|
||||
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
|
||||
from vllm.profiler.prof import profile
|
||||
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
|
||||
|
||||
|
||||
class V1ZeroModelRunner(GPUModelRunner):
|
||||
def __init__(self, vllm_config, device):
|
||||
super().__init__(vllm_config, device)
|
||||
self.last_sampled_token_ids = None
|
||||
self.last_sampled_req_ids = []
|
||||
self.last_sampled_token_lens = []
|
||||
self.last_sampler_event = torch.cuda.Event(enable_timing=False)
|
||||
self.last_sampler_host_tokens = None
|
||||
self.token_ids_cpu_fix_record = []
|
||||
self.last_draft_token_ids = None
|
||||
self.last_draft_host_tokens = None
|
||||
self.last_draft_event = torch.cuda.Event(enable_timing=False)
|
||||
self.spec_sampler_event = torch.cuda.Event(enable_timing=False)
|
||||
self.spec_scheduler_max_num_tokens = 0
|
||||
if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
|
||||
self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device,
|
||||
self)
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[dict[str, Any], bool, torch.Tensor,
|
||||
Optional[SpecDecodeMetadata], np.ndarray]:
|
||||
"""
|
||||
:return: tuple[
|
||||
attn_metadata: layer-to-attention_metadata mapping,
|
||||
attention_cuda_graphs: whether attention can run in cudagraph
|
||||
logits_indices, spec_decode_metadata
|
||||
]
|
||||
"""
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit(num_reqs)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
req_ids = self.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
self.spec_scheduler_max_num_tokens = max_num_scheduled_tokens
|
||||
|
||||
# Get request indices.
|
||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens)
|
||||
|
||||
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
||||
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Get positions.
|
||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
out=positions_np)
|
||||
|
||||
# Calculate M-RoPE positions.
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
|
||||
# Get token indices.
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
||||
# where M is the max_model_len.
|
||||
token_indices = (positions_np +
|
||||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices),
|
||||
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Calculate the slot mapping for each KV cache group.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
block_size = kv_cache_group_spec.kv_cache_spec.block_size
|
||||
block_table: BlockTable = self.input_batch.block_table[
|
||||
kv_cache_group_id]
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
block_table_indices = (
|
||||
req_indices * block_table.max_num_blocks_per_req +
|
||||
positions_np // block_size)
|
||||
block_table_cpu = block_table.get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten(
|
||||
)[block_table_indices].numpy()
|
||||
block_offsets = positions_np % block_size
|
||||
np.add(
|
||||
block_numbers * block_size,
|
||||
block_offsets,
|
||||
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Copy the tensors to the GPU.
|
||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
|
||||
self.zero_prepare_inputs(scheduler_output, self.input_ids)
|
||||
|
||||
if self.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
||||
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
else:
|
||||
# Common case (1D positions)
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
self.query_start_loc[:num_reqs + 1].copy_(
|
||||
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
||||
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
self.query_start_loc[num_reqs + 1:].fill_(
|
||||
self.query_start_loc_cpu[num_reqs].item())
|
||||
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
)
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
builder = self.attn_metadata_builders[kv_cache_group_id]
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id],
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
builder,
|
||||
)
|
||||
|
||||
attn_metadata_i = (builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
attention_cuda_graphs = all(
|
||||
b.can_run_in_cudagraph(common_attn_metadata)
|
||||
for b in self.attn_metadata_builders)
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if not use_spec_decode:
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||||
# partial requests. While we should not sample any token
|
||||
# from these partial requests, we do so for simplicity.
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
spec_decode_metadata = None
|
||||
else:
|
||||
# Get the number of draft tokens for each request.
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
# requests have draft tokens.
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for req_id, draft_token_ids in (
|
||||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
return (attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
spec_decode_metadata, num_scheduled_tokens)
|
||||
|
||||
def zero_prepare_inputs(self, scheduler_output, input_ids):
|
||||
req_ids = self.input_batch.req_ids
|
||||
update_req_indices = []
|
||||
input_ids_indices = []
|
||||
token_idx = 0
|
||||
if self.last_draft_token_ids is not None:
|
||||
draft_tokens_num = self.last_draft_token_ids.shape[1]
|
||||
for req_id in req_ids:
|
||||
if req_id in self.last_sampled_req_ids:
|
||||
req_idx = self.last_sampled_req_ids.index(req_id) * draft_tokens_num
|
||||
for num_idx in range(draft_tokens_num):
|
||||
update_req_indices.append(req_idx + num_idx)
|
||||
input_ids_indices.append(token_idx + num_idx + 1)
|
||||
token_idx += draft_tokens_num + 1
|
||||
if len(update_req_indices) > 0:
|
||||
update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
|
||||
self.device,
|
||||
True)
|
||||
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
|
||||
self.device,
|
||||
True)
|
||||
last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int)
|
||||
input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor]
|
||||
|
||||
update_req_indices = []
|
||||
input_ids_indices = []
|
||||
token_idx = 0
|
||||
if self.last_sampled_token_ids is not None:
|
||||
sampled_tokens_num = self.last_sampled_token_ids.shape[1]
|
||||
for req_id in req_ids:
|
||||
if req_id in self.last_sampled_req_ids:
|
||||
req_idx = self.last_sampled_req_ids.index(req_id) * sampled_tokens_num
|
||||
update_req_indices.append(req_idx)
|
||||
input_ids_indices.append(token_idx)
|
||||
token_idx += scheduler_output.num_scheduled_tokens[req_id]
|
||||
if len(update_req_indices) > 0:
|
||||
update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
|
||||
self.device,
|
||||
True)
|
||||
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
|
||||
self.device,
|
||||
True)
|
||||
last_sampled_token_ids = self.last_sampled_token_ids.flatten()
|
||||
for i in range(sampled_tokens_num):
|
||||
input_ids[input_ids_indices_tensor + i] = last_sampled_token_ids[update_req_indices_tensor + i]
|
||||
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
num_accepted_tokens_tensor: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
sample_hidden_states: torch.Tensor,
|
||||
aux_hidden_states: Optional[torch.Tensor],
|
||||
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
||||
attn_metadata: dict[str, Any],
|
||||
) -> list[list[int]]:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if self.speculative_config.method == "ngram":
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
spec_token_ids = self.propose_ngram_draft_token_ids(
|
||||
sampled_token_ids)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
if sample_hidden_states.shape[0] == len(sampled_token_ids):
|
||||
# The input to the target model does not include draft tokens.
|
||||
hidden_states = sample_hidden_states
|
||||
else:
|
||||
indices = []
|
||||
offset = 0
|
||||
for num_draft, tokens in zip(
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
sampled_token_ids):
|
||||
indices.append(offset + len(tokens) - 1)
|
||||
offset += num_draft + 1
|
||||
indices = torch.tensor(indices, device=self.device)
|
||||
hidden_states = sample_hidden_states[indices]
|
||||
|
||||
spec_token_ids = self.drafter.propose(
|
||||
target_hidden_states=hidden_states,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
elif self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# TODO(woosuk): Refactor the loop.
|
||||
row_indices = torch.arange(sampled_token_ids.size(0), device=sampled_token_ids.device)
|
||||
next_token_ids = sampled_token_ids[row_indices, num_accepted_tokens_tensor].flatten()
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
eagle_attn_metadata = attn_metadata[
|
||||
self.drafter.attn_layer_names[0]]
|
||||
|
||||
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
|
||||
if hasattr(eagle_attn_metadata, "block_table"):
|
||||
block_table = eagle_attn_metadata.block_table
|
||||
else:
|
||||
block_table = None
|
||||
|
||||
spec_scheduler_max_num_tokens = self.spec_scheduler_max_num_tokens
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
# TODO(woosuk): Support M-RoPE.
|
||||
target_positions = self.positions[:num_scheduled_tokens]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping
|
||||
cu_num_tokens = eagle_attn_metadata.query_start_loc
|
||||
else:
|
||||
# TODO(woosuk): Refactor this.
|
||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||
eagle_attn_metadata.query_start_loc,
|
||||
num_accepted_tokens_tensor,
|
||||
)
|
||||
spec_scheduler_max_num_tokens = 1
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
# TODO(woosuk): Support M-RoPE.
|
||||
target_positions = self.positions[token_indices]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping[
|
||||
token_indices]
|
||||
self.drafter.spec_scheduler_max_num_tokens = spec_scheduler_max_num_tokens
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=block_table,
|
||||
sampling_metadata=sampling_metadata,
|
||||
decoding=spec_decode_metadata is not None,
|
||||
)
|
||||
spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist()
|
||||
self.last_draft_token_ids = draft_token_ids
|
||||
self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True)
|
||||
self.last_draft_event.record()
|
||||
return spec_token_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
spec_decode_metadata,
|
||||
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.use_cuda_graph
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_scheduled_tokens)
|
||||
else:
|
||||
# Eager mode.
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if self.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism and tp_size > 1:
|
||||
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
||||
else:
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
|
||||
# Padding for DP
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
|
||||
num_input_tokens += num_pad
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
# modal outputs after that to ensure the correct order
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
else:
|
||||
mm_embeds = []
|
||||
|
||||
if self.is_multimodal_model and get_pp_group().is_first_rank:
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
input_ids = self.input_ids[:num_scheduled_tokens]
|
||||
if mm_embeds:
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids, mm_embeds)
|
||||
else:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
input_ids = None
|
||||
else:
|
||||
# For text-only models, we use token ids as input.
|
||||
# While it is possible to use embeddings as input just like the
|
||||
# multimodal models, it is not desirable for performance since
|
||||
# then the embedding layer is not included in the CUDA graph.
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
else:
|
||||
positions = self.positions[:num_input_tokens]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
else:
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_input_tokens, intermediate_tensors, True)
|
||||
|
||||
# Some attention backends only support CUDA Graphs in pure decode.
|
||||
# If attention doesn't support CUDA Graphs for this batch, but we
|
||||
# compiled with full CUDA graphs, we have to skip them entirely.
|
||||
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
|
||||
if envs.VLLM_ENABLE_TBO and (not self.use_cuda_graph or skip_cuda_graphs):
|
||||
model_output, finished_sending, finished_recving = \
|
||||
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
|
||||
num_tokens_across_dp, input_ids, positions,
|
||||
inputs_embeds, scheduler_output, intermediate_tensors,
|
||||
skip_cuda_graphs)
|
||||
else:
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs,
|
||||
):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
# Broadcast PP output for external_launcher (torchrun)
|
||||
# to make sure we are synced across pp ranks
|
||||
# TODO: Support overlapping mirco-batches
|
||||
# https://github.com/vllm-project/vllm/issues/18019
|
||||
broadcast_pp_output = \
|
||||
self.parallel_config.distributed_executor_backend \
|
||||
== "external_launcher" and len(get_pp_group().ranks) > 0
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
if not broadcast_pp_output:
|
||||
return hidden_states
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
logits = None
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving)
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
model_output_broadcast_data = {
|
||||
"logits": logits.contiguous(),
|
||||
} if logits is not None else {}
|
||||
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
|
||||
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
|
||||
assert model_output_broadcast_data is not None
|
||||
logits = model_output_broadcast_data["logits"]
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
|
||||
num_nans_in_logits = {}
|
||||
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
||||
num_nans_in_logits = self._get_nans_in_logits(logits)
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
discard_sampled_tokens_req_indices = []
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
if seq_len < req_state.num_tokens:
|
||||
# Ignore the sampled token for partial prefills.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
# Record the index of the request that should not be sampled,
|
||||
# so that we could clear the sampled tokens before returning.
|
||||
discard_sampled_tokens_req_indices.append(i)
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
logprobs_lists = logprobs_tensors.tolists() \
|
||||
if logprobs_tensors is not None else None
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
scheduler_output,
|
||||
)
|
||||
|
||||
fix_req_ids = None
|
||||
fix_sampled_token_ids = None
|
||||
fix_draft_token_ids = None
|
||||
fix_draft_req_ids = self.last_sampled_req_ids
|
||||
is_output_valid = False
|
||||
# Get the valid generated tokens.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if not self.speculative_config:
|
||||
# Speculative decoding is not enabled.
|
||||
spec_token_ids = None
|
||||
fix_draft_req_ids = None
|
||||
else:
|
||||
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
|
||||
self.spec_sampler_event.record()
|
||||
if self.last_draft_host_tokens is not None:
|
||||
self.last_draft_event.synchronize()
|
||||
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
|
||||
|
||||
mask = (sampled_token_ids == -1)
|
||||
mask_int = mask.int()
|
||||
first_neg_one_indices = torch.argmax(mask_int, dim=1)
|
||||
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
|
||||
spec_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
num_accepted_tokens_tensor,
|
||||
sampled_token_ids,
|
||||
sampling_metadata,
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
spec_decode_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
if self.speculative_config:
|
||||
self.spec_sampler_event.synchronize()
|
||||
if max_gen_len == 1:
|
||||
valid_sampled_token_ids = sampled_token_ids_cpu.tolist()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids_cpu,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
self.last_sampler_host_tokens = None
|
||||
self.last_sampled_token_ids = None
|
||||
is_output_valid = True
|
||||
else:
|
||||
# No spec decode tokens.
|
||||
fix_req_ids = self.last_sampled_req_ids
|
||||
if self.last_sampler_host_tokens != None:
|
||||
self.last_sampler_event.synchronize()
|
||||
fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
|
||||
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
|
||||
if start_idx == -1:
|
||||
continue
|
||||
req_id = fix_req_ids[req_idx]
|
||||
if req_id in self.input_batch.req_ids:
|
||||
new_req_idx = self.input_batch.req_ids.index(req_id)
|
||||
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx]
|
||||
for req_idx, req_id in enumerate(fix_req_ids):
|
||||
if req_id in self.requests:
|
||||
req_state = self.requests[req_id]
|
||||
token_idx = self.last_sampled_token_lens[req_idx]
|
||||
if token_idx == -1:
|
||||
continue
|
||||
fix_len = len(fix_sampled_token_ids[req_idx])
|
||||
req_state.output_token_ids[token_idx:token_idx + fix_len] = fix_sampled_token_ids[req_idx]
|
||||
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
|
||||
self.last_sampler_event.record()
|
||||
self.last_sampled_token_ids = sampled_token_ids
|
||||
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
|
||||
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
|
||||
# Cache the sampled tokens in the model runner, so that the scheduler
|
||||
# doesn't need to send them back.
|
||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
self.token_ids_cpu_fix_record.clear()
|
||||
self.last_sampled_req_ids = []
|
||||
self.last_sampled_token_lens = []
|
||||
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||
req_id = self.input_batch.req_ids[req_idx]
|
||||
self.last_sampled_req_ids.append(req_id)
|
||||
cache_output_len = -1
|
||||
if not sampled_ids:
|
||||
self.last_sampled_token_lens.append(-1)
|
||||
self.token_ids_cpu_fix_record.append([req_idx, -1, -1])
|
||||
continue
|
||||
|
||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||
end_idx = start_idx + len(sampled_ids)
|
||||
assert end_idx <= self.max_model_len, (
|
||||
"Sampled token IDs exceed the max model length. "
|
||||
f"Total number of tokens: {end_idx} > max_model_len: "
|
||||
f"{self.max_model_len}")
|
||||
|
||||
self.input_batch.token_ids_cpu[req_idx,
|
||||
start_idx:end_idx] = sampled_ids
|
||||
self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx])
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
if req_id in self.requests:
|
||||
req_state = self.requests[req_id]
|
||||
cache_output_len = len(req_state.output_token_ids)
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
self.last_sampled_token_lens.append(cache_output_len)
|
||||
|
||||
|
||||
# Clear KVConnector state after all KVs are generated.
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().clear_connector_metadata()
|
||||
|
||||
self.eplb_step()
|
||||
|
||||
model_output = ZeroV1ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
fix_req_ids = fix_req_ids,
|
||||
fix_sampled_token_ids = fix_sampled_token_ids,
|
||||
fix_draft_tokens_ids = fix_draft_token_ids,
|
||||
fix_draft_req_ids = fix_draft_req_ids,
|
||||
is_output_valid=is_output_valid
|
||||
)
|
||||
return model_output
|
||||
14
vllm/zero_overhead/v1/outputs.py
Normal file
14
vllm/zero_overhead/v1/outputs.py
Normal file
@@ -0,0 +1,14 @@
|
||||
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
@dataclass
|
||||
class ZeroV1ModelRunnerOutput(ModelRunnerOutput):
|
||||
# [num_reqs]
|
||||
fix_req_ids: list[str] = None
|
||||
fix_sampled_token_ids:list[list[int]] = None
|
||||
fix_draft_req_ids:list[str] = None
|
||||
fix_draft_tokens_ids:list[list[int]] = None
|
||||
is_output_valid:bool = True
|
||||
Reference in New Issue
Block a user