Move sampler into CUDA graph (#1201)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
Liangsheng Yin
2024-08-26 07:02:50 -07:00
committed by GitHub
parent 97589a60a2
commit 75ce37f401
28 changed files with 336 additions and 110 deletions

View File

@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
@dataclasses.dataclass @dataclasses.dataclass
class LogitProcessorOutput: class LogitsProcessorOutput:
# The logits of the next tokens. shape: [#seq, vocab_size] # The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor next_token_logits: torch.Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size] # The logprobs of the next tokens. shape: [#seq, vocab_size]
@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
# Return only last_logits if logprob is not requested # Return only last_logits if logprob is not requested
if not logits_metadata.return_logprob: if not logits_metadata.return_logprob:
return LogitProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=None, next_token_logprobs=None,
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
else: else:
output_top_logprobs = None output_top_logprobs = None
return LogitProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=last_logprobs, next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
# Remove the last token logprob for the prefill tokens. # Remove the last token logprob for the prefill tokens.
input_token_logprobs = input_token_logprobs[:-1] input_token_logprobs = input_token_logprobs[:-1]
return LogitProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=last_logprobs, next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs,

View File

@@ -1,4 +1,6 @@
import dataclasses
import logging import logging
from typing import Union
import torch import torch
from flashinfer.sampling import ( from flashinfer.sampling import (
@@ -9,6 +11,8 @@ from flashinfer.sampling import (
) )
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
# TODO: move this dict to another place # TODO: move this dict to another place
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -16,30 +20,71 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclasses.dataclass
class SampleOutput:
success: torch.Tensor
probs: torch.Tensor
batch_next_token_ids: torch.Tensor
class Sampler(CustomOp): class Sampler(CustomOp):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties
# repetition
if sampling_info.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / sampling_info.scaling_penalties,
logits * sampling_info.scaling_penalties,
)
return logits
def _get_probs(
self,
logits: torch.Tensor,
sampling_info: SamplingBatchInfo,
is_torch_compile: bool = False,
):
# Post process logits # Post process logits
logits = logits.contiguous() logits = logits.contiguous()
logits.div_(sampling_info.temperatures) logits.div_(sampling_info.temperatures)
if is_torch_compile:
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits.add_(0)
if sampling_info.logit_bias is not None: if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias) logits.add_(sampling_info.logit_bias)
if sampling_info.vocab_mask is not None: if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf")) logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
logits = sampling_info.penalizer_orchestrator.apply(logits) logits = self._apply_penalties(logits, sampling_info)
probs = torch.softmax(logits, dim=-1) return torch.softmax(logits, dim=-1)
def forward_cuda(
self,
logits: Union[torch.Tensor, LogitsProcessorOutput],
sampling_info: SamplingBatchInfo,
):
if isinstance(logits, LogitsProcessorOutput):
logits = logits.next_token_logits
probs = self._get_probs(logits, sampling_info)
if not global_server_args_dict["disable_flashinfer_sampling"]: if not global_server_args_dict["disable_flashinfer_sampling"]:
max_top_k_round, batch_size = 32, probs.shape[0] max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand( uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device (max_top_k_round, batch_size), device=probs.device
) )
if sampling_info.min_ps.any(): if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps) probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs( batch_next_token_ids, success = min_p_sampling_from_probs(
@@ -55,18 +100,23 @@ class Sampler(CustomOp):
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
) )
if not torch.all(success): return SampleOutput(success, probs, batch_next_token_ids)
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0) def forward_native(
argmax_ids = torch.argmax(probs, dim=-1) self,
batch_next_token_ids = torch.where( logits: Union[torch.Tensor, LogitsProcessorOutput],
success, batch_next_token_ids, argmax_ids sampling_info: SamplingBatchInfo,
):
if isinstance(logits, LogitsProcessorOutput):
logits = logits.next_token_logits
probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
) )
return batch_next_token_ids return SampleOutput(success, probs, batch_next_token_ids)
def forward_native():
raise NotImplementedError("Native forward is not implemented yet.")
def top_k_top_p_min_p_sampling_from_probs_torch( def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -87,7 +137,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try: try:
sampled_index = torch.multinomial(probs_sort, num_samples=1) # FIXME: torch.multiomial does not support num_samples = 1
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
:, :1
]
except RuntimeError as e: except RuntimeError as e:
logger.warning(f"Sampling error: {e}") logger.warning(f"Sampling error: {e}")
batch_next_token_ids = torch.zeros( batch_next_token_ids = torch.zeros(

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
""" """
Copyright 2023-2024 SGLang Team Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,7 +19,7 @@ limitations under the License.
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
@@ -29,6 +31,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
if TYPE_CHECKING:
from sglang.srt.layers.sampler import SampleOutput
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access # Put some global args for easy access
@@ -671,11 +677,17 @@ class ScheduleBatch:
self.top_logprobs_nums.extend(other.top_logprobs_nums) self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
def sample(self, logits: torch.Tensor): def check_sample_results(self, sample_output: SampleOutput):
from sglang.srt.layers.sampler import Sampler if not torch.all(sample_output.success):
probs = sample_output.probs
batch_next_token_ids = sample_output.batch_next_token_ids
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
sample_output.success, batch_next_token_ids, argmax_ids
)
sample_output.probs = probs
sample_output.batch_next_token_ids = batch_next_token_ids
sampler = Sampler() return sample_output.batch_next_token_ids
batch_next_token_ids = sampler(logits, self.sampling_info)
return batch_next_token_ids

View File

@@ -31,7 +31,7 @@ from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOut,
@@ -486,21 +486,29 @@ class ModelTpServer:
if self.model_runner.is_generation: if self.model_runner.is_generation:
# Forward and sample the next tokens # Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND) sample_output, logits_output = self.model_runner.forward(
next_token_ids = batch.sample(output.next_token_logits) batch, ForwardMode.EXTEND
)
next_token_ids = batch.check_sample_results(sample_output)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
) )
# Move logprobs to cpu # Move logprobs to cpu
if output.next_token_logprobs is not None: if logits_output.next_token_logprobs is not None:
output.next_token_logprobs = output.next_token_logprobs[ logits_output.next_token_logprobs = (
torch.arange(len(next_token_ids), device=next_token_ids.device), logits_output.next_token_logprobs[
torch.arange(
len(next_token_ids), device=next_token_ids.device
),
next_token_ids, next_token_ids,
].tolist() ].tolist()
output.input_token_logprobs = output.input_token_logprobs.tolist() )
output.normalized_prompt_logprobs = ( logits_output.input_token_logprobs = (
output.normalized_prompt_logprobs.tolist() logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
) )
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
@@ -539,12 +547,14 @@ class ModelTpServer:
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
if req.return_logprob: if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output) self.add_logprob_return_values(
i, req, pt, next_token_ids, logits_output
)
pt += req.extend_input_len pt += req.extend_input_len
else: else:
assert batch.extend_num_tokens != 0 assert batch.extend_num_tokens != 0
output = self.model_runner.forward(batch, ForwardMode.EXTEND) logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
embeddings = output.embeddings.tolist() embeddings = logits_output.embeddings.tolist()
# Check finish conditions # Check finish conditions
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
@@ -572,7 +582,7 @@ class ModelTpServer:
req: Req, req: Req,
pt: int, pt: int,
next_token_ids: List[int], next_token_ids: List[int],
output: LogitProcessorOutput, output: LogitsProcessorOutput,
): ):
if req.normalized_prompt_logprob is None: if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
@@ -654,15 +664,17 @@ class ModelTpServer:
batch.prepare_for_decode() batch.prepare_for_decode()
# Forward and sample the next tokens # Forward and sample the next tokens
output = self.model_runner.forward(batch, ForwardMode.DECODE) sample_output, logits_output = self.model_runner.forward(
next_token_ids = batch.sample(output.next_token_logits) batch, ForwardMode.DECODE
)
next_token_ids = batch.check_sample_results(sample_output)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
) )
# Move logprobs to cpu # Move logprobs to cpu
if output.next_token_logprobs is not None: if logits_output.next_token_logprobs is not None:
next_token_logprobs = output.next_token_logprobs[ next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device), torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids, next_token_ids,
].tolist() ].tolist()
@@ -688,7 +700,7 @@ class ModelTpServer:
(next_token_logprobs[i], next_token_id) (next_token_logprobs[i], next_token_id)
) )
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.output_top_logprobs.append(output.output_top_logprobs[i]) req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
self.handle_finished_requests(batch) self.handle_finished_requests(batch)

