172 lines
8.7 KiB
Python
172 lines
8.7 KiB
Python
|
|
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import itertools
|
||
|
|
from typing import List, Optional, Set
|
||
|
|
from vllm.lora.layers import LoRAMapping
|
||
|
|
from vllm.multimodal.inputs import MultiModalKwargs
|
||
|
|
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||
|
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||
|
|
from vllm.sequence import SequenceGroupMetadata
|
||
|
|
from vllm.utils import async_tensor_h2d, flatten_2d_lists
|
||
|
|
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
|
||
|
|
from vllm.zero_overhead.sampler import get_last_sampler
|
||
|
|
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_proposal_token_ids, get_spec_last_step, get_spec_step
|
||
|
|
|
||
|
|
import triton
|
||
|
|
import triton.language as tl
|
||
|
|
|
||
|
|
@triton.jit
|
||
|
|
def _update_input_tokens(
|
||
|
|
accepted_req_ids,
|
||
|
|
accepted_req_ids_len,
|
||
|
|
accepted_token_ids,
|
||
|
|
accepted_token_len,
|
||
|
|
chidren_req_ids,
|
||
|
|
chidren_req_ids_len,
|
||
|
|
input_tokens,
|
||
|
|
input_tokens_len,
|
||
|
|
input_positions,
|
||
|
|
seq_lens,
|
||
|
|
seq_lens_meta,
|
||
|
|
seq_lens_tensor,
|
||
|
|
slot_mapping,
|
||
|
|
seq_start_loc,
|
||
|
|
context_lens_tensor,
|
||
|
|
):
|
||
|
|
chidren_req_ids_ = tl.load(chidren_req_ids + tl.arange(0, chidren_req_ids_len))
|
||
|
|
accepted_req_ids_ = tl.load(accepted_req_ids + tl.arange(0, chidren_req_ids_len))
|
||
|
|
|
||
|
|
for seq_id_idx in range(chidren_req_ids_len / 2):
|
||
|
|
seq_id = chidren_req_ids_[2 * seq_id_idx]
|
||
|
|
for i in range(accepted_req_ids_len):
|
||
|
|
if seq_id == accepted_req_ids_[i]:
|
||
|
|
accepted_token_ids_ = tl.load(accepted_token_ids + tl.arange(i * accepted_token_len, tl.arange(0, accepted_token_len)))
|
||
|
|
accepted_token_counter = 0
|
||
|
|
for j in range(accepted_token_len):
|
||
|
|
if accepted_token_ids_[j] == -1:
|
||
|
|
break
|
||
|
|
accepted_token_counter += 1
|
||
|
|
if accepted_token_counter == accepted_token_len:
|
||
|
|
tl.store(input_tokens + seq_id_idx * 2 + tl.arange(0, 2), accepted_token_ids_[-2:])
|
||
|
|
else:
|
||
|
|
tl.store(input_tokens + seq_id_idx * 2, 0)
|
||
|
|
tl.store(input_tokens + seq_id_idx * 2 + 1, accepted_token_ids_[accepted_token_counter - 1])
|
||
|
|
input_pos = tl.load(input_positions + seq_id_idx * 2 + tl.arange(0, 2))
|
||
|
|
input_pos[0] = 0
|
||
|
|
input_pos[1] = input_pos[1] - (accepted_req_ids_len - accepted_token_counter)
|
||
|
|
tl.store(input_positions + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||
|
|
tl.store(context_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||
|
|
input_pos[0] = -1
|
||
|
|
tl.store(slot_mapping + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||
|
|
input_pos[0] = 1
|
||
|
|
input_pos[1] = input_pos[1] + 1
|
||
|
|
tl.store(seq_lens + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||
|
|
tl.store(seq_lens_meta + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||
|
|
tl.store(seq_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||
|
|
seq_lens_ = tl.load(seq_lens + tl.arange(0, input_tokens_len))
|
||
|
|
seq_start_loc_ = tl.zero_like(seq_start_loc)
|
||
|
|
for i in range(input_tokens_len):
|
||
|
|
seq_start_loc_[i + 1] = seq_start_loc_[i] + seq_lens_[i]
|
||
|
|
tl.store(seq_start_loc + tl.arange(0, input_tokens_len + 1), seq_start_loc_)
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
|
||
|
|
def __init__(self, runner, finished_requests_ids = None):
|
||
|
|
super().__init__(runner, finished_requests_ids)
|
||
|
|
self.req_ids = []
|
||
|
|
|
||
|
|
def prepare(self,
|
||
|
|
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||
|
|
self.req_ids.clear()
|
||
|
|
return super().prepare(finished_requests_ids)
|
||
|
|
|
||
|
|
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||
|
|
seq_ids = seq_group_metadata.seq_data.keys()
|
||
|
|
n_seqs = len(seq_ids)
|
||
|
|
seq_ids = list(seq_ids)
|
||
|
|
for seq_idx in range(n_seqs):
|
||
|
|
self.req_ids.append(seq_ids[seq_idx])
|
||
|
|
return super().add_seq_group(seq_group_metadata)
|
||
|
|
|
||
|
|
def build(self) -> ModelInputForGPU:
|
||
|
|
model_input = super().build()
|
||
|
|
last_sampler = get_last_sampler()
|
||
|
|
spec_step = get_spec_step()
|
||
|
|
last_step = get_spec_last_step()
|
||
|
|
if last_sampler is not None:
|
||
|
|
if spec_step == SpecStepKind.KIND_DEFAULT:
|
||
|
|
update_indices = []
|
||
|
|
select_indices = []
|
||
|
|
query_idx = 0
|
||
|
|
for i, seq_id in enumerate(self.req_ids):
|
||
|
|
for j, seq_id_ in enumerate(last_sampler.seq_ids):
|
||
|
|
if seq_id == seq_id_:
|
||
|
|
select_indices.append(j)
|
||
|
|
update_indices.append(query_idx)
|
||
|
|
break
|
||
|
|
query_idx += model_input.query_lens[i]
|
||
|
|
if len(select_indices) > 0 and last_sampler.sampled_token_ids_tensor is not None:
|
||
|
|
select_indices = async_tensor_h2d(select_indices, torch.long,
|
||
|
|
self.runner.device,
|
||
|
|
self.runner.pin_memory)
|
||
|
|
update_indices = async_tensor_h2d(update_indices, torch.long,
|
||
|
|
self.runner.device,
|
||
|
|
self.runner.pin_memory)
|
||
|
|
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
|
||
|
|
if spec_step == SpecStepKind.OTHER_PROPOSAL:
|
||
|
|
if last_step == SpecStepKind.OTHER_PROPOSAL: # copy last sampled token ids to input tokens directly.
|
||
|
|
update_indices = [i for i in range(len(self.req_ids))]
|
||
|
|
update_indices = async_tensor_h2d(update_indices, torch.long,
|
||
|
|
self.runner.device,
|
||
|
|
self.runner.pin_memory)
|
||
|
|
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
|
||
|
|
if last_step == SpecStepKind.FIRST_PROPOSAL: # TODO: ajust input tokens number to 1 per request.
|
||
|
|
update_indices = [i for i in range(len(self.req_ids))]
|
||
|
|
update_indices = async_tensor_h2d(update_indices, torch.long,
|
||
|
|
self.runner.device,
|
||
|
|
self.runner.pin_memory)
|
||
|
|
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
|
||
|
|
|
||
|
|
if spec_step == SpecStepKind.SCORE_DECODE:
|
||
|
|
proposal_token_ids = get_proposal_token_ids()
|
||
|
|
shape = proposal_token_ids.shape
|
||
|
|
batch_size = shape[0]
|
||
|
|
proposal_len = shape[1]
|
||
|
|
update_indices = []
|
||
|
|
for i in range(batch_size):
|
||
|
|
for j in range(proposal_len):
|
||
|
|
update_indices.append(i * (proposal_len + 1) + j + 1)
|
||
|
|
|
||
|
|
update_indices = async_tensor_h2d(update_indices, torch.long,
|
||
|
|
self.runner.device,
|
||
|
|
self.runner.pin_memory)
|
||
|
|
model_input.input_tokens[update_indices] = proposal_token_ids.view(-1)
|
||
|
|
if spec_step == SpecStepKind.FIRST_PROPOSAL:
|
||
|
|
if last_step == SpecStepKind.PREFILL:# TODO: when last step is prefill, just update the input ids for last seqence_id onely.
|
||
|
|
pass
|
||
|
|
if last_step == SpecStepKind.SCORE_DECODE:# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids
|
||
|
|
accept_token_ids, accept_seq_ids = get_accepted_token_ids()
|
||
|
|
|
||
|
|
chidren_req_ids = async_tensor_h2d(self.req_ids, torch.long,
|
||
|
|
self.runner.device,
|
||
|
|
self.runner.pin_memory)
|
||
|
|
grid = [1, 1, 1]
|
||
|
|
_update_input_tokens[grid](
|
||
|
|
accept_seq_ids, accept_seq_ids.shape[0],
|
||
|
|
accept_token_ids, accept_token_ids.shape[1],
|
||
|
|
chidren_req_ids, chidren_req_ids.shape[0],
|
||
|
|
model_input.input_tokens, model_input.input_tokens.shape[0],
|
||
|
|
model_input.input_positions,
|
||
|
|
model_input.seq_lens,
|
||
|
|
model_input.attn_metadata.seq_lens_tensor,
|
||
|
|
model_input.attn_metadata.seq_lens,
|
||
|
|
model_input.attn_metadata.slot_mapping,
|
||
|
|
model_input.attn_metadata.seq_start_loc,
|
||
|
|
model_input.attn_metadata.context_lens_tensor,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
return model_input
|