Organize sampling batch info better (#1562)
This commit is contained in:
@@ -96,7 +96,9 @@ class Scheduler:
|
||||
|
||||
if self.tp_rank == 0:
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.scheduler_port}")
|
||||
self.recv_from_tokenizer.bind(
|
||||
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
|
||||
)
|
||||
|
||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_detokenizer.connect(
|
||||
@@ -141,9 +143,6 @@ class Scheduler:
|
||||
nccl_port=port_args.nccl_ports[0],
|
||||
)
|
||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||
self.pad_input_ids_func = getattr(
|
||||
self.tp_worker.model_runner.model, "pad_input_ids", None
|
||||
)
|
||||
|
||||
# Get token and memory info from the tp worker
|
||||
(
|
||||
@@ -154,6 +153,9 @@ class Scheduler:
|
||||
self.random_seed,
|
||||
) = self.tp_worker.get_token_and_memory_info()
|
||||
set_random_seed(self.random_seed)
|
||||
self.pad_input_ids_func = getattr(
|
||||
self.tp_worker.model_runner.model, "pad_input_ids", None
|
||||
)
|
||||
|
||||
# Print debug info
|
||||
logger.info(
|
||||
|
||||
@@ -87,7 +87,9 @@ class TokenizerManager:
|
||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
|
||||
self.send_to_scheduler = context.socket(zmq.PUSH)
|
||||
self.send_to_scheduler.connect(f"tcp://127.0.0.1:{port_args.scheduler_port}")
|
||||
self.send_to_scheduler.connect(
|
||||
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
|
||||
)
|
||||
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
|
||||
@@ -30,6 +30,7 @@ class ReqToTokenPool:
|
||||
|
||||
def __init__(self, size: int, max_context_len: int, device: str):
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.free_slots = list(range(size))
|
||||
self.req_to_token = torch.empty(
|
||||
(size, max_context_len), dtype=torch.int32, device=device
|
||||
@@ -54,7 +55,7 @@ class ReqToTokenPool:
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class BaseTokenToKVPool(ABC):
|
||||
class BaseTokenToKVPool:
|
||||
"""A memory pool that maps a token to its kv cache locations"""
|
||||
|
||||
def __init__(
|
||||
@@ -92,19 +93,15 @@ class BaseTokenToKVPool(ABC):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_slots = np.arange(1, self.size + 1)
|
||||
|
||||
@abstractmethod
|
||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
|
||||
@@ -411,8 +411,8 @@ class ModelRunner:
|
||||
|
||||
device = "cuda"
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
max_num_reqs + 1,
|
||||
self.model_config.context_len + 4,
|
||||
size=max_num_reqs + 1,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=device,
|
||||
)
|
||||
if (
|
||||
|
||||
@@ -14,16 +14,17 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingBatchInfo:
|
||||
# Basic Info
|
||||
vocab_size: int
|
||||
|
||||
# Batched sampling params
|
||||
temperatures: torch.Tensor = None
|
||||
top_ps: torch.Tensor = None
|
||||
top_ks: torch.Tensor = None
|
||||
min_ps: torch.Tensor = None
|
||||
temperatures: torch.Tensor
|
||||
top_ps: torch.Tensor
|
||||
top_ks: torch.Tensor
|
||||
min_ps: torch.Tensor
|
||||
|
||||
# Dispatch in CUDA graph
|
||||
need_min_p_sampling: bool
|
||||
|
||||
# Bias Tensors
|
||||
vocab_size: int
|
||||
logit_bias: torch.Tensor = None
|
||||
vocab_mask: torch.Tensor = None
|
||||
|
||||
@@ -31,9 +32,6 @@ class SamplingBatchInfo:
|
||||
regex_fsms: List[RegexGuide] = None
|
||||
regex_fsm_states: List[int] = None
|
||||
|
||||
# Dispatch in CUDA graph
|
||||
need_min_p_sampling: bool = False
|
||||
|
||||
# Penalizer
|
||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||
linear_penalties: torch.Tensor = None
|
||||
@@ -42,25 +40,30 @@ class SamplingBatchInfo:
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
reqs = batch.reqs
|
||||
ret = cls(vocab_size=vocab_size)
|
||||
|
||||
with torch.device("cuda"):
|
||||
ret.temperatures = torch.tensor(
|
||||
temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
).view(-1, 1)
|
||||
ret.top_ps = torch.tensor(
|
||||
top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
||||
)
|
||||
ret.top_ks = torch.tensor(
|
||||
top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int
|
||||
)
|
||||
ret.min_ps = torch.tensor(
|
||||
min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
)
|
||||
|
||||
ret = cls(
|
||||
temperatures=temperatures,
|
||||
top_ps=top_ps,
|
||||
top_ks=top_ks,
|
||||
min_ps=min_ps,
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
|
||||
@@ -118,6 +118,7 @@ async def health_generate(request: Request) -> Response:
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
"""Get the model information."""
|
||||
result = {
|
||||
"model_path": tokenizer_manager.model_path,
|
||||
"is_generation": tokenizer_manager.is_generation,
|
||||
@@ -127,11 +128,13 @@ async def get_model_info():
|
||||
|
||||
@app.get("/get_server_args")
|
||||
async def get_server_args():
|
||||
"""Get the server arguments."""
|
||||
return dataclasses.asdict(tokenizer_manager.server_args)
|
||||
|
||||
|
||||
@app.get("/flush_cache")
|
||||
async def flush_cache():
|
||||
"""Flush the radix cache."""
|
||||
tokenizer_manager.flush_cache()
|
||||
return Response(
|
||||
content="Cache flushed.\nPlease check backend logs for more details. "
|
||||
@@ -142,7 +145,7 @@ async def flush_cache():
|
||||
|
||||
@app.post("/update_weights")
|
||||
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
||||
|
||||
"""Update the weights inplace without re-launching the server."""
|
||||
success, message = await tokenizer_manager.update_weights(obj, request)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
@@ -205,7 +208,7 @@ app.put("/encode")(encode_request)
|
||||
|
||||
|
||||
async def judge_request(obj: RewardReqInput, request: Request):
|
||||
"""Handle an embedding request."""
|
||||
"""Handle a reward model request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
@@ -307,7 +310,7 @@ def launch_server(
|
||||
ports = server_args.additional_ports
|
||||
port_args = PortArgs(
|
||||
tokenizer_port=ports[0],
|
||||
scheduler_port=ports[1],
|
||||
scheduler_input_port=ports[1],
|
||||
detokenizer_port=ports[2],
|
||||
nccl_ports=ports[3:],
|
||||
)
|
||||
|
||||
@@ -627,10 +627,11 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
class PortArgs:
|
||||
# The port for tokenizer to receive inputs from detokenizer (zmq)
|
||||
tokenizer_port: int
|
||||
# The port for scheduler to receive inputs from tokenizer (zmq)
|
||||
scheduler_port: int
|
||||
# The port for scheduler (rank 0) to receive inputs from tokenizer (zmq)
|
||||
scheduler_input_port: int
|
||||
# The port for detokenizer to receive inputs from scheduler (zmq)
|
||||
detokenizer_port: int
|
||||
|
||||
# The port for nccl initialization for multiple TP groups (torch.dist)
|
||||
nccl_ports: List[int]
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
|
||||
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
|
||||
|
||||
Reference in New Issue
Block a user