update
This commit is contained in:
115
vllm/v1/worker/gpu/structured_outputs.py
Normal file
115
vllm/v1/worker/gpu/structured_outputs.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
|
||||
|
||||
class StructuredOutputsWorker:
|
||||
def __init__(self, max_num_logits: int, vocab_size: int, device: torch.device):
|
||||
self.logits_indices = torch.zeros(
|
||||
max_num_logits, dtype=torch.int32, device=device
|
||||
)
|
||||
self.grammar_bitmask = torch.zeros(
|
||||
(max_num_logits, cdiv(vocab_size, 32)), dtype=torch.int32, device=device
|
||||
)
|
||||
self.device = device
|
||||
self.copy_stream = torch.cuda.Stream()
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
grammar_req_ids: list[str],
|
||||
grammar_bitmask: np.ndarray,
|
||||
) -> None:
|
||||
if not grammar_req_ids:
|
||||
return
|
||||
|
||||
# Asynchronously copy the bitmask to GPU.
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
bitmask = async_copy_to_gpu(
|
||||
grammar_bitmask, out=self.grammar_bitmask[: grammar_bitmask.shape[0]]
|
||||
)
|
||||
|
||||
# Construct bitmask -> logits mapping
|
||||
mapping: list[int] = []
|
||||
req_ids = input_batch.req_ids
|
||||
cu_num_logits = input_batch.cu_num_logits_np.tolist()
|
||||
req_id_to_idx = {req_id: i for i, req_id in enumerate(req_ids)}
|
||||
for grammar_req_id in grammar_req_ids:
|
||||
req_idx = req_id_to_idx[grammar_req_id]
|
||||
logits_start_idx = cu_num_logits[req_idx]
|
||||
logits_end_idx = cu_num_logits[req_idx + 1]
|
||||
mapping.extend(range(logits_start_idx, logits_end_idx))
|
||||
|
||||
# Asynchronously copy the mapping to GPU.
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
logits_indices = torch.tensor(
|
||||
mapping, dtype=torch.int32, device="cpu", pin_memory=True
|
||||
)
|
||||
logits_indices = self.logits_indices[: len(mapping)].copy_(
|
||||
logits_indices, non_blocking=True
|
||||
)
|
||||
|
||||
# Ensure all async copies are complete before launching the kernel.
|
||||
current_stream = torch.cuda.current_stream()
|
||||
current_stream.wait_stream(self.copy_stream)
|
||||
|
||||
num_masks = bitmask.shape[0]
|
||||
assert num_masks == len(mapping)
|
||||
vocab_size = logits.shape[-1]
|
||||
BLOCK_SIZE = 8192
|
||||
grid = (num_masks, triton.cdiv(vocab_size, BLOCK_SIZE))
|
||||
_apply_grammar_bitmask_kernel[grid](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
logits_indices,
|
||||
bitmask,
|
||||
bitmask.stride(0),
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Ensure the copy stream waits for the device tensors to finish being used
|
||||
# before it re-uses or deallocates them
|
||||
self.copy_stream.wait_stream(current_stream)
|
||||
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
|
||||
@triton.jit
|
||||
def _apply_grammar_bitmask_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
logits_indices_ptr,
|
||||
bitmask_ptr,
|
||||
bitmask_stride,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
bitmask_idx = tl.program_id(0)
|
||||
logits_idx = tl.load(logits_indices_ptr + bitmask_idx)
|
||||
|
||||
# Load the bitmask.
|
||||
block_id = tl.program_id(1)
|
||||
bitmask_offset = (block_id * BLOCK_SIZE) // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
||||
packed_bitmask = tl.load(
|
||||
bitmask_ptr + bitmask_idx * bitmask_stride + bitmask_offset,
|
||||
mask=bitmask_offset < bitmask_stride,
|
||||
)
|
||||
# Unpack the bitmask.
|
||||
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
|
||||
bitmask = bitmask.reshape(BLOCK_SIZE)
|
||||
|
||||
# Apply the bitmask to the logits.
|
||||
block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(
|
||||
logits_ptr + logits_idx * logits_stride + block_offset,
|
||||
-float("inf"),
|
||||
mask=bitmask & (block_offset < vocab_size),
|
||||
)
|
||||
Reference in New Issue
Block a user