Clean up batch data structures: Introducing ModelWorkerBatch (#1544)

This commit is contained in:
Lianmin Zheng
2024-09-30 06:41:49 -07:00
committed by GitHub
parent 36d5acfca5
commit 63ba2f8d7b
9 changed files with 274 additions and 155 deletions

View File

@@ -15,18 +15,33 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Meta data for a forward pass."""
"""
Store information about a forward batch.
The following is the flow of data structures for a batch:
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Optional
import numpy as np
import torch
if TYPE_CHECKING:
from sglang.srt.layers.attention_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
class ForwardMode(IntEnum):
@@ -69,25 +84,28 @@ class ForwardBatch:
# The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor
# For logprob
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
# Position information
positions: torch.Tensor = None
# For extend
extend_seq_lens: torch.Tensor = None
extend_prefix_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None
# For logprob
return_logprob: bool = False
top_logprobs_nums: List[int] = None
extend_seq_lens_cpu: List[int] = None
extend_logprob_start_lens_cpu: List[int] = None
extend_seq_lens: Optional[torch.Tensor] = None
extend_prefix_lens: Optional[torch.Tensor] = None
extend_start_loc: Optional[torch.Tensor] = None
extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None
# For multimodal
image_inputs: List[ImageInputs] = None
image_inputs: Optional[List[ImageInputs]] = None
# For LoRA
lora_paths: List[str] = None
lora_paths: Optional[List[str]] = None
# Sampling info
sampling_info: SamplingBatchInfo = None
# Attention backend
req_to_token_pool: ReqToTokenPool = None
@@ -95,42 +113,61 @@ class ForwardBatch:
attn_backend: AttentionBackend = None
@classmethod
def from_schedule_batch(
def init_new(
cls,
batch: ScheduleBatch,
batch: ModelWorkerBatch,
model_runner: ModelRunner,
):
device = "cuda"
ret = cls(
forward_mode=batch.forward_mode,
batch_size=batch.batch_size(),
input_ids=batch.input_ids,
batch_size=len(batch.seq_lens),
input_ids=torch.tensor(batch.input_ids, dtype=torch.int32, device=device),
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
lora_paths=[req.lora_path for req in batch.reqs],
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
)
# Init position information
if ret.forward_mode.is_decode():
ret.positions = (ret.seq_lens - 1).to(torch.int64)
else:
ret.positions = torch.tensor(
np.concatenate(
[
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
for i, req in enumerate(batch.reqs)
np.arange(prefix_len, prefix_len + extend_len)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
),
device="cuda",
device=device,
).to(torch.int64)
ret.image_inputs = [r.image_inputs for r in batch.reqs]
ret.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
ret.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
ret.image_inputs = batch.image_inputs
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, device=device
)
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
ret.extend_seq_lens_cpu = batch.extend_lens_cpu
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
# Init attention information
ret.req_to_token_pool = model_runner.req_to_token_pool
ret.token_to_kv_pool = model_runner.token_to_kv_pool
ret.attn_backend = model_runner.attn_backend
model_runner.attn_backend.init_forward_metadata(ret)
# Init lora information
if model_runner.server_args.lora_paths is not None:
model_runner.lora_manager.prepare_lora_batch(ret)
return ret

View File

@@ -21,7 +21,7 @@ import importlib.resources
import logging
import pkgutil
from functools import lru_cache
from typing import Optional, Tuple, Type
from typing import Optional, Type
import torch
import torch.nn as nn
@@ -38,11 +38,12 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.constrained import disable_cache
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
@@ -52,6 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
is_generation_model,
is_multimodal_model,
@@ -102,6 +104,12 @@ class ModelRunner:
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
# Global vars
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_disk_cache:
disable_cache()
global_server_args_dict.update(
{
"attention_backend": server_args.attention_backend,
@@ -491,16 +499,6 @@ class ModelRunner:
)
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
# Attach attention information
forward_batch.req_to_token_pool = self.req_to_token_pool
forward_batch.token_to_kv_pool = self.token_to_kv_pool
forward_batch.attn_backend = self.attn_backend
forward_batch.attn_backend.init_forward_metadata(forward_batch)
# Attach lora information
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(forward_batch)
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
@@ -508,16 +506,27 @@ class ModelRunner:
else:
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
def _apply_logits_bias(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
):
def sample(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
) -> torch.Tensor:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info = forward_batch.sampling_info
sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties()
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
# Sample the next tokens.
next_token_ids = self.sampler(logits, sampling_info)
return next_token_ids
def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# Apply logit_bias
if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties
logits.add_(sampling_info.linear_penalties)
# repetition
if sampling_info.scaling_penalties is not None:
@@ -533,20 +542,6 @@ class ModelRunner:
return logits
def sample(
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
) -> torch.Tensor:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
batch.sampling_info.update_regex_vocab_mask(batch)
batch.sampling_info.update_penalties()
logits = self._apply_logits_bias(
logits_output.next_token_logits, batch.sampling_info
)
# Sample the next tokens.
next_token_ids = self.sampler(logits, batch.sampling_info)
return next_token_ids
@lru_cache()
def import_model_classes():