import torch import itertools from typing import List, Optional, Set from vllm.lora.layers import LoRAMapping from vllm.multimodal.inputs import MultiModalKwargs from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import SequenceGroupMetadata from vllm.utils import async_tensor_h2d, flatten_2d_lists from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder from vllm.zero_overhead.sampler import get_last_sampler from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_proposal_token_ids, get_spec_last_step, get_spec_step import triton import triton.language as tl @triton.jit def _update_input_tokens( accepted_req_ids, accepted_req_ids_len, accepted_token_ids, accepted_token_len, chidren_req_ids, chidren_req_ids_len, input_tokens, input_tokens_len, input_positions, seq_lens, seq_lens_meta, seq_lens_tensor, slot_mapping, seq_start_loc, context_lens_tensor, ): chidren_req_ids_ = tl.load(chidren_req_ids + tl.arange(0, chidren_req_ids_len)) accepted_req_ids_ = tl.load(accepted_req_ids + tl.arange(0, chidren_req_ids_len)) for seq_id_idx in range(chidren_req_ids_len / 2): seq_id = chidren_req_ids_[2 * seq_id_idx] for i in range(accepted_req_ids_len): if seq_id == accepted_req_ids_[i]: accepted_token_ids_ = tl.load(accepted_token_ids + tl.arange(i * accepted_token_len, tl.arange(0, accepted_token_len))) accepted_token_counter = 0 for j in range(accepted_token_len): if accepted_token_ids_[j] == -1: break accepted_token_counter += 1 if accepted_token_counter == accepted_token_len: tl.store(input_tokens + seq_id_idx * 2 + tl.arange(0, 2), accepted_token_ids_[-2:]) else: tl.store(input_tokens + seq_id_idx * 2, 0) tl.store(input_tokens + seq_id_idx * 2 + 1, accepted_token_ids_[accepted_token_counter - 1]) input_pos = tl.load(input_positions + seq_id_idx * 2 + tl.arange(0, 2)) input_pos[0] = 0 input_pos[1] = input_pos[1] - (accepted_req_ids_len - accepted_token_counter) tl.store(input_positions + seq_id_idx * 2 + tl.arange(0, 2), input_pos) tl.store(context_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos) input_pos[0] = -1 tl.store(slot_mapping + seq_id_idx * 2 + tl.arange(0, 2), input_pos) input_pos[0] = 1 input_pos[1] = input_pos[1] + 1 tl.store(seq_lens + seq_id_idx * 2 + tl.arange(0, 2), input_pos) tl.store(seq_lens_meta + seq_id_idx * 2 + tl.arange(0, 2), input_pos) tl.store(seq_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos) seq_lens_ = tl.load(seq_lens + tl.arange(0, input_tokens_len)) seq_start_loc_ = tl.zero_like(seq_start_loc) for i in range(input_tokens_len): seq_start_loc_[i + 1] = seq_start_loc_[i] + seq_lens_[i] tl.store(seq_start_loc + tl.arange(0, input_tokens_len + 1), seq_start_loc_) class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder): def __init__(self, runner, finished_requests_ids = None): super().__init__(runner, finished_requests_ids) self.req_ids = [] def prepare(self, finished_requests_ids: Optional[List[str]] = None) -> None: self.req_ids.clear() return super().prepare(finished_requests_ids) def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): seq_ids = seq_group_metadata.seq_data.keys() n_seqs = len(seq_ids) seq_ids = list(seq_ids) for seq_idx in range(n_seqs): self.req_ids.append(seq_ids[seq_idx]) return super().add_seq_group(seq_group_metadata) def build(self) -> ModelInputForGPU: model_input = super().build() last_sampler = get_last_sampler() spec_step = get_spec_step() last_step = get_spec_last_step() if last_sampler is not None: if spec_step == SpecStepKind.KIND_DEFAULT: update_indices = [] select_indices = [] query_idx = 0 for i, seq_id in enumerate(self.req_ids): for j, seq_id_ in enumerate(last_sampler.seq_ids): if seq_id == seq_id_: select_indices.append(j) update_indices.append(query_idx) break query_idx += model_input.query_lens[i] if len(select_indices) > 0 and last_sampler.sampled_token_ids_tensor is not None: select_indices = async_tensor_h2d(select_indices, torch.long, self.runner.device, self.runner.pin_memory) update_indices = async_tensor_h2d(update_indices, torch.long, self.runner.device, self.runner.pin_memory) model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0] if spec_step == SpecStepKind.OTHER_PROPOSAL: if last_step == SpecStepKind.OTHER_PROPOSAL: # copy last sampled token ids to input tokens directly. update_indices = [i for i in range(len(self.req_ids))] update_indices = async_tensor_h2d(update_indices, torch.long, self.runner.device, self.runner.pin_memory) model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0] if last_step == SpecStepKind.FIRST_PROPOSAL: # TODO: ajust input tokens number to 1 per request. update_indices = [i for i in range(len(self.req_ids))] update_indices = async_tensor_h2d(update_indices, torch.long, self.runner.device, self.runner.pin_memory) model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0] if spec_step == SpecStepKind.SCORE_DECODE: proposal_token_ids = get_proposal_token_ids() shape = proposal_token_ids.shape batch_size = shape[0] proposal_len = shape[1] update_indices = [] for i in range(batch_size): for j in range(proposal_len): update_indices.append(i * (proposal_len + 1) + j + 1) update_indices = async_tensor_h2d(update_indices, torch.long, self.runner.device, self.runner.pin_memory) model_input.input_tokens[update_indices] = proposal_token_ids.view(-1) if spec_step == SpecStepKind.FIRST_PROPOSAL: if last_step == SpecStepKind.PREFILL:# TODO: when last step is prefill, just update the input ids for last seqence_id onely. pass if last_step == SpecStepKind.SCORE_DECODE:# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids accept_token_ids, accept_seq_ids = get_accepted_token_ids() chidren_req_ids = async_tensor_h2d(self.req_ids, torch.long, self.runner.device, self.runner.pin_memory) grid = [1, 1, 1] _update_input_tokens[grid]( accept_seq_ids, accept_seq_ids.shape[0], accept_token_ids, accept_token_ids.shape[1], chidren_req_ids, chidren_req_ids.shape[0], model_input.input_tokens, model_input.input_tokens.shape[0], model_input.input_positions, model_input.seq_lens, model_input.attn_metadata.seq_lens_tensor, model_input.attn_metadata.seq_lens, model_input.attn_metadata.slot_mapping, model_input.attn_metadata.seq_start_loc, model_input.attn_metadata.context_lens_tensor, ) return model_input