Files
2026-01-09 15:09:53 +08:00

72 lines
1.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from enum import Enum
import os
import torch
import vllm.envs as envs
zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1'
def is_zero_no_thread():
return zero_no_thread and envs.VLLM_ZERO_OVERHEAD
class SpecStepKind(Enum):
KIND_DEFAULT = 0
PREFILL = 1
FIRST_PROPOSAL = 2
OTHER_PROPOSAL = 3
SCORE_DECODE = 4
class ZeroOverheadSpecContext():
def __init__(self):
self.step_kind = SpecStepKind.KIND_DEFAULT
self.last_step = SpecStepKind.KIND_DEFAULT
self.proposal_lens_list = None
self.proposal_token_ids = None
self.accepted_token_ids = None
self.accepted_seq_ids = None
spec_context = ZeroOverheadSpecContext()
def set_spec_step(_step):
global spec_context
spec_context.last_step = spec_context.step_kind
spec_context.step_kind = _step
def get_spec_step():
return spec_context.step_kind
def get_spec_last_step():
return spec_context.last_step
def record_proposal_lens_list(list):
global spec_context
spec_context.proposal_lens_list = list
def get_proposal_lens_list():
return spec_context.proposal_lens_list
def record_proposal_token_ids(tensor):
global spec_context
spec_context.proposal_token_ids = tensor
def get_proposal_token_ids():
return spec_context.proposal_token_ids
def record_accepted_token_ids(tensor, seq_ids):
global spec_context
spec_context.accepted_token_ids = tensor
spec_context.accepted_seq_ids = seq_ids
def get_accepted_token_ids():
return spec_context.accepted_token_ids, spec_context.accepted_seq_ids
# 零消耗调度不在默认流上推理用以规避runtime引入的内存申请流同步问题。
alloc_stream = {}
def zero_overhead_stream(target_device):
"""Asynchronously create a tensor and copy it from host to device."""
if target_device not in alloc_stream.keys():
alloc_stream[target_device] = torch.cuda.Stream(device=target_device)
return alloc_stream[target_device]