Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
47
vllm/v1/worker/gpu/spec_decode/utils.py
Normal file
47
vllm/v1/worker/gpu/spec_decode/utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import DraftTokenIds
|
||||
from vllm.v1.worker.gpu.async_utils import async_copy_to_np
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
|
||||
|
||||
class DraftTokensHandler:
|
||||
def __init__(self, device: torch.device | None = None):
|
||||
self.device = device
|
||||
self.copy_stream = torch.cuda.Stream(device)
|
||||
self.copy_event = torch.cuda.Event()
|
||||
|
||||
self.req_ids: list[str] = []
|
||||
self.draft_tokens_np: np.ndarray | None = None
|
||||
self.num_draft_tokens: int = 0
|
||||
|
||||
def set_draft_tokens(
|
||||
self, input_batch: InputBatch, draft_tokens: torch.Tensor
|
||||
) -> None:
|
||||
self.req_ids = input_batch.req_ids
|
||||
self.num_draft_tokens = draft_tokens.shape[1]
|
||||
if not input_batch.has_structured_output_reqs:
|
||||
# No draft token validation needs to be performed by
|
||||
# the scheduler for this batch.
|
||||
self.draft_tokens_np = None
|
||||
return
|
||||
|
||||
# For spec decoding + structured outputs, we must transfer the
|
||||
# draft tokens back to the scheduler for grammar validation.
|
||||
current_stream = torch.cuda.current_stream(self.device)
|
||||
self.copy_stream.wait_stream(current_stream)
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
self.draft_tokens_np = async_copy_to_np(draft_tokens)
|
||||
self.copy_event.record()
|
||||
|
||||
def get_draft_tokens(self) -> DraftTokenIds | None:
|
||||
if self.draft_tokens_np is not None:
|
||||
self.copy_event.synchronize()
|
||||
draft_token_ids = self.draft_tokens_np.tolist()
|
||||
else:
|
||||
# This case only happens when async scheduling is disabled.
|
||||
draft_token_ids = [[-1] * self.num_draft_tokens for _ in self.req_ids]
|
||||
return DraftTokenIds(self.req_ids, draft_token_ids)
|
||||
Reference in New Issue
Block a user