init src 0.9.2
This commit is contained in:
171
vllm/zero_overhead/model_runner.py
Normal file
171
vllm/zero_overhead/model_runner.py
Normal file
@@ -0,0 +1,171 @@
|
||||
|
||||
|
||||
import torch
|
||||
import itertools
|
||||
from typing import List, Optional, Set
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.utils import async_tensor_h2d, flatten_2d_lists
|
||||
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
|
||||
from vllm.zero_overhead.sampler import get_last_sampler
|
||||
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_proposal_token_ids, get_spec_last_step, get_spec_step
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def _update_input_tokens(
|
||||
accepted_req_ids,
|
||||
accepted_req_ids_len,
|
||||
accepted_token_ids,
|
||||
accepted_token_len,
|
||||
chidren_req_ids,
|
||||
chidren_req_ids_len,
|
||||
input_tokens,
|
||||
input_tokens_len,
|
||||
input_positions,
|
||||
seq_lens,
|
||||
seq_lens_meta,
|
||||
seq_lens_tensor,
|
||||
slot_mapping,
|
||||
seq_start_loc,
|
||||
context_lens_tensor,
|
||||
):
|
||||
chidren_req_ids_ = tl.load(chidren_req_ids + tl.arange(0, chidren_req_ids_len))
|
||||
accepted_req_ids_ = tl.load(accepted_req_ids + tl.arange(0, chidren_req_ids_len))
|
||||
|
||||
for seq_id_idx in range(chidren_req_ids_len / 2):
|
||||
seq_id = chidren_req_ids_[2 * seq_id_idx]
|
||||
for i in range(accepted_req_ids_len):
|
||||
if seq_id == accepted_req_ids_[i]:
|
||||
accepted_token_ids_ = tl.load(accepted_token_ids + tl.arange(i * accepted_token_len, tl.arange(0, accepted_token_len)))
|
||||
accepted_token_counter = 0
|
||||
for j in range(accepted_token_len):
|
||||
if accepted_token_ids_[j] == -1:
|
||||
break
|
||||
accepted_token_counter += 1
|
||||
if accepted_token_counter == accepted_token_len:
|
||||
tl.store(input_tokens + seq_id_idx * 2 + tl.arange(0, 2), accepted_token_ids_[-2:])
|
||||
else:
|
||||
tl.store(input_tokens + seq_id_idx * 2, 0)
|
||||
tl.store(input_tokens + seq_id_idx * 2 + 1, accepted_token_ids_[accepted_token_counter - 1])
|
||||
input_pos = tl.load(input_positions + seq_id_idx * 2 + tl.arange(0, 2))
|
||||
input_pos[0] = 0
|
||||
input_pos[1] = input_pos[1] - (accepted_req_ids_len - accepted_token_counter)
|
||||
tl.store(input_positions + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||||
tl.store(context_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||||
input_pos[0] = -1
|
||||
tl.store(slot_mapping + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||||
input_pos[0] = 1
|
||||
input_pos[1] = input_pos[1] + 1
|
||||
tl.store(seq_lens + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||||
tl.store(seq_lens_meta + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||||
tl.store(seq_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
|
||||
seq_lens_ = tl.load(seq_lens + tl.arange(0, input_tokens_len))
|
||||
seq_start_loc_ = tl.zero_like(seq_start_loc)
|
||||
for i in range(input_tokens_len):
|
||||
seq_start_loc_[i + 1] = seq_start_loc_[i] + seq_lens_[i]
|
||||
tl.store(seq_start_loc + tl.arange(0, input_tokens_len + 1), seq_start_loc_)
|
||||
|
||||
|
||||
|
||||
class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
|
||||
def __init__(self, runner, finished_requests_ids = None):
|
||||
super().__init__(runner, finished_requests_ids)
|
||||
self.req_ids = []
|
||||
|
||||
def prepare(self,
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
self.req_ids.clear()
|
||||
return super().prepare(finished_requests_ids)
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
seq_ids = seq_group_metadata.seq_data.keys()
|
||||
n_seqs = len(seq_ids)
|
||||
seq_ids = list(seq_ids)
|
||||
for seq_idx in range(n_seqs):
|
||||
self.req_ids.append(seq_ids[seq_idx])
|
||||
return super().add_seq_group(seq_group_metadata)
|
||||
|
||||
def build(self) -> ModelInputForGPU:
|
||||
model_input = super().build()
|
||||
last_sampler = get_last_sampler()
|
||||
spec_step = get_spec_step()
|
||||
last_step = get_spec_last_step()
|
||||
if last_sampler is not None:
|
||||
if spec_step == SpecStepKind.KIND_DEFAULT:
|
||||
update_indices = []
|
||||
select_indices = []
|
||||
query_idx = 0
|
||||
for i, seq_id in enumerate(self.req_ids):
|
||||
for j, seq_id_ in enumerate(last_sampler.seq_ids):
|
||||
if seq_id == seq_id_:
|
||||
select_indices.append(j)
|
||||
update_indices.append(query_idx)
|
||||
break
|
||||
query_idx += model_input.query_lens[i]
|
||||
if len(select_indices) > 0 and last_sampler.sampled_token_ids_tensor is not None:
|
||||
select_indices = async_tensor_h2d(select_indices, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
update_indices = async_tensor_h2d(update_indices, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
|
||||
if spec_step == SpecStepKind.OTHER_PROPOSAL:
|
||||
if last_step == SpecStepKind.OTHER_PROPOSAL: # copy last sampled token ids to input tokens directly.
|
||||
update_indices = [i for i in range(len(self.req_ids))]
|
||||
update_indices = async_tensor_h2d(update_indices, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
|
||||
if last_step == SpecStepKind.FIRST_PROPOSAL: # TODO: ajust input tokens number to 1 per request.
|
||||
update_indices = [i for i in range(len(self.req_ids))]
|
||||
update_indices = async_tensor_h2d(update_indices, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
|
||||
|
||||
if spec_step == SpecStepKind.SCORE_DECODE:
|
||||
proposal_token_ids = get_proposal_token_ids()
|
||||
shape = proposal_token_ids.shape
|
||||
batch_size = shape[0]
|
||||
proposal_len = shape[1]
|
||||
update_indices = []
|
||||
for i in range(batch_size):
|
||||
for j in range(proposal_len):
|
||||
update_indices.append(i * (proposal_len + 1) + j + 1)
|
||||
|
||||
update_indices = async_tensor_h2d(update_indices, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
model_input.input_tokens[update_indices] = proposal_token_ids.view(-1)
|
||||
if spec_step == SpecStepKind.FIRST_PROPOSAL:
|
||||
if last_step == SpecStepKind.PREFILL:# TODO: when last step is prefill, just update the input ids for last seqence_id onely.
|
||||
pass
|
||||
if last_step == SpecStepKind.SCORE_DECODE:# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids
|
||||
accept_token_ids, accept_seq_ids = get_accepted_token_ids()
|
||||
|
||||
chidren_req_ids = async_tensor_h2d(self.req_ids, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
grid = [1, 1, 1]
|
||||
_update_input_tokens[grid](
|
||||
accept_seq_ids, accept_seq_ids.shape[0],
|
||||
accept_token_ids, accept_token_ids.shape[1],
|
||||
chidren_req_ids, chidren_req_ids.shape[0],
|
||||
model_input.input_tokens, model_input.input_tokens.shape[0],
|
||||
model_input.input_positions,
|
||||
model_input.seq_lens,
|
||||
model_input.attn_metadata.seq_lens_tensor,
|
||||
model_input.attn_metadata.seq_lens,
|
||||
model_input.attn_metadata.slot_mapping,
|
||||
model_input.attn_metadata.seq_start_loc,
|
||||
model_input.attn_metadata.context_lens_tensor,
|
||||
)
|
||||
|
||||
|
||||
return model_input
|
||||
Reference in New Issue
Block a user