Clean up batch data structures: Introducing ModelWorkerBatch (#1544)
This commit is contained in:
@@ -15,18 +15,33 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Meta data for a forward pass."""
|
||||
"""
|
||||
Store information about a forward batch.
|
||||
|
||||
The following is the flow of data structures for a batch:
|
||||
|
||||
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
|
||||
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
||||
It contains high-level scheduling data. Most of the data is on the CPU.
|
||||
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
||||
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
@@ -69,25 +84,28 @@ class ForwardBatch:
|
||||
# The indices of output tokens in the token_to_kv_pool
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
# For logprob
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
|
||||
# Position information
|
||||
positions: torch.Tensor = None
|
||||
|
||||
# For extend
|
||||
extend_seq_lens: torch.Tensor = None
|
||||
extend_prefix_lens: torch.Tensor = None
|
||||
extend_start_loc: torch.Tensor = None
|
||||
|
||||
# For logprob
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
extend_seq_lens_cpu: List[int] = None
|
||||
extend_logprob_start_lens_cpu: List[int] = None
|
||||
extend_seq_lens: Optional[torch.Tensor] = None
|
||||
extend_prefix_lens: Optional[torch.Tensor] = None
|
||||
extend_start_loc: Optional[torch.Tensor] = None
|
||||
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||
|
||||
# For multimodal
|
||||
image_inputs: List[ImageInputs] = None
|
||||
image_inputs: Optional[List[ImageInputs]] = None
|
||||
|
||||
# For LoRA
|
||||
lora_paths: List[str] = None
|
||||
lora_paths: Optional[List[str]] = None
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
|
||||
# Attention backend
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
@@ -95,42 +113,61 @@ class ForwardBatch:
|
||||
attn_backend: AttentionBackend = None
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(
|
||||
def init_new(
|
||||
cls,
|
||||
batch: ScheduleBatch,
|
||||
batch: ModelWorkerBatch,
|
||||
model_runner: ModelRunner,
|
||||
):
|
||||
device = "cuda"
|
||||
|
||||
ret = cls(
|
||||
forward_mode=batch.forward_mode,
|
||||
batch_size=batch.batch_size(),
|
||||
input_ids=batch.input_ids,
|
||||
batch_size=len(batch.seq_lens),
|
||||
input_ids=torch.tensor(batch.input_ids, dtype=torch.int32, device=device),
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
lora_paths=[req.lora_path for req in batch.reqs],
|
||||
lora_paths=batch.lora_paths,
|
||||
sampling_info=batch.sampling_info,
|
||||
)
|
||||
|
||||
# Init position information
|
||||
if ret.forward_mode.is_decode():
|
||||
ret.positions = (ret.seq_lens - 1).to(torch.int64)
|
||||
else:
|
||||
ret.positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
|
||||
for i, req in enumerate(batch.reqs)
|
||||
np.arange(prefix_len, prefix_len + extend_len)
|
||||
for prefix_len, extend_len in zip(
|
||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device="cuda",
|
||||
device=device,
|
||||
).to(torch.int64)
|
||||
|
||||
ret.image_inputs = [r.image_inputs for r in batch.reqs]
|
||||
ret.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
|
||||
ret.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||
ret.image_inputs = batch.image_inputs
|
||||
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, device=device
|
||||
)
|
||||
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
||||
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
||||
ret.extend_seq_lens_cpu = batch.extend_lens_cpu
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
|
||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||
|
||||
# Init attention information
|
||||
ret.req_to_token_pool = model_runner.req_to_token_pool
|
||||
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||
ret.attn_backend = model_runner.attn_backend
|
||||
model_runner.attn_backend.init_forward_metadata(ret)
|
||||
|
||||
# Init lora information
|
||||
if model_runner.server_args.lora_paths is not None:
|
||||
model_runner.lora_manager.prepare_lora_batch(ret)
|
||||
|
||||
return ret
|
||||
|
||||
@@ -21,7 +21,7 @@ import importlib.resources
|
||||
import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Tuple, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -38,11 +38,12 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.constrained import disable_cache
|
||||
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
@@ -52,6 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
is_generation_model,
|
||||
is_multimodal_model,
|
||||
@@ -102,6 +104,12 @@ class ModelRunner:
|
||||
server_args.chunked_prefill_size = None
|
||||
server_args.mem_fraction_static *= 0.95
|
||||
|
||||
# Global vars
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
if server_args.disable_disk_cache:
|
||||
disable_cache()
|
||||
|
||||
global_server_args_dict.update(
|
||||
{
|
||||
"attention_backend": server_args.attention_backend,
|
||||
@@ -491,16 +499,6 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
||||
# Attach attention information
|
||||
forward_batch.req_to_token_pool = self.req_to_token_pool
|
||||
forward_batch.token_to_kv_pool = self.token_to_kv_pool
|
||||
forward_batch.attn_backend = self.attn_backend
|
||||
forward_batch.attn_backend.init_forward_metadata(forward_batch)
|
||||
|
||||
# Attach lora information
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(forward_batch)
|
||||
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
return self.forward_decode(forward_batch)
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
@@ -508,16 +506,27 @@ class ModelRunner:
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
||||
|
||||
def _apply_logits_bias(
|
||||
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
||||
):
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
sampling_info = forward_batch.sampling_info
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
sampling_info.update_penalties()
|
||||
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
||||
|
||||
# Sample the next tokens.
|
||||
next_token_ids = self.sampler(logits, sampling_info)
|
||||
return next_token_ids
|
||||
|
||||
def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||
# Apply logit_bias
|
||||
if sampling_info.logit_bias is not None:
|
||||
logits.add_(sampling_info.logit_bias)
|
||||
|
||||
# min-token, presence, frequency
|
||||
if sampling_info.linear_penalties is not None:
|
||||
logits += sampling_info.linear_penalties
|
||||
logits.add_(sampling_info.linear_penalties)
|
||||
|
||||
# repetition
|
||||
if sampling_info.scaling_penalties is not None:
|
||||
@@ -533,20 +542,6 @@ class ModelRunner:
|
||||
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
|
||||
) -> torch.Tensor:
|
||||
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
batch.sampling_info.update_regex_vocab_mask(batch)
|
||||
batch.sampling_info.update_penalties()
|
||||
logits = self._apply_logits_bias(
|
||||
logits_output.next_token_logits, batch.sampling_info
|
||||
)
|
||||
|
||||
# Sample the next tokens.
|
||||
next_token_ids = self.sampler(logits, batch.sampling_info)
|
||||
return next_token_ids
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_model_classes():
|
||||
|
||||
Reference in New Issue
Block a user