PrefillAdder abstraction (#968)
This commit is contained in:
@@ -17,6 +17,9 @@ limitations under the License.
|
||||
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
|
||||
|
||||
class PolicyScheduler:
|
||||
@@ -83,3 +86,122 @@ class PolicyScheduler:
|
||||
for child in childs:
|
||||
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
|
||||
q.extend(last_node_to_reqs[cur_node])
|
||||
|
||||
|
||||
class PrefillAdder:
|
||||
def __init__(
|
||||
self,
|
||||
tree_cache,
|
||||
rem_total_tokens,
|
||||
rem_input_tokens,
|
||||
rem_chunk_tokens,
|
||||
):
|
||||
self.tree_cache = tree_cache
|
||||
self.rem_total_tokens = rem_total_tokens
|
||||
self.rem_input_tokens = rem_input_tokens
|
||||
self.rem_chunk_tokens = rem_chunk_tokens
|
||||
|
||||
self.can_run_list = []
|
||||
self.new_inflight_req = None
|
||||
self.log_hit_tokens = 0
|
||||
self.log_input_tokens = 0
|
||||
|
||||
def no_remaining_tokens(self):
|
||||
return (
|
||||
self.rem_total_tokens <= 0
|
||||
or self.rem_input_tokens <= 0
|
||||
or (
|
||||
self.rem_chunk_tokens <= 0
|
||||
if self.rem_chunk_tokens is not None
|
||||
else False
|
||||
)
|
||||
)
|
||||
|
||||
def remove_running_tokens(
|
||||
self, running_batch: ScheduleBatch, new_token_ratio: float
|
||||
):
|
||||
self.rem_total_tokens -= sum(
|
||||
[
|
||||
(r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio
|
||||
for r in running_batch.reqs
|
||||
]
|
||||
)
|
||||
|
||||
def _prefill_one_req(
|
||||
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
||||
):
|
||||
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
||||
self.rem_input_tokens -= extend_input_len
|
||||
if self.rem_chunk_tokens is not None:
|
||||
self.rem_chunk_tokens -= extend_input_len
|
||||
|
||||
self.log_hit_tokens += prefix_len
|
||||
self.log_input_tokens += extend_input_len
|
||||
|
||||
def add_inflight_req(self, req: Req):
|
||||
req.input_ids = req.origin_input_ids + req.output_ids
|
||||
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
|
||||
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||
req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||
self.can_run_list.append(req)
|
||||
|
||||
self._prefill_one_req(
|
||||
len(req.prefix_indices),
|
||||
req.extend_input_len,
|
||||
req.sampling_params.max_new_tokens if not truncated else 0,
|
||||
)
|
||||
|
||||
# Return if chunked prefill not finished
|
||||
return req if truncated else None
|
||||
|
||||
@contextmanager
|
||||
def _lock_node(self, last_node):
|
||||
try:
|
||||
delta = self.tree_cache.inc_lock_ref(last_node)
|
||||
self.rem_total_tokens += delta
|
||||
yield None
|
||||
finally:
|
||||
delta = self.tree_cache.dec_lock_ref(last_node)
|
||||
self.rem_total_tokens += delta
|
||||
|
||||
def add_one_req(self, req: Req):
|
||||
total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens
|
||||
input_tokens = req.extend_input_len
|
||||
prefix_len = len(req.prefix_indices)
|
||||
|
||||
if total_tokens >= self.rem_total_tokens:
|
||||
return False
|
||||
|
||||
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
|
||||
return False
|
||||
|
||||
with self._lock_node(req.last_node):
|
||||
if total_tokens > self.rem_total_tokens:
|
||||
return False
|
||||
|
||||
if (
|
||||
self.rem_chunk_tokens is None
|
||||
or input_tokens <= self.rem_chunk_tokens
|
||||
or (req.return_logprob and req.normalized_prompt_logprob is None)
|
||||
):
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self._prefill_one_req(
|
||||
prefix_len, input_tokens, req.sampling_params.max_new_tokens
|
||||
)
|
||||
else:
|
||||
# Chunked prefill
|
||||
trunc_len = self.rem_chunk_tokens
|
||||
if trunc_len == 0:
|
||||
return False
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.input_ids = req.input_ids[: len(req.prefix_indices) + trunc_len]
|
||||
self.can_run_list.append(req)
|
||||
self.new_inflight_req = req
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self._prefill_one_req(prefix_len, trunc_len, 0)
|
||||
|
||||
return True
|
||||
|
||||
@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
|
||||
FlushCacheReq,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.policy_scheduler import PolicyScheduler
|
||||
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
BaseFinishReason,
|
||||
@@ -377,151 +377,57 @@ class ModelTpServer:
|
||||
# Get priority queue
|
||||
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
|
||||
|
||||
# Add requests if there is available space
|
||||
can_run_list = []
|
||||
new_batch_total_tokens = 0
|
||||
new_batch_input_tokens = 0
|
||||
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
adder = PrefillAdder(
|
||||
self.tree_cache,
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
||||
self.max_prefill_tokens,
|
||||
self.chunked_prefill_size,
|
||||
)
|
||||
if self.running_batch:
|
||||
available_size -= sum(
|
||||
[
|
||||
(r.sampling_params.max_new_tokens - len(r.output_ids))
|
||||
* self.new_token_ratio
|
||||
for r in self.running_batch.reqs
|
||||
]
|
||||
)
|
||||
|
||||
# Handle the current inflight request
|
||||
take_inflight = 0
|
||||
if self.current_inflight_req:
|
||||
take_inflight = 1
|
||||
r = self.current_inflight_req
|
||||
r.input_ids = r.origin_input_ids + r.output_ids
|
||||
truncated = (
|
||||
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
|
||||
)
|
||||
r.extend_input_len = min(
|
||||
len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
|
||||
)
|
||||
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
|
||||
can_run_list.append(r)
|
||||
if self.running_batch is not None:
|
||||
adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
|
||||
|
||||
if not truncated:
|
||||
# Finish inflight
|
||||
self.current_inflight_req = None
|
||||
new_batch_total_tokens += (
|
||||
r.extend_input_len + r.sampling_params.max_new_tokens
|
||||
)
|
||||
new_batch_input_tokens += r.extend_input_len
|
||||
else:
|
||||
new_batch_total_tokens += r.extend_input_len
|
||||
new_batch_input_tokens += r.extend_input_len
|
||||
has_inflight = self.current_inflight_req is not None
|
||||
if self.current_inflight_req is not None:
|
||||
self.current_inflight_req = adder.add_inflight_req(
|
||||
self.current_inflight_req
|
||||
)
|
||||
|
||||
for req in self.waiting_queue:
|
||||
if req.return_logprob and req.normalized_prompt_logprob is None:
|
||||
# Need at least two tokens to compute normalized logprob
|
||||
if req.extend_input_len < 2:
|
||||
delta = 2 - req.extend_input_len
|
||||
req.extend_input_len += delta
|
||||
req.prefix_indices = req.prefix_indices[:-delta]
|
||||
if req.image_offset is not None:
|
||||
req.image_offset += delta
|
||||
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
|
||||
# Need at least one token to compute logits
|
||||
req.extend_input_len = 1
|
||||
req.prefix_indices = req.prefix_indices[:-1]
|
||||
if req.image_offset is not None:
|
||||
req.image_offset += 1
|
||||
|
||||
res = adder.add_one_req(req)
|
||||
if (
|
||||
req.extend_input_len
|
||||
+ req.sampling_params.max_new_tokens
|
||||
+ new_batch_total_tokens
|
||||
< available_size
|
||||
and (
|
||||
req.extend_input_len + new_batch_input_tokens
|
||||
<= self.max_prefill_tokens
|
||||
or len(can_run_list) == 0
|
||||
)
|
||||
not res
|
||||
or adder.no_remaining_tokens()
|
||||
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
||||
):
|
||||
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
||||
available_size += delta
|
||||
|
||||
if not (
|
||||
req.extend_input_len
|
||||
+ req.sampling_params.max_new_tokens
|
||||
+ new_batch_total_tokens
|
||||
< available_size
|
||||
):
|
||||
# Undo locking
|
||||
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
||||
available_size += delta
|
||||
break
|
||||
else:
|
||||
# Add this request to the running batch
|
||||
if (
|
||||
self.chunked_prefill_size is None
|
||||
or (
|
||||
new_batch_input_tokens + req.extend_input_len
|
||||
<= self.chunked_prefill_size
|
||||
)
|
||||
or (
|
||||
req.return_logprob and req.normalized_prompt_logprob is None
|
||||
)
|
||||
):
|
||||
can_run_list.append(req)
|
||||
new_batch_total_tokens += (
|
||||
req.extend_input_len + req.sampling_params.max_new_tokens
|
||||
)
|
||||
new_batch_input_tokens += req.extend_input_len
|
||||
else:
|
||||
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
|
||||
|
||||
if trunc_len <= 0:
|
||||
# Undo locking
|
||||
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
||||
available_size += delta
|
||||
break
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.input_ids = req.input_ids[
|
||||
: len(req.prefix_indices) + req.extend_input_len
|
||||
]
|
||||
can_run_list.append(req)
|
||||
self.current_inflight_req = req
|
||||
new_batch_input_tokens += req.extend_input_len
|
||||
new_batch_total_tokens += req.extend_input_len
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
if running_bs + len(can_run_list) >= self.max_running_requests:
|
||||
break
|
||||
can_run_list = adder.can_run_list
|
||||
|
||||
if adder.new_inflight_req is not None:
|
||||
assert self.current_inflight_req is None
|
||||
self.current_inflight_req = adder.new_inflight_req
|
||||
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
|
||||
# Print stats
|
||||
if self.tp_rank == 0:
|
||||
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
||||
self.tree_cache_metrics["total"] += (
|
||||
hit_tokens + new_batch_input_tokens
|
||||
adder.log_input_tokens + adder.log_hit_tokens
|
||||
) / 10**9
|
||||
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
|
||||
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
||||
tree_cache_hit_rate = (
|
||||
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
||||
)
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Prefill batch. "
|
||||
f"#new-seq: {len(can_run_list)}, "
|
||||
f"#new-token: {new_batch_input_tokens}, "
|
||||
f"#cached-token: {hit_tokens}, "
|
||||
f"#new-token: {adder.log_input_tokens}, "
|
||||
f"#cached-token: {adder.log_hit_tokens}, "
|
||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||
f"#running-req: {running_bs}, "
|
||||
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
|
||||
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
||||
)
|
||||
|
||||
# Return the new batch
|
||||
|
||||
@@ -130,7 +130,7 @@ class ModelRunner:
|
||||
server_args.max_total_tokens,
|
||||
)
|
||||
self.init_cublas()
|
||||
self.init_flash_infer()
|
||||
self.init_flashinfer()
|
||||
|
||||
# Capture cuda graphs
|
||||
self.init_cuda_graphs()
|
||||
@@ -287,7 +287,7 @@ class ModelRunner:
|
||||
c = a @ b
|
||||
return c
|
||||
|
||||
def init_flash_infer(self):
|
||||
def init_flashinfer(self):
|
||||
if self.server_args.disable_flashinfer:
|
||||
self.flashinfer_prefill_wrapper_ragged = None
|
||||
self.flashinfer_prefill_wrapper_paged = None
|
||||
|
||||
@@ -38,7 +38,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
||||
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
|
||||
@@ -46,8 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -368,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -394,14 +391,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
Reference in New Issue
Block a user