Files
enginex-hygon-vllm/vllm/zero_overhead/model_runner.py
2026-01-09 15:09:53 +08:00

172 lines
8.7 KiB
Python

import torch
import itertools
from typing import List, Optional, Set
from vllm.lora.layers import LoRAMapping
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
from vllm.zero_overhead.sampler import get_last_sampler
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_proposal_token_ids, get_spec_last_step, get_spec_step
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
accepted_req_ids,
accepted_req_ids_len,
accepted_token_ids,
accepted_token_len,
chidren_req_ids,
chidren_req_ids_len,
input_tokens,
input_tokens_len,
input_positions,
seq_lens,
seq_lens_meta,
seq_lens_tensor,
slot_mapping,
seq_start_loc,
context_lens_tensor,
):
chidren_req_ids_ = tl.load(chidren_req_ids + tl.arange(0, chidren_req_ids_len))
accepted_req_ids_ = tl.load(accepted_req_ids + tl.arange(0, chidren_req_ids_len))
for seq_id_idx in range(chidren_req_ids_len / 2):
seq_id = chidren_req_ids_[2 * seq_id_idx]
for i in range(accepted_req_ids_len):
if seq_id == accepted_req_ids_[i]:
accepted_token_ids_ = tl.load(accepted_token_ids + tl.arange(i * accepted_token_len, tl.arange(0, accepted_token_len)))
accepted_token_counter = 0
for j in range(accepted_token_len):
if accepted_token_ids_[j] == -1:
break
accepted_token_counter += 1
if accepted_token_counter == accepted_token_len:
tl.store(input_tokens + seq_id_idx * 2 + tl.arange(0, 2), accepted_token_ids_[-2:])
else:
tl.store(input_tokens + seq_id_idx * 2, 0)
tl.store(input_tokens + seq_id_idx * 2 + 1, accepted_token_ids_[accepted_token_counter - 1])
input_pos = tl.load(input_positions + seq_id_idx * 2 + tl.arange(0, 2))
input_pos[0] = 0
input_pos[1] = input_pos[1] - (accepted_req_ids_len - accepted_token_counter)
tl.store(input_positions + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(context_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
input_pos[0] = -1
tl.store(slot_mapping + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
input_pos[0] = 1
input_pos[1] = input_pos[1] + 1
tl.store(seq_lens + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(seq_lens_meta + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(seq_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
seq_lens_ = tl.load(seq_lens + tl.arange(0, input_tokens_len))
seq_start_loc_ = tl.zero_like(seq_start_loc)
for i in range(input_tokens_len):
seq_start_loc_[i + 1] = seq_start_loc_[i] + seq_lens_[i]
tl.store(seq_start_loc + tl.arange(0, input_tokens_len + 1), seq_start_loc_)
class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
def __init__(self, runner, finished_requests_ids = None):
super().__init__(runner, finished_requests_ids)
self.req_ids = []
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.req_ids.clear()
return super().prepare(finished_requests_ids)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
seq_ids = seq_group_metadata.seq_data.keys()
n_seqs = len(seq_ids)
seq_ids = list(seq_ids)
for seq_idx in range(n_seqs):
self.req_ids.append(seq_ids[seq_idx])
return super().add_seq_group(seq_group_metadata)
def build(self) -> ModelInputForGPU:
model_input = super().build()
last_sampler = get_last_sampler()
spec_step = get_spec_step()
last_step = get_spec_last_step()
if last_sampler is not None:
if spec_step == SpecStepKind.KIND_DEFAULT:
update_indices = []
select_indices = []
query_idx = 0
for i, seq_id in enumerate(self.req_ids):
for j, seq_id_ in enumerate(last_sampler.seq_ids):
if seq_id == seq_id_:
select_indices.append(j)
update_indices.append(query_idx)
break
query_idx += model_input.query_lens[i]
if len(select_indices) > 0 and last_sampler.sampled_token_ids_tensor is not None:
select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
if spec_step == SpecStepKind.OTHER_PROPOSAL:
if last_step == SpecStepKind.OTHER_PROPOSAL: # copy last sampled token ids to input tokens directly.
update_indices = [i for i in range(len(self.req_ids))]
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
if last_step == SpecStepKind.FIRST_PROPOSAL: # TODO: ajust input tokens number to 1 per request.
update_indices = [i for i in range(len(self.req_ids))]
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
if spec_step == SpecStepKind.SCORE_DECODE:
proposal_token_ids = get_proposal_token_ids()
shape = proposal_token_ids.shape
batch_size = shape[0]
proposal_len = shape[1]
update_indices = []
for i in range(batch_size):
for j in range(proposal_len):
update_indices.append(i * (proposal_len + 1) + j + 1)
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = proposal_token_ids.view(-1)
if spec_step == SpecStepKind.FIRST_PROPOSAL:
if last_step == SpecStepKind.PREFILL:# TODO: when last step is prefill, just update the input ids for last seqence_id onely.
pass
if last_step == SpecStepKind.SCORE_DECODE:# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids
accept_token_ids, accept_seq_ids = get_accepted_token_ids()
chidren_req_ids = async_tensor_h2d(self.req_ids, torch.long,
self.runner.device,
self.runner.pin_memory)
grid = [1, 1, 1]
_update_input_tokens[grid](
accept_seq_ids, accept_seq_ids.shape[0],
accept_token_ids, accept_token_ids.shape[1],
chidren_req_ids, chidren_req_ids.shape[0],
model_input.input_tokens, model_input.input_tokens.shape[0],
model_input.input_positions,
model_input.seq_lens,
model_input.attn_metadata.seq_lens_tensor,
model_input.attn_metadata.seq_lens,
model_input.attn_metadata.slot_mapping,
model_input.attn_metadata.seq_start_loc,
model_input.attn_metadata.context_lens_tensor,
)
return model_input