View File

@@ -25,16 +25,18 @@ from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.logits_processor import ( from sglang.srt.layers.logits_processor import (
LogitProcessorOutput,
LogitsMetadata, LogitsMetadata,
LogitsProcessor, LogitsProcessor,
LogitsProcessorOutput,
) )
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
ForwardMode, ForwardMode,
InputMetadata, InputMetadata,
update_flashinfer_indices, update_flashinfer_indices,
) )
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import monkey_patch_vllm_all_gather from sglang.srt.utils import monkey_patch_vllm_all_gather
@@ -143,6 +145,10 @@ class CudaGraphRunner:
self.flashinfer_kv_indices.clone(), self.flashinfer_kv_indices.clone(),
] ]
# Sampling inputs
vocab_size = model_runner.model_config.vocab_size
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
if use_torch_compile: if use_torch_compile:
@@ -234,6 +240,7 @@ class CudaGraphRunner:
def run_once(): def run_once():
input_metadata = InputMetadata( input_metadata = InputMetadata(
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
sampling_info=self.sampling_info[:bs],
batch_size=bs, batch_size=bs,
req_pool_indices=req_pool_indices, req_pool_indices=req_pool_indices,
seq_lens=seq_lens, seq_lens=seq_lens,
@@ -298,27 +305,35 @@ class CudaGraphRunner:
self.flashinfer_handlers[bs], self.flashinfer_handlers[bs],
) )
# Sampling inputs
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
# Replay # Replay
torch.cuda.synchronize() torch.cuda.synchronize()
self.graphs[bs].replay() self.graphs[bs].replay()
torch.cuda.synchronize() torch.cuda.synchronize()
output = self.output_buffers[bs] sample_output, logits_output = self.output_buffers[bs]
# Unpad # Unpad
if bs != raw_bs: if bs != raw_bs:
output = LogitProcessorOutput( logits_output = LogitsProcessorOutput(
next_token_logits=output.next_token_logits[:raw_bs], next_token_logits=logits_output.next_token_logits[:raw_bs],
next_token_logprobs=None, next_token_logprobs=None,
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
input_token_logprobs=None, input_token_logprobs=None,
input_top_logprobs=None, input_top_logprobs=None,
output_top_logprobs=None, output_top_logprobs=None,
) )
sample_output = SampleOutput(
sample_output.success[:raw_bs],
sample_output.probs[:raw_bs],
sample_output.batch_next_token_ids[:raw_bs],
)
# Extract logprobs # Extract logprobs
if batch.return_logprob: if batch.return_logprob:
output.next_token_logprobs = torch.nn.functional.log_softmax( logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
output.next_token_logits, dim=-1 logits_output.next_token_logits, dim=-1
) )
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums) return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
if return_top_logprob: if return_top_logprob:
@@ -326,8 +341,8 @@ class CudaGraphRunner:
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
) )
output.output_top_logprobs = LogitsProcessor.get_top_logprobs( logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
output.next_token_logprobs, logits_metadata logits_output.next_token_logprobs, logits_metadata
)[1] )[1]
return output return sample_output, logits_output

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
""" """
Copyright 2023-2024 SGLang Team Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,7 +18,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List
import numpy as np import numpy as np
import torch import torch
@@ -26,6 +28,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
@@ -42,6 +45,7 @@ class InputMetadata:
"""Store all inforamtion of a forward pass.""" """Store all inforamtion of a forward pass."""
forward_mode: ForwardMode forward_mode: ForwardMode
sampling_info: SamplingBatchInfo
batch_size: int batch_size: int
req_pool_indices: torch.Tensor req_pool_indices: torch.Tensor
seq_lens: torch.Tensor seq_lens: torch.Tensor
@@ -179,6 +183,7 @@ class InputMetadata:
): ):
ret = cls( ret = cls(
forward_mode=forward_mode, forward_mode=forward_mode,
sampling_info=batch.sampling_info,
batch_size=batch.batch_size(), batch_size=batch.batch_size(),
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
@@ -189,6 +194,8 @@ class InputMetadata:
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
) )
ret.sampling_info.prepare_penalties()
ret.compute_positions(batch) ret.compute_positions(batch)
ret.compute_extend_infos(batch) ret.compute_extend_infos(batch)

