PrefillAdder abstraction (#968)
This commit is contained in:
@@ -17,6 +17,9 @@ limitations under the License.
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
|
|
||||||
|
|
||||||
class PolicyScheduler:
|
class PolicyScheduler:
|
||||||
@@ -83,3 +86,122 @@ class PolicyScheduler:
|
|||||||
for child in childs:
|
for child in childs:
|
||||||
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
|
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
|
||||||
q.extend(last_node_to_reqs[cur_node])
|
q.extend(last_node_to_reqs[cur_node])
|
||||||
|
|
||||||
|
|
||||||
|
class PrefillAdder:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tree_cache,
|
||||||
|
rem_total_tokens,
|
||||||
|
rem_input_tokens,
|
||||||
|
rem_chunk_tokens,
|
||||||
|
):
|
||||||
|
self.tree_cache = tree_cache
|
||||||
|
self.rem_total_tokens = rem_total_tokens
|
||||||
|
self.rem_input_tokens = rem_input_tokens
|
||||||
|
self.rem_chunk_tokens = rem_chunk_tokens
|
||||||
|
|
||||||
|
self.can_run_list = []
|
||||||
|
self.new_inflight_req = None
|
||||||
|
self.log_hit_tokens = 0
|
||||||
|
self.log_input_tokens = 0
|
||||||
|
|
||||||
|
def no_remaining_tokens(self):
|
||||||
|
return (
|
||||||
|
self.rem_total_tokens <= 0
|
||||||
|
or self.rem_input_tokens <= 0
|
||||||
|
or (
|
||||||
|
self.rem_chunk_tokens <= 0
|
||||||
|
if self.rem_chunk_tokens is not None
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_running_tokens(
|
||||||
|
self, running_batch: ScheduleBatch, new_token_ratio: float
|
||||||
|
):
|
||||||
|
self.rem_total_tokens -= sum(
|
||||||
|
[
|
||||||
|
(r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio
|
||||||
|
for r in running_batch.reqs
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prefill_one_req(
|
||||||
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
||||||
|
):
|
||||||
|
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
||||||
|
self.rem_input_tokens -= extend_input_len
|
||||||
|
if self.rem_chunk_tokens is not None:
|
||||||
|
self.rem_chunk_tokens -= extend_input_len
|
||||||
|
|
||||||
|
self.log_hit_tokens += prefix_len
|
||||||
|
self.log_input_tokens += extend_input_len
|
||||||
|
|
||||||
|
def add_inflight_req(self, req: Req):
|
||||||
|
req.input_ids = req.origin_input_ids + req.output_ids
|
||||||
|
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
|
||||||
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||||
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||||
|
req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||||
|
self.can_run_list.append(req)
|
||||||
|
|
||||||
|
self._prefill_one_req(
|
||||||
|
len(req.prefix_indices),
|
||||||
|
req.extend_input_len,
|
||||||
|
req.sampling_params.max_new_tokens if not truncated else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return if chunked prefill not finished
|
||||||
|
return req if truncated else None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _lock_node(self, last_node):
|
||||||
|
try:
|
||||||
|
delta = self.tree_cache.inc_lock_ref(last_node)
|
||||||
|
self.rem_total_tokens += delta
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
delta = self.tree_cache.dec_lock_ref(last_node)
|
||||||
|
self.rem_total_tokens += delta
|
||||||
|
|
||||||
|
def add_one_req(self, req: Req):
|
||||||
|
total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens
|
||||||
|
input_tokens = req.extend_input_len
|
||||||
|
prefix_len = len(req.prefix_indices)
|
||||||
|
|
||||||
|
if total_tokens >= self.rem_total_tokens:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
with self._lock_node(req.last_node):
|
||||||
|
if total_tokens > self.rem_total_tokens:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.rem_chunk_tokens is None
|
||||||
|
or input_tokens <= self.rem_chunk_tokens
|
||||||
|
or (req.return_logprob and req.normalized_prompt_logprob is None)
|
||||||
|
):
|
||||||
|
# Non-chunked prefill
|
||||||
|
self.can_run_list.append(req)
|
||||||
|
self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
|
self._prefill_one_req(
|
||||||
|
prefix_len, input_tokens, req.sampling_params.max_new_tokens
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Chunked prefill
|
||||||
|
trunc_len = self.rem_chunk_tokens
|
||||||
|
if trunc_len == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
req.extend_input_len = trunc_len
|
||||||
|
req.input_ids = req.input_ids[: len(req.prefix_indices) + trunc_len]
|
||||||
|
self.can_run_list.append(req)
|
||||||
|
self.new_inflight_req = req
|
||||||
|
self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.policy_scheduler import PolicyScheduler
|
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
FINISH_ABORT,
|
FINISH_ABORT,
|
||||||
BaseFinishReason,
|
BaseFinishReason,
|
||||||
@@ -377,151 +377,57 @@ class ModelTpServer:
|
|||||||
# Get priority queue
|
# Get priority queue
|
||||||
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
|
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
|
||||||
|
|
||||||
# Add requests if there is available space
|
adder = PrefillAdder(
|
||||||
can_run_list = []
|
self.tree_cache,
|
||||||
new_batch_total_tokens = 0
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
||||||
new_batch_input_tokens = 0
|
self.max_prefill_tokens,
|
||||||
|
self.chunked_prefill_size,
|
||||||
available_size = (
|
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
|
||||||
)
|
)
|
||||||
if self.running_batch:
|
|
||||||
available_size -= sum(
|
|
||||||
[
|
|
||||||
(r.sampling_params.max_new_tokens - len(r.output_ids))
|
|
||||||
* self.new_token_ratio
|
|
||||||
for r in self.running_batch.reqs
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle the current inflight request
|
if self.running_batch is not None:
|
||||||
take_inflight = 0
|
adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
|
||||||
if self.current_inflight_req:
|
|
||||||
take_inflight = 1
|
|
||||||
r = self.current_inflight_req
|
|
||||||
r.input_ids = r.origin_input_ids + r.output_ids
|
|
||||||
truncated = (
|
|
||||||
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
|
|
||||||
)
|
|
||||||
r.extend_input_len = min(
|
|
||||||
len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
|
|
||||||
)
|
|
||||||
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
|
|
||||||
can_run_list.append(r)
|
|
||||||
|
|
||||||
if not truncated:
|
has_inflight = self.current_inflight_req is not None
|
||||||
# Finish inflight
|
if self.current_inflight_req is not None:
|
||||||
self.current_inflight_req = None
|
self.current_inflight_req = adder.add_inflight_req(
|
||||||
new_batch_total_tokens += (
|
self.current_inflight_req
|
||||||
r.extend_input_len + r.sampling_params.max_new_tokens
|
)
|
||||||
)
|
|
||||||
new_batch_input_tokens += r.extend_input_len
|
|
||||||
else:
|
|
||||||
new_batch_total_tokens += r.extend_input_len
|
|
||||||
new_batch_input_tokens += r.extend_input_len
|
|
||||||
|
|
||||||
for req in self.waiting_queue:
|
for req in self.waiting_queue:
|
||||||
if req.return_logprob and req.normalized_prompt_logprob is None:
|
res = adder.add_one_req(req)
|
||||||
# Need at least two tokens to compute normalized logprob
|
|
||||||
if req.extend_input_len < 2:
|
|
||||||
delta = 2 - req.extend_input_len
|
|
||||||
req.extend_input_len += delta
|
|
||||||
req.prefix_indices = req.prefix_indices[:-delta]
|
|
||||||
if req.image_offset is not None:
|
|
||||||
req.image_offset += delta
|
|
||||||
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
|
|
||||||
# Need at least one token to compute logits
|
|
||||||
req.extend_input_len = 1
|
|
||||||
req.prefix_indices = req.prefix_indices[:-1]
|
|
||||||
if req.image_offset is not None:
|
|
||||||
req.image_offset += 1
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
req.extend_input_len
|
not res
|
||||||
+ req.sampling_params.max_new_tokens
|
or adder.no_remaining_tokens()
|
||||||
+ new_batch_total_tokens
|
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
||||||
< available_size
|
|
||||||
and (
|
|
||||||
req.extend_input_len + new_batch_input_tokens
|
|
||||||
<= self.max_prefill_tokens
|
|
||||||
or len(can_run_list) == 0
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
|
||||||
available_size += delta
|
|
||||||
|
|
||||||
if not (
|
|
||||||
req.extend_input_len
|
|
||||||
+ req.sampling_params.max_new_tokens
|
|
||||||
+ new_batch_total_tokens
|
|
||||||
< available_size
|
|
||||||
):
|
|
||||||
# Undo locking
|
|
||||||
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
|
||||||
available_size += delta
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Add this request to the running batch
|
|
||||||
if (
|
|
||||||
self.chunked_prefill_size is None
|
|
||||||
or (
|
|
||||||
new_batch_input_tokens + req.extend_input_len
|
|
||||||
<= self.chunked_prefill_size
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
req.return_logprob and req.normalized_prompt_logprob is None
|
|
||||||
)
|
|
||||||
):
|
|
||||||
can_run_list.append(req)
|
|
||||||
new_batch_total_tokens += (
|
|
||||||
req.extend_input_len + req.sampling_params.max_new_tokens
|
|
||||||
)
|
|
||||||
new_batch_input_tokens += req.extend_input_len
|
|
||||||
else:
|
|
||||||
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
|
|
||||||
|
|
||||||
if trunc_len <= 0:
|
|
||||||
# Undo locking
|
|
||||||
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
|
||||||
available_size += delta
|
|
||||||
break
|
|
||||||
|
|
||||||
req.extend_input_len = trunc_len
|
|
||||||
req.input_ids = req.input_ids[
|
|
||||||
: len(req.prefix_indices) + req.extend_input_len
|
|
||||||
]
|
|
||||||
can_run_list.append(req)
|
|
||||||
self.current_inflight_req = req
|
|
||||||
new_batch_input_tokens += req.extend_input_len
|
|
||||||
new_batch_total_tokens += req.extend_input_len
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if running_bs + len(can_run_list) >= self.max_running_requests:
|
can_run_list = adder.can_run_list
|
||||||
break
|
|
||||||
|
if adder.new_inflight_req is not None:
|
||||||
|
assert self.current_inflight_req is None
|
||||||
|
self.current_inflight_req = adder.new_inflight_req
|
||||||
|
|
||||||
if len(can_run_list) == 0:
|
if len(can_run_list) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Print stats
|
# Print stats
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
|
||||||
self.tree_cache_metrics["total"] += (
|
self.tree_cache_metrics["total"] += (
|
||||||
hit_tokens + new_batch_input_tokens
|
adder.log_input_tokens + adder.log_hit_tokens
|
||||||
) / 10**9
|
) / 10**9
|
||||||
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
|
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
||||||
tree_cache_hit_rate = (
|
tree_cache_hit_rate = (
|
||||||
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[gpu={self.gpu_id}] Prefill batch. "
|
f"[gpu={self.gpu_id}] Prefill batch. "
|
||||||
f"#new-seq: {len(can_run_list)}, "
|
f"#new-seq: {len(can_run_list)}, "
|
||||||
f"#new-token: {new_batch_input_tokens}, "
|
f"#new-token: {adder.log_input_tokens}, "
|
||||||
f"#cached-token: {hit_tokens}, "
|
f"#cached-token: {adder.log_hit_tokens}, "
|
||||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||||
f"#running-req: {running_bs}, "
|
f"#running-req: {running_bs}, "
|
||||||
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
|
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return the new batch
|
# Return the new batch
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ class ModelRunner:
|
|||||||
server_args.max_total_tokens,
|
server_args.max_total_tokens,
|
||||||
)
|
)
|
||||||
self.init_cublas()
|
self.init_cublas()
|
||||||
self.init_flash_infer()
|
self.init_flashinfer()
|
||||||
|
|
||||||
# Capture cuda graphs
|
# Capture cuda graphs
|
||||||
self.init_cuda_graphs()
|
self.init_cuda_graphs()
|
||||||
@@ -287,7 +287,7 @@ class ModelRunner:
|
|||||||
c = a @ b
|
c = a @ b
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def init_flash_infer(self):
|
def init_flashinfer(self):
|
||||||
if self.server_args.disable_flashinfer:
|
if self.server_args.disable_flashinfer:
|
||||||
self.flashinfer_prefill_wrapper_ragged = None
|
self.flashinfer_prefill_wrapper_ragged = None
|
||||||
self.flashinfer_prefill_wrapper_paged = None
|
self.flashinfer_prefill_wrapper_paged = None
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|||||||
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
|||||||
@@ -46,8 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
@@ -368,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -394,14 +391,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def sample(
|
|
||||||
self,
|
|
||||||
logits: Optional[torch.Tensor],
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[SamplerOutput]:
|
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user