Clean up batch data structures: Introducing ModelWorkerBatch (#1544)
This commit is contained in:
@@ -62,11 +62,13 @@ import torch.distributed as dist
|
|||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server import _set_envs_and_config
|
from sglang.srt.server import _set_envs_and_config
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
|
allocate_init_ports,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
kill_child_process,
|
kill_child_process,
|
||||||
suppress_other_loggers,
|
suppress_other_loggers,
|
||||||
@@ -125,6 +127,11 @@ def load_model(server_args, tp_rank):
|
|||||||
suppress_other_loggers()
|
suppress_other_loggers()
|
||||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||||
|
server_args.port,
|
||||||
|
server_args.additional_ports,
|
||||||
|
server_args.dp_size,
|
||||||
|
)
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
server_args.model_path,
|
server_args.model_path,
|
||||||
server_args.trust_remote_code,
|
server_args.trust_remote_code,
|
||||||
@@ -136,7 +143,7 @@ def load_model(server_args, tp_rank):
|
|||||||
gpu_id=tp_rank,
|
gpu_id=tp_rank,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
tp_size=server_args.tp_size,
|
tp_size=server_args.tp_size,
|
||||||
nccl_port=28888,
|
nccl_port=server_args.additional_ports[-1],
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
)
|
)
|
||||||
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
||||||
@@ -225,17 +232,19 @@ def extend(reqs, model_runner):
|
|||||||
tree_cache=None,
|
tree_cache=None,
|
||||||
)
|
)
|
||||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
||||||
forward_batch = batch.get_forward_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||||
logits_output = model_runner.forward(forward_batch)
|
logits_output = model_runner.forward(forward_batch)
|
||||||
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
|
||||||
return next_token_ids, logits_output.next_token_logits, batch
|
return next_token_ids, logits_output.next_token_logits, batch
|
||||||
|
|
||||||
|
|
||||||
def decode(input_token_ids, batch, model_runner):
|
def decode(input_token_ids, batch, model_runner):
|
||||||
batch.prepare_for_decode(input_token_ids)
|
batch.prepare_for_decode(input_token_ids)
|
||||||
forward_batch = batch.get_forward_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||||
logits_output = model_runner.forward(forward_batch)
|
logits_output = model_runner.forward(forward_batch)
|
||||||
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
|
||||||
return next_token_ids, logits_output.next_token_logits
|
return next_token_ids, logits_output.next_token_logits
|
||||||
|
|
||||||
|
|
||||||
@@ -357,7 +366,6 @@ def latency_test(
|
|||||||
tp_rank,
|
tp_rank,
|
||||||
):
|
):
|
||||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||||
_set_envs_and_config(server_args)
|
|
||||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
@@ -463,6 +471,7 @@ def plot_latency_test(
|
|||||||
|
|
||||||
|
|
||||||
def main(server_args, bench_args):
|
def main(server_args, bench_args):
|
||||||
|
_set_envs_and_config(server_args)
|
||||||
|
|
||||||
if server_args.model_path:
|
if server_args.model_path:
|
||||||
if bench_args.correctness_test:
|
if bench_args.correctness_test:
|
||||||
@@ -513,8 +522,6 @@ if __name__ == "__main__":
|
|||||||
format="%(message)s",
|
format="%(message)s",
|
||||||
)
|
)
|
||||||
|
|
||||||
multiprocessing.set_start_method("spawn", force=True)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(server_args, bench_args)
|
main(server_args, bench_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -62,7 +62,11 @@ class LogitsMetadata:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
if forward_batch.return_logprob:
|
||||||
|
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
||||||
|
else:
|
||||||
|
return_top_logprob = False
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_extend():
|
if forward_batch.forward_mode.is_extend():
|
||||||
extend_logprob_pruned_lens_cpu = [
|
extend_logprob_pruned_lens_cpu = [
|
||||||
extend_len - start_len
|
extend_len - start_len
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Copyright 2023-2024 SGLang Team
|
Copyright 2023-2024 SGLang Team
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -15,7 +13,19 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""Meta data for requests and batches"""
|
"""
|
||||||
|
Store information about requests and batches.
|
||||||
|
|
||||||
|
The following is the flow of data structures for a batch:
|
||||||
|
|
||||||
|
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||||
|
|
||||||
|
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
||||||
|
It contains high-level scheduling data. Most of the data is on the CPU.
|
||||||
|
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
||||||
|
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
||||||
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -29,7 +39,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
|
|||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -105,6 +115,8 @@ class FINISH_ABORT(BaseFinishReason):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ImageInputs:
|
class ImageInputs:
|
||||||
|
"""The image related inputs."""
|
||||||
|
|
||||||
pixel_values: torch.Tensor
|
pixel_values: torch.Tensor
|
||||||
image_hash: int
|
image_hash: int
|
||||||
image_sizes: Optional[list] = None
|
image_sizes: Optional[list] = None
|
||||||
@@ -137,7 +149,7 @@ class ImageInputs:
|
|||||||
|
|
||||||
|
|
||||||
class Req:
|
class Req:
|
||||||
"""Store all inforamtion of a request."""
|
"""The input and output status of a request."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -393,20 +405,20 @@ class ScheduleBatch:
|
|||||||
sampling_info: SamplingBatchInfo = None
|
sampling_info: SamplingBatchInfo = None
|
||||||
|
|
||||||
# Batched arguments to model runner
|
# Batched arguments to model runner
|
||||||
input_ids: torch.Tensor = None
|
input_ids: List[int] = None
|
||||||
req_pool_indices: torch.Tensor = None
|
req_pool_indices: List[int] = None
|
||||||
seq_lens: torch.Tensor = None
|
seq_lens: List[int] = None
|
||||||
position_ids_offsets: torch.Tensor = None
|
|
||||||
out_cache_loc: torch.Tensor = None
|
out_cache_loc: torch.Tensor = None
|
||||||
extend_num_tokens: int = None
|
|
||||||
|
|
||||||
# For mixed chunekd prefill
|
|
||||||
prefix_lens_cpu: List[int] = None
|
|
||||||
running_bs: int = None
|
|
||||||
|
|
||||||
# For processing logprobs
|
# For processing logprobs
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: List[int] = None
|
top_logprobs_nums: Optional[List[int]] = None
|
||||||
|
|
||||||
|
# For extend and mixed chunekd prefill
|
||||||
|
prefix_lens: List[int] = None
|
||||||
|
extend_lens: List[int] = None
|
||||||
|
extend_num_tokens: int = None
|
||||||
|
running_bs: int = None
|
||||||
|
|
||||||
# Stream
|
# Stream
|
||||||
has_stream: bool = False
|
has_stream: bool = False
|
||||||
@@ -466,12 +478,12 @@ class ScheduleBatch:
|
|||||||
seq_lens = []
|
seq_lens = []
|
||||||
|
|
||||||
# Allocate memory
|
# Allocate memory
|
||||||
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
req_pool_indices = self.alloc_req_slots(bs)
|
||||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||||
|
|
||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(reqs):
|
||||||
req.req_pool_idx = req_pool_indices_cpu[i]
|
req.req_pool_idx = req_pool_indices[i]
|
||||||
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||||
seq_lens.append(seq_len)
|
seq_lens.append(seq_len)
|
||||||
assert seq_len - pre_len == req.extend_input_len
|
assert seq_len - pre_len == req.extend_input_len
|
||||||
@@ -497,22 +509,19 @@ class ScheduleBatch:
|
|||||||
pt += req.extend_input_len
|
pt += req.extend_input_len
|
||||||
|
|
||||||
# Set fields
|
# Set fields
|
||||||
with torch.device("cuda"):
|
self.input_ids = sum(input_ids, [])
|
||||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
self.req_pool_indices = torch.tensor(req_pool_indices, device="cuda")
|
||||||
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
|
self.seq_lens = torch.tensor(seq_lens, device="cuda")
|
||||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
|
||||||
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
|
|
||||||
|
|
||||||
self.extend_num_tokens = extend_num_tokens
|
self.extend_num_tokens = extend_num_tokens
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
if self.return_logprob:
|
||||||
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||||
self.extend_lens_cpu = [r.extend_input_len for r in reqs]
|
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||||
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||||
|
|
||||||
def get_forward_batch(self):
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
||||||
return ForwardBatch.from_schedule_batch(self)
|
|
||||||
|
|
||||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||||
self.forward_mode = ForwardMode.MIXED
|
self.forward_mode = ForwardMode.MIXED
|
||||||
@@ -522,24 +531,24 @@ class ScheduleBatch:
|
|||||||
req.fill_ids = req.origin_input_ids + req.output_ids
|
req.fill_ids = req.origin_input_ids + req.output_ids
|
||||||
req.extend_input_len = 1
|
req.extend_input_len = 1
|
||||||
|
|
||||||
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
input_ids = self.input_ids + running_batch.input_ids
|
||||||
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
||||||
extend_num_tokens = self.extend_num_tokens + running_bs
|
extend_num_tokens = self.extend_num_tokens + running_bs
|
||||||
|
|
||||||
self.merge(running_batch)
|
self.merge_batch(running_batch)
|
||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
self.extend_num_tokens = extend_num_tokens
|
self.extend_num_tokens = extend_num_tokens
|
||||||
|
|
||||||
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
||||||
self.prefix_lens_cpu.extend(
|
self.prefix_lens.extend(
|
||||||
[
|
[
|
||||||
len(r.origin_input_ids) + len(r.output_ids) - 1
|
len(r.origin_input_ids) + len(r.output_ids) - 1
|
||||||
for r in running_batch.reqs
|
for r in running_batch.reqs
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.extend_lens_cpu.extend([1] * running_bs)
|
self.extend_lens.extend([1] * running_bs)
|
||||||
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
||||||
|
|
||||||
def check_decode_mem(self):
|
def check_decode_mem(self):
|
||||||
bs = len(self.reqs)
|
bs = len(self.reqs)
|
||||||
@@ -631,7 +640,7 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
return retracted_reqs, new_estimate_ratio
|
return retracted_reqs, new_estimate_ratio
|
||||||
|
|
||||||
def check_for_jump_forward(self, model_runner):
|
def check_for_jump_forward(self, pad_input_ids_func):
|
||||||
jump_forward_reqs = []
|
jump_forward_reqs = []
|
||||||
filter_indices = [i for i in range(len(self.reqs))]
|
filter_indices = [i for i in range(len(self.reqs))]
|
||||||
|
|
||||||
@@ -688,7 +697,7 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
# re-applying image padding
|
# re-applying image padding
|
||||||
if req.image_inputs is not None:
|
if req.image_inputs is not None:
|
||||||
req.origin_input_ids = model_runner.model.pad_input_ids(
|
req.origin_input_ids = pad_input_ids_func(
|
||||||
req.origin_input_ids_unpadded, req.image_inputs
|
req.origin_input_ids_unpadded, req.image_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -708,7 +717,7 @@ class ScheduleBatch:
|
|||||||
for r in self.reqs
|
for r in self.reqs
|
||||||
]
|
]
|
||||||
|
|
||||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
self.input_ids = input_ids
|
||||||
self.seq_lens.add_(1)
|
self.seq_lens.add_(1)
|
||||||
|
|
||||||
# Alloc mem
|
# Alloc mem
|
||||||
@@ -731,32 +740,97 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
||||||
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
||||||
self.seq_lens = self.seq_lens[new_indices]
|
|
||||||
self.input_ids = None
|
|
||||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||||
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
self.seq_lens = self.seq_lens[new_indices]
|
||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
|
if self.return_logprob:
|
||||||
|
self.top_logprobs_nums = [
|
||||||
|
self.top_logprobs_nums[i] for i in unfinished_indices
|
||||||
|
]
|
||||||
self.has_stream = any(req.stream for req in self.reqs)
|
self.has_stream = any(req.stream for req in self.reqs)
|
||||||
|
|
||||||
self.sampling_info.filter(unfinished_indices, new_indices)
|
self.sampling_info.filter_batch(unfinished_indices, new_indices)
|
||||||
|
|
||||||
def merge(self, other: "ScheduleBatch"):
|
def merge_batch(self, other: "ScheduleBatch"):
|
||||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||||
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
||||||
# needs to be called with pre-merged Batch.reqs.
|
# needs to be called with pre-merged Batch.reqs.
|
||||||
self.sampling_info.merge(other.sampling_info)
|
self.sampling_info.merge_batch(other.sampling_info)
|
||||||
|
|
||||||
self.reqs.extend(other.reqs)
|
self.reqs.extend(other.reqs)
|
||||||
self.req_pool_indices = torch.concat(
|
self.req_pool_indices = torch.concat(
|
||||||
[self.req_pool_indices, other.req_pool_indices]
|
[self.req_pool_indices, other.req_pool_indices]
|
||||||
)
|
)
|
||||||
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
||||||
self.position_ids_offsets = torch.concat(
|
|
||||||
[self.position_ids_offsets, other.position_ids_offsets]
|
|
||||||
)
|
|
||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
|
if self.return_logprob and other.return_logprob:
|
||||||
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||||
|
elif self.return_logprob:
|
||||||
|
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
||||||
|
elif other.return_logprob:
|
||||||
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
||||||
self.has_stream = any(req.stream for req in self.reqs)
|
self.has_stream = any(req.stream for req in self.reqs)
|
||||||
|
|
||||||
|
def get_model_worker_batch(self):
|
||||||
|
if self.forward_mode.is_decode():
|
||||||
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
|
||||||
|
image_inputs
|
||||||
|
) = None
|
||||||
|
else:
|
||||||
|
extend_seq_lens = self.extend_lens
|
||||||
|
extend_prefix_lens = self.prefix_lens
|
||||||
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||||
|
image_inputs = [r.image_inputs for r in self.reqs]
|
||||||
|
|
||||||
|
lora_paths = [req.lora_path for req in self.reqs]
|
||||||
|
self.sampling_info.regex_fsm_states = [req.regex_fsm_state for req in self.reqs]
|
||||||
|
|
||||||
|
return ModelWorkerBatch(
|
||||||
|
forward_mode=self.forward_mode,
|
||||||
|
input_ids=self.input_ids,
|
||||||
|
req_pool_indices=self.req_pool_indices,
|
||||||
|
seq_lens=self.seq_lens,
|
||||||
|
out_cache_loc=self.out_cache_loc,
|
||||||
|
return_logprob=self.return_logprob,
|
||||||
|
top_logprobs_nums=self.top_logprobs_nums,
|
||||||
|
extend_seq_lens=extend_seq_lens,
|
||||||
|
extend_prefix_lens=extend_prefix_lens,
|
||||||
|
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||||
|
image_inputs=image_inputs,
|
||||||
|
lora_paths=lora_paths,
|
||||||
|
sampling_info=self.sampling_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelWorkerBatch:
|
||||||
|
# The forward mode
|
||||||
|
forward_mode: ForwardMode
|
||||||
|
# The input ids
|
||||||
|
input_ids: List[int]
|
||||||
|
# The indices of requests in the req_to_token_pool
|
||||||
|
req_pool_indices: torch.Tensor
|
||||||
|
# The sequence length
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
# The indices of output tokens in the token_to_kv_pool
|
||||||
|
out_cache_loc: torch.Tensor
|
||||||
|
|
||||||
|
# For logprob
|
||||||
|
return_logprob: bool
|
||||||
|
top_logprobs_nums: Optional[List[int]]
|
||||||
|
|
||||||
|
# For extend
|
||||||
|
extend_seq_lens: Optional[List[int]]
|
||||||
|
extend_prefix_lens: Optional[List[int]]
|
||||||
|
extend_logprob_start_lens: Optional[List[int]]
|
||||||
|
|
||||||
|
# For multimodal
|
||||||
|
image_inputs: Optional[List[ImageInputs]]
|
||||||
|
|
||||||
|
# For LoRA
|
||||||
|
lora_paths: Optional[List[str]]
|
||||||
|
|
||||||
|
# Sampling info
|
||||||
|
sampling_info: SamplingBatchInfo
|
||||||
|
|||||||
@@ -141,6 +141,9 @@ class Scheduler:
|
|||||||
nccl_port=port_args.nccl_ports[0],
|
nccl_port=port_args.nccl_ports[0],
|
||||||
)
|
)
|
||||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||||
|
self.pad_input_ids_func = getattr(
|
||||||
|
self.tp_worker.model_runner.model, "pad_input_ids", None
|
||||||
|
)
|
||||||
|
|
||||||
# Get token and memory info from the tp worker
|
# Get token and memory info from the tp worker
|
||||||
(
|
(
|
||||||
@@ -292,7 +295,7 @@ class Scheduler:
|
|||||||
if self.running_batch is None:
|
if self.running_batch is None:
|
||||||
self.running_batch = new_batch
|
self.running_batch = new_batch
|
||||||
else:
|
else:
|
||||||
self.running_batch.merge(new_batch)
|
self.running_batch.merge_batch(new_batch)
|
||||||
else:
|
else:
|
||||||
# Run a decode batch
|
# Run a decode batch
|
||||||
if self.running_batch is not None:
|
if self.running_batch is not None:
|
||||||
@@ -370,7 +373,7 @@ class Scheduler:
|
|||||||
req.image_inputs = ImageInputs.from_dict(
|
req.image_inputs = ImageInputs.from_dict(
|
||||||
recv_req.image_inputs, self.model_config.vocab_size
|
recv_req.image_inputs, self.model_config.vocab_size
|
||||||
)
|
)
|
||||||
req.origin_input_ids = self.tp_worker.model_runner.model.pad_input_ids(
|
req.origin_input_ids = self.pad_input_ids_func(
|
||||||
req.origin_input_ids_unpadded, req.image_inputs
|
req.origin_input_ids_unpadded, req.image_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -575,9 +578,9 @@ class Scheduler:
|
|||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
forward_batch = batch.get_forward_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||||
forward_batch, batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
next_token_ids
|
next_token_ids
|
||||||
@@ -641,8 +644,8 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert batch.extend_num_tokens != 0
|
assert batch.extend_num_tokens != 0
|
||||||
forward_batch = batch.get_forward_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
embeddings = self.tp_worker.forward_batch_embedding(forward_batch)
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
@@ -759,9 +762,7 @@ class Scheduler:
|
|||||||
|
|
||||||
# Check for jump-forward
|
# Check for jump-forward
|
||||||
if not self.disable_regex_jump_forward:
|
if not self.disable_regex_jump_forward:
|
||||||
jump_forward_reqs = batch.check_for_jump_forward(
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||||
self.tp_worker.model_runner
|
|
||||||
)
|
|
||||||
self.waiting_queue.extend(jump_forward_reqs)
|
self.waiting_queue.extend(jump_forward_reqs)
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
return
|
return
|
||||||
@@ -771,9 +772,9 @@ class Scheduler:
|
|||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
|
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
forward_batch = batch.get_forward_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||||
forward_batch, batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
next_token_ids
|
next_token_ids
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import logging
|
|||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
||||||
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -108,12 +109,14 @@ class TpModelWorker:
|
|||||||
self.random_seed,
|
self.random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_batch_generation(self, forward_batch: ForwardBatch, batch):
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||||
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
next_token_ids = self.model_runner.sample(logits_output, batch)
|
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids
|
||||||
|
|
||||||
def forward_batch_embedding(self, forward_batch: ForwardBatch):
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||||
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
embeddings = logits_output.embeddings.tolist()
|
embeddings = logits_output.embeddings.tolist()
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|||||||
@@ -15,18 +15,33 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""Meta data for a forward pass."""
|
"""
|
||||||
|
Store information about a forward batch.
|
||||||
|
|
||||||
|
The following is the flow of data structures for a batch:
|
||||||
|
|
||||||
|
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||||
|
|
||||||
|
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
||||||
|
It contains high-level scheduling data. Most of the data is on the CPU.
|
||||||
|
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
||||||
|
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
||||||
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.attention_backend import AttentionBackend
|
from sglang.srt.layers.attention_backend import AttentionBackend
|
||||||
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(IntEnum):
|
class ForwardMode(IntEnum):
|
||||||
@@ -69,25 +84,28 @@ class ForwardBatch:
|
|||||||
# The indices of output tokens in the token_to_kv_pool
|
# The indices of output tokens in the token_to_kv_pool
|
||||||
out_cache_loc: torch.Tensor
|
out_cache_loc: torch.Tensor
|
||||||
|
|
||||||
|
# For logprob
|
||||||
|
return_logprob: bool = False
|
||||||
|
top_logprobs_nums: Optional[List[int]] = None
|
||||||
|
|
||||||
# Position information
|
# Position information
|
||||||
positions: torch.Tensor = None
|
positions: torch.Tensor = None
|
||||||
|
|
||||||
# For extend
|
# For extend
|
||||||
extend_seq_lens: torch.Tensor = None
|
extend_seq_lens: Optional[torch.Tensor] = None
|
||||||
extend_prefix_lens: torch.Tensor = None
|
extend_prefix_lens: Optional[torch.Tensor] = None
|
||||||
extend_start_loc: torch.Tensor = None
|
extend_start_loc: Optional[torch.Tensor] = None
|
||||||
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||||
# For logprob
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||||
return_logprob: bool = False
|
|
||||||
top_logprobs_nums: List[int] = None
|
|
||||||
extend_seq_lens_cpu: List[int] = None
|
|
||||||
extend_logprob_start_lens_cpu: List[int] = None
|
|
||||||
|
|
||||||
# For multimodal
|
# For multimodal
|
||||||
image_inputs: List[ImageInputs] = None
|
image_inputs: Optional[List[ImageInputs]] = None
|
||||||
|
|
||||||
# For LoRA
|
# For LoRA
|
||||||
lora_paths: List[str] = None
|
lora_paths: Optional[List[str]] = None
|
||||||
|
|
||||||
|
# Sampling info
|
||||||
|
sampling_info: SamplingBatchInfo = None
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
req_to_token_pool: ReqToTokenPool = None
|
req_to_token_pool: ReqToTokenPool = None
|
||||||
@@ -95,42 +113,61 @@ class ForwardBatch:
|
|||||||
attn_backend: AttentionBackend = None
|
attn_backend: AttentionBackend = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(
|
def init_new(
|
||||||
cls,
|
cls,
|
||||||
batch: ScheduleBatch,
|
batch: ModelWorkerBatch,
|
||||||
|
model_runner: ModelRunner,
|
||||||
):
|
):
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=batch.forward_mode,
|
forward_mode=batch.forward_mode,
|
||||||
batch_size=batch.batch_size(),
|
batch_size=len(batch.seq_lens),
|
||||||
input_ids=batch.input_ids,
|
input_ids=torch.tensor(batch.input_ids, dtype=torch.int32, device=device),
|
||||||
req_pool_indices=batch.req_pool_indices,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
seq_lens=batch.seq_lens,
|
seq_lens=batch.seq_lens,
|
||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
lora_paths=[req.lora_path for req in batch.reqs],
|
lora_paths=batch.lora_paths,
|
||||||
|
sampling_info=batch.sampling_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Init position information
|
||||||
if ret.forward_mode.is_decode():
|
if ret.forward_mode.is_decode():
|
||||||
ret.positions = (ret.seq_lens - 1).to(torch.int64)
|
ret.positions = (ret.seq_lens - 1).to(torch.int64)
|
||||||
else:
|
else:
|
||||||
ret.positions = torch.tensor(
|
ret.positions = torch.tensor(
|
||||||
np.concatenate(
|
np.concatenate(
|
||||||
[
|
[
|
||||||
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
|
np.arange(prefix_len, prefix_len + extend_len)
|
||||||
for i, req in enumerate(batch.reqs)
|
for prefix_len, extend_len in zip(
|
||||||
|
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||||
|
)
|
||||||
],
|
],
|
||||||
axis=0,
|
axis=0,
|
||||||
),
|
),
|
||||||
device="cuda",
|
device=device,
|
||||||
).to(torch.int64)
|
).to(torch.int64)
|
||||||
|
|
||||||
ret.image_inputs = [r.image_inputs for r in batch.reqs]
|
ret.image_inputs = batch.image_inputs
|
||||||
ret.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
|
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
|
||||||
ret.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
ret.extend_prefix_lens = torch.tensor(
|
||||||
|
batch.extend_prefix_lens, device=device
|
||||||
|
)
|
||||||
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
||||||
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
||||||
ret.extend_seq_lens_cpu = batch.extend_lens_cpu
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||||
|
|
||||||
|
# Init attention information
|
||||||
|
ret.req_to_token_pool = model_runner.req_to_token_pool
|
||||||
|
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||||
|
ret.attn_backend = model_runner.attn_backend
|
||||||
|
model_runner.attn_backend.init_forward_metadata(ret)
|
||||||
|
|
||||||
|
# Init lora information
|
||||||
|
if model_runner.server_args.lora_paths is not None:
|
||||||
|
model_runner.lora_manager.prepare_lora_batch(ret)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import importlib.resources
|
|||||||
import logging
|
import logging
|
||||||
import pkgutil
|
import pkgutil
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, Tuple, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -38,11 +38,12 @@ from vllm.model_executor.model_loader import get_model
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||||
|
from sglang.srt.constrained import disable_cache
|
||||||
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import Sampler
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.lora.lora_manager import LoRAManager
|
from sglang.srt.lora.lora_manager import LoRAManager
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
MHATokenToKVPool,
|
MHATokenToKVPool,
|
||||||
MLATokenToKVPool,
|
MLATokenToKVPool,
|
||||||
@@ -52,6 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
|
enable_show_time_cost,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
is_generation_model,
|
is_generation_model,
|
||||||
is_multimodal_model,
|
is_multimodal_model,
|
||||||
@@ -102,6 +104,12 @@ class ModelRunner:
|
|||||||
server_args.chunked_prefill_size = None
|
server_args.chunked_prefill_size = None
|
||||||
server_args.mem_fraction_static *= 0.95
|
server_args.mem_fraction_static *= 0.95
|
||||||
|
|
||||||
|
# Global vars
|
||||||
|
if server_args.show_time_cost:
|
||||||
|
enable_show_time_cost()
|
||||||
|
if server_args.disable_disk_cache:
|
||||||
|
disable_cache()
|
||||||
|
|
||||||
global_server_args_dict.update(
|
global_server_args_dict.update(
|
||||||
{
|
{
|
||||||
"attention_backend": server_args.attention_backend,
|
"attention_backend": server_args.attention_backend,
|
||||||
@@ -491,16 +499,6 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
||||||
# Attach attention information
|
|
||||||
forward_batch.req_to_token_pool = self.req_to_token_pool
|
|
||||||
forward_batch.token_to_kv_pool = self.token_to_kv_pool
|
|
||||||
forward_batch.attn_backend = self.attn_backend
|
|
||||||
forward_batch.attn_backend.init_forward_metadata(forward_batch)
|
|
||||||
|
|
||||||
# Attach lora information
|
|
||||||
if self.server_args.lora_paths is not None:
|
|
||||||
self.lora_manager.prepare_lora_batch(forward_batch)
|
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
return self.forward_decode(forward_batch)
|
return self.forward_decode(forward_batch)
|
||||||
elif forward_batch.forward_mode.is_extend():
|
elif forward_batch.forward_mode.is_extend():
|
||||||
@@ -508,16 +506,27 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
||||||
|
|
||||||
def _apply_logits_bias(
|
def sample(
|
||||||
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||||
):
|
) -> torch.Tensor:
|
||||||
|
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||||
|
sampling_info = forward_batch.sampling_info
|
||||||
|
sampling_info.update_regex_vocab_mask()
|
||||||
|
sampling_info.update_penalties()
|
||||||
|
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
||||||
|
|
||||||
|
# Sample the next tokens.
|
||||||
|
next_token_ids = self.sampler(logits, sampling_info)
|
||||||
|
return next_token_ids
|
||||||
|
|
||||||
|
def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||||
# Apply logit_bias
|
# Apply logit_bias
|
||||||
if sampling_info.logit_bias is not None:
|
if sampling_info.logit_bias is not None:
|
||||||
logits.add_(sampling_info.logit_bias)
|
logits.add_(sampling_info.logit_bias)
|
||||||
|
|
||||||
# min-token, presence, frequency
|
# min-token, presence, frequency
|
||||||
if sampling_info.linear_penalties is not None:
|
if sampling_info.linear_penalties is not None:
|
||||||
logits += sampling_info.linear_penalties
|
logits.add_(sampling_info.linear_penalties)
|
||||||
|
|
||||||
# repetition
|
# repetition
|
||||||
if sampling_info.scaling_penalties is not None:
|
if sampling_info.scaling_penalties is not None:
|
||||||
@@ -533,20 +542,6 @@ class ModelRunner:
|
|||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def sample(
|
|
||||||
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
|
||||||
batch.sampling_info.update_regex_vocab_mask(batch)
|
|
||||||
batch.sampling_info.update_penalties()
|
|
||||||
logits = self._apply_logits_bias(
|
|
||||||
logits_output.next_token_logits, batch.sampling_info
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sample the next tokens.
|
|
||||||
next_token_ids = self.sampler(logits, batch.sampling_info)
|
|
||||||
return next_token_ids
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def import_model_classes():
|
def import_model_classes():
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import sglang.srt.sampling.penaltylib as penaltylib
|
import sglang.srt.sampling.penaltylib as penaltylib
|
||||||
|
from sglang.srt.constrained import RegexGuide
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
@@ -22,13 +23,17 @@ class SamplingBatchInfo:
|
|||||||
top_ks: torch.Tensor = None
|
top_ks: torch.Tensor = None
|
||||||
min_ps: torch.Tensor = None
|
min_ps: torch.Tensor = None
|
||||||
|
|
||||||
# Dispatch in CUDA graph
|
|
||||||
need_min_p_sampling: bool = False
|
|
||||||
|
|
||||||
# Bias Tensors
|
# Bias Tensors
|
||||||
logit_bias: torch.Tensor = None
|
logit_bias: torch.Tensor = None
|
||||||
vocab_mask: torch.Tensor = None
|
vocab_mask: torch.Tensor = None
|
||||||
|
|
||||||
|
# FSM states
|
||||||
|
regex_fsms: List[RegexGuide] = None
|
||||||
|
regex_fsm_states: List[int] = None
|
||||||
|
|
||||||
|
# Dispatch in CUDA graph
|
||||||
|
need_min_p_sampling: bool = False
|
||||||
|
|
||||||
# Penalizer
|
# Penalizer
|
||||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||||
linear_penalties: torch.Tensor = None
|
linear_penalties: torch.Tensor = None
|
||||||
@@ -54,6 +59,8 @@ class SamplingBatchInfo:
|
|||||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ret.regex_fsms = [r.regex_fsm for r in reqs]
|
||||||
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||||
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
||||||
|
|
||||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||||
@@ -102,24 +109,22 @@ class SamplingBatchInfo:
|
|||||||
)
|
)
|
||||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||||
|
|
||||||
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
def update_regex_vocab_mask(self):
|
||||||
has_regex = any(req.regex_fsm is not None for req in batch.reqs)
|
|
||||||
|
|
||||||
# Reset the vocab mask
|
# Reset the vocab mask
|
||||||
self.vocab_mask = None
|
self.vocab_mask = None
|
||||||
|
|
||||||
if has_regex:
|
if any(regex_fsm is not None for regex_fsm in self.regex_fsms):
|
||||||
self.vocab_mask = torch.zeros(
|
self.vocab_mask = torch.zeros(
|
||||||
batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
|
len(self.regex_fsms), self.vocab_size, dtype=torch.bool, device="cuda"
|
||||||
)
|
)
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, regex_fsm in enumerate(self.regex_fsms):
|
||||||
if req.regex_fsm is not None:
|
if regex_fsm is not None:
|
||||||
self.vocab_mask[i].fill_(1)
|
self.vocab_mask[i].fill_(1)
|
||||||
self.vocab_mask[i][
|
self.vocab_mask[i][
|
||||||
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
||||||
] = 0
|
] = 0
|
||||||
|
|
||||||
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||||
|
|
||||||
for item in [
|
for item in [
|
||||||
@@ -129,9 +134,11 @@ class SamplingBatchInfo:
|
|||||||
"min_ps",
|
"min_ps",
|
||||||
"logit_bias",
|
"logit_bias",
|
||||||
]:
|
]:
|
||||||
self_val = getattr(self, item, None)
|
value = getattr(self, item, None)
|
||||||
if self_val is not None: # logit_bias can be None
|
if value is not None: # logit_bias can be None
|
||||||
setattr(self, item, self_val[new_indices])
|
setattr(self, item, value[new_indices])
|
||||||
|
|
||||||
|
self.regex_fsms = [self.regex_fsms[i] for i in new_indices]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def merge_bias_tensor(
|
def merge_bias_tensor(
|
||||||
@@ -153,7 +160,7 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def merge(self, other: "SamplingBatchInfo"):
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
||||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||||
|
|
||||||
for item in [
|
for item in [
|
||||||
@@ -169,3 +176,5 @@ class SamplingBatchInfo:
|
|||||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||||
self.logit_bias, other.logit_bias, len(self), len(other)
|
self.logit_bias, other.logit_bias, len(self), len(other)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.regex_fsms.extend(other.regex_fsms)
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.srt.constrained import disable_cache
|
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
@@ -72,8 +71,6 @@ from sglang.srt.utils import (
|
|||||||
allocate_init_ports,
|
allocate_init_ports,
|
||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
enable_show_time_cost,
|
|
||||||
is_hip,
|
|
||||||
kill_child_process,
|
kill_child_process,
|
||||||
maybe_set_triton_cache_manager,
|
maybe_set_triton_cache_manager,
|
||||||
prepare_model_and_tokenizer,
|
prepare_model_and_tokenizer,
|
||||||
@@ -400,14 +397,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
# Set ulimit
|
# Set ulimit
|
||||||
set_ulimit()
|
set_ulimit()
|
||||||
|
|
||||||
# Enable show time cost for debugging
|
|
||||||
if server_args.show_time_cost:
|
|
||||||
enable_show_time_cost()
|
|
||||||
|
|
||||||
# Disable disk cache
|
|
||||||
if server_args.disable_disk_cache:
|
|
||||||
disable_cache()
|
|
||||||
|
|
||||||
# Fix triton bugs
|
# Fix triton bugs
|
||||||
if server_args.tp_size * server_args.dp_size > 1:
|
if server_args.tp_size * server_args.dp_size > 1:
|
||||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||||
|
|||||||
Reference in New Issue
Block a user