View File

@@ -21,7 +21,7 @@ import importlib.resources
import logging import logging
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache
from typing import Optional, Type from typing import Optional, Tuple, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool, MHATokenToKVPool,
@@ -514,7 +516,11 @@ class ModelRunner:
@torch.inference_mode() @torch.inference_mode()
def forward_decode(self, batch: ScheduleBatch): def forward_decode(self, batch: ScheduleBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): if (
self.cuda_graph_runner
and self.cuda_graph_runner.can_run(len(batch.reqs))
and not batch.sampling_info.has_bias()
):
return self.cuda_graph_runner.replay(batch) return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch( input_metadata = InputMetadata.from_schedule_batch(
@@ -563,7 +569,9 @@ class ModelRunner:
input_metadata.image_offsets, input_metadata.image_offsets,
) )
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode): def forward(
self, batch: ScheduleBatch, forward_mode: ForwardMode
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
return self.forward_extend_multi_modal(batch) return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE: elif forward_mode == ForwardMode.DECODE:

View File

@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
LoraConfig = None LoraConfig = None
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
def sample( return sample_output, logits_output
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))

View File

@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.model = CohereModel(config, quant_config) self.model = CohereModel(config, quant_config)
@torch.no_grad() @torch.no_grad()
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
positions, positions,
input_metadata, input_metadata,
) )
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [ expert_params_mapping = [

View File

@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata) hidden_states = self.model(input_ids, positions, input_metadata)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -45,6 +45,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -632,6 +633,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward( def forward(
self, self,
@@ -640,9 +642,11 @@ class DeepseekV2ForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata) hidden_states = self.model(input_ids, positions, input_metadata)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -37,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -287,6 +288,7 @@ class GemmaForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.model = GemmaModel(config, quant_config=quant_config) self.model = GemmaModel(config, quant_config=quant_config)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -297,9 +299,11 @@ class GemmaForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return (sample_output, logits_output)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -396,6 +397,7 @@ class Gemma2ForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config) self.model = Gemma2Model(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -406,9 +408,11 @@ class Gemma2ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_attention_sliding_window_size(self): def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config) return get_attention_sliding_window_size(self.config)

View File

@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -261,6 +262,7 @@ class GPTBigCodeForCausalLM(nn.Module):
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -270,9 +272,11 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))

View File

@@ -46,6 +46,7 @@ from sglang.srt.layers.fused_moe import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -297,6 +298,7 @@ class Grok1ModelForCausalLM(nn.Module):
self.model = Grok1Model(config, quant_config=quant_config) self.model = Grok1Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
# Monkey patch _prepare_weights to load pre-sharded weights # Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@@ -313,9 +315,11 @@ class Grok1ModelForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
self.model = InternLM2Model(config, quant_config) self.model = InternLM2Model(config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.output.weight, input_metadata input_ids, hidden_states, self.output.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -302,6 +303,7 @@ class LlamaForCausalLM(nn.Module):
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -310,11 +312,13 @@ class LlamaForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> LogitProcessorOutput: ) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_module_name(self, name): def get_module_name(self, name):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama2 import LlamaModel from sglang.srt.models.llama2 import LlamaModel
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
(input_metadata.batch_size, self.config.classification_out_size) (input_metadata.batch_size, self.config.classification_out_size)
).to(input_ids.device) ).to(input_ids.device)
return LogitProcessorOutput( return LogitsProcessorOutput(
next_token_logits=scores, next_token_logits=scores,
next_token_logprobs=scores, next_token_logprobs=scores,
normalized_prompt_logprobs=scores, normalized_prompt_logprobs=scores,

View File

@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -297,6 +298,7 @@ class MiniCPMForCausalLM(nn.Module):
self.scale_width = self.config.hidden_size / self.config.dim_model_base self.scale_width = self.config.hidden_size / self.config.dim_model_base
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -314,9 +316,11 @@ class MiniCPMForCausalLM(nn.Module):
lm_head_weight = self.model.embed_tokens.weight lm_head_weight = self.model.embed_tokens.weight
else: else:
lm_head_weight = self.lm_head.weight lm_head_weight = self.lm_head.weight
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, lm_head_weight, input_metadata input_ids, hidden_states, lm_head_weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -299,6 +300,7 @@ class MixtralForCausalLM(nn.Module):
self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward( def forward(
self, self,
@@ -308,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -333,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
self.model = MixtralModel(config, quant_config=quant_config) self.model = MixtralModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -343,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
): ):
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
next_tokens = self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
return next_tokens sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -38,8 +38,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
Qwen2Config = None Qwen2Config = None
@@ -276,6 +277,7 @@ class Qwen2ForCausalLM(nn.Module):
self.model = Qwen2Model(config, quant_config=quant_config) self.model = Qwen2Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@torch.no_grad() @torch.no_grad()
@@ -289,9 +291,11 @@ class Qwen2ForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
if not get_embedding: if not get_embedding:
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
else: else:
return self.pooler(hidden_states, input_metadata) return self.pooler(hidden_states, input_metadata)

View File

@@ -35,10 +35,8 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
@@ -49,6 +47,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
def compute_logits( return sample_output, logits_output
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
logits = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
self.model = StableLMEpochModel(config, quant_config=quant_config) self.model = StableLMEpochModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -21,10 +21,63 @@ class SamplingBatchInfo:
top_ps: torch.Tensor = None top_ps: torch.Tensor = None
top_ks: torch.Tensor = None top_ks: torch.Tensor = None
min_ps: torch.Tensor = None min_ps: torch.Tensor = None
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Bias Tensors
logit_bias: torch.Tensor = None logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None vocab_mask: torch.Tensor = None
# Penalizer
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
linear_penalties: torch.Tensor = None
scaling_penalties: torch.Tensor = None
def has_bias(self):
return (
self.logit_bias is not None
or self.vocab_mask is not None
or self.linear_penalties is not None
or self.scaling_penalties is not None
)
@classmethod
def dummy_one(cls, max_bs: int, vocab_size: int):
ret = cls(vocab_size=vocab_size)
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda")
return ret
def __getitem__(self, key):
if isinstance(key, slice):
# NOTE: We do not use cuda graph when there is bias tensors
assert not self.has_bias()
return SamplingBatchInfo(
vocab_size=self.vocab_size,
temperatures=self.temperatures[key],
top_ps=self.top_ps[key],
top_ks=self.top_ks[key],
min_ps=self.min_ps[key],
need_min_p_sampling=self.need_min_p_sampling,
)
else:
raise NotImplementedError
def inplace_assign(self, bs: int, other: SamplingBatchInfo):
# NOTE: We do not use cuda graph when there is bias tensors
assert not self.has_bias()
self.vocab_size = other.vocab_size
self.need_min_p_sampling = other.need_min_p_sampling
self.temperatures[:bs] = other.temperatures
self.top_ps[:bs] = other.top_ps
self.top_ks[:bs] = other.top_ks
self.min_ps[:bs] = other.min_ps
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
device = "cuda" device = "cuda"
@@ -45,6 +98,7 @@ class SamplingBatchInfo:
ret.min_ps = torch.tensor( ret.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
) )
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at # Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
@@ -72,6 +126,25 @@ class SamplingBatchInfo:
return ret return ret
def prepare_penalties(self):
self.scaling_penalties = None
self.linear_penalties = None
for penalizer in self.penalizer_orchestrator.penalizers.values():
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
if penalizer.is_prepared():
self.scaling_penalties = penalizer.cumulated_repetition_penalties
else:
if penalizer.is_prepared():
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device="cuda",
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self, batch: ScheduleBatch): def update_regex_vocab_mask(self, batch: ScheduleBatch):
bs, reqs = batch.batch_size(), batch.reqs bs, reqs = batch.batch_size(), batch.reqs
device = "cuda" device = "cuda"

View File

@@ -180,7 +180,7 @@ class SRTRunner:
tp_size=tp_size, tp_size=tp_size,
dtype=get_dtype_str(torch_dtype), dtype=get_dtype_str(torch_dtype),
port=port, port=port,
mem_fraction_static=0.7, mem_fraction_static=0.69,
trust_remote_code=False, trust_remote_code=False,
is_embedding=not self.is_generation, is_embedding=not self.is_generation,
) )