hotfix: revert sampler CUDA Graph (#1242)

This commit is contained in:
Yineng Zhang
2024-08-28 21:16:47 +10:00
committed by GitHub
parent 184ae1c683
commit f25f4dfde5
33 changed files with 119 additions and 348 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "sglang"
version = "0.2.14"
version = "0.2.14.post1"
description = "SGLang is yet another fast serving framework for large language models and vision language models."
readme = "README.md"
requires-python = ">=3.8"

View File

@@ -200,14 +200,16 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
return sample_output.batch_next_token_ids, logits_output.next_token_logits, batch
output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids.cpu().numpy())
sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE)
return sample_output.batch_next_token_ids, logits_output.next_token_logits
output = model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits
@torch.inference_mode()

View File

@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
@dataclasses.dataclass
class LogitsProcessorOutput:
class LogitProcessorOutput:
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
# Return only last_logits if logprob is not requested
if not logits_metadata.return_logprob:
return LogitsProcessorOutput(
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=None,
normalized_prompt_logprobs=None,
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
else:
output_top_logprobs = None
return LogitsProcessorOutput(
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None,
@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
# Remove the last token logprob for the prefill tokens.
input_token_logprobs = input_token_logprobs[:-1]
return LogitsProcessorOutput(
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs,

View File

@@ -1,6 +1,4 @@
import dataclasses
import logging
from typing import Union
import torch
from flashinfer.sampling import (
@@ -11,8 +9,6 @@ from flashinfer.sampling import (
)
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
# TODO: move this dict to another place
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -20,71 +16,30 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class SampleOutput:
success: torch.Tensor
probs: torch.Tensor
batch_next_token_ids: torch.Tensor
class Sampler(CustomOp):
def __init__(self):
super().__init__()
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties
# repetition
if sampling_info.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / sampling_info.scaling_penalties,
logits * sampling_info.scaling_penalties,
)
return logits
def _get_probs(
self,
logits: torch.Tensor,
sampling_info: SamplingBatchInfo,
is_torch_compile: bool = False,
):
def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# Post process logits
logits = logits.contiguous()
logits.div_(sampling_info.temperatures)
if is_torch_compile:
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits.add_(0)
if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)
if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
logits = self._apply_penalties(logits, sampling_info)
logits = sampling_info.penalizer_orchestrator.apply(logits)
return torch.softmax(logits, dim=-1)
def forward_cuda(
self,
logits: Union[torch.Tensor, LogitsProcessorOutput],
sampling_info: SamplingBatchInfo,
):
if isinstance(logits, LogitsProcessorOutput):
logits = logits.next_token_logits
probs = self._get_probs(logits, sampling_info)
probs = torch.softmax(logits, dim=-1)
if not global_server_args_dict["disable_flashinfer_sampling"]:
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
if sampling_info.need_min_p_sampling:
if sampling_info.min_ps.any():
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs(
@@ -100,23 +55,18 @@ class Sampler(CustomOp):
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
)
return SampleOutput(success, probs, batch_next_token_ids)
if not torch.all(success):
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
success, batch_next_token_ids, argmax_ids
)
def forward_native(
self,
logits: Union[torch.Tensor, LogitsProcessorOutput],
sampling_info: SamplingBatchInfo,
):
if isinstance(logits, LogitsProcessorOutput):
logits = logits.next_token_logits
return batch_next_token_ids
probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
)
return SampleOutput(success, probs, batch_next_token_ids)
def forward_native():
raise NotImplementedError("Native forward is not implemented yet.")
def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -137,10 +87,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try:
# FIXME: torch.multiomial does not support num_samples = 1
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
:, :1
]
sampled_index = torch.multinomial(probs_sort, num_samples=1)
except RuntimeError as e:
logger.warning(f"Sampling error: {e}")
batch_next_token_ids = torch.zeros(

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,7 +17,7 @@ limitations under the License.
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Optional, Union
import torch
@@ -31,10 +29,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
if TYPE_CHECKING:
from sglang.srt.layers.sampler import SampleOutput
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
@@ -684,17 +678,11 @@ class ScheduleBatch:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs)
def check_sample_results(self, sample_output: SampleOutput):
if not torch.all(sample_output.success):
probs = sample_output.probs
batch_next_token_ids = sample_output.batch_next_token_ids
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
sample_output.success, batch_next_token_ids, argmax_ids
)
sample_output.probs = probs
sample_output.batch_next_token_ids = batch_next_token_ids
def sample(self, logits: torch.Tensor):
from sglang.srt.layers.sampler import Sampler
return sample_output.batch_next_token_ids
sampler = Sampler()
batch_next_token_ids = sampler(logits, self.sampling_info)
return batch_next_token_ids

View File

@@ -31,7 +31,7 @@ from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
@@ -505,29 +505,21 @@ class ModelTpServer:
if self.model_runner.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
sample_output, logits_output = self.model_runner.forward(
batch, ForwardMode.EXTEND
)
next_token_ids = batch.check_sample_results(sample_output)
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
# Move logprobs to cpu
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs[
torch.arange(
len(next_token_ids), device=next_token_ids.device
),
next_token_ids,
].tolist()
)
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
if output.next_token_logprobs is not None:
output.next_token_logprobs = output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
output.input_token_logprobs = output.input_token_logprobs.tolist()
output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
@@ -566,14 +558,12 @@ class ModelTpServer:
self.req_to_token_pool.free(req.req_pool_idx)
if req.return_logprob:
self.add_logprob_return_values(
i, req, pt, next_token_ids, logits_output
)
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
pt += req.extend_input_len
else:
assert batch.extend_num_tokens != 0
logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
embeddings = logits_output.embeddings.tolist()
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
embeddings = output.embeddings.tolist()
# Check finish conditions
for i, req in enumerate(batch.reqs):
@@ -601,7 +591,7 @@ class ModelTpServer:
req: Req,
pt: int,
next_token_ids: List[int],
output: LogitsProcessorOutput,
output: LogitProcessorOutput,
):
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
@@ -683,17 +673,15 @@ class ModelTpServer:
batch.prepare_for_decode()
# Forward and sample the next tokens
sample_output, logits_output = self.model_runner.forward(
batch, ForwardMode.DECODE
)
next_token_ids = batch.check_sample_results(sample_output)
output = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = batch.sample(output.next_token_logits)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
# Move logprobs to cpu
if logits_output.next_token_logprobs is not None:
next_token_logprobs = logits_output.next_token_logprobs[
if output.next_token_logprobs is not None:
next_token_logprobs = output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
@@ -719,7 +707,7 @@ class ModelTpServer:
(next_token_logprobs[i], next_token_id)
)
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
req.output_top_logprobs.append(output.output_top_logprobs[i])
self.handle_finished_requests(batch)

View File

@@ -26,18 +26,16 @@ from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.logits_processor import (
LogitProcessorOutput,
LogitsMetadata,
LogitsProcessor,
LogitsProcessorOutput,
)
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import (
ForwardMode,
InputMetadata,
update_flashinfer_indices,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import monkey_patch_vllm_all_gather
@@ -146,10 +144,6 @@ class CudaGraphRunner:
self.flashinfer_kv_indices.clone(),
]
# Sampling inputs
vocab_size = model_runner.model_config.vocab_size
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
if use_torch_compile:
@@ -241,7 +235,6 @@ class CudaGraphRunner:
def run_once():
input_metadata = InputMetadata(
forward_mode=ForwardMode.DECODE,
sampling_info=self.sampling_info[:bs],
batch_size=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
@@ -306,35 +299,27 @@ class CudaGraphRunner:
self.flashinfer_handlers[bs],
)
# Sampling inputs
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
# Replay
torch.cuda.synchronize()
self.graphs[bs].replay()
torch.cuda.synchronize()
sample_output, logits_output = self.output_buffers[bs]
output = self.output_buffers[bs]
# Unpad
if bs != raw_bs:
logits_output = LogitsProcessorOutput(
next_token_logits=logits_output.next_token_logits[:raw_bs],
output = LogitProcessorOutput(
next_token_logits=output.next_token_logits[:raw_bs],
next_token_logprobs=None,
normalized_prompt_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
sample_output = SampleOutput(
sample_output.success[:raw_bs],
sample_output.probs[:raw_bs],
sample_output.batch_next_token_ids[:raw_bs],
)
# Extract logprobs
if batch.return_logprob:
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
output.next_token_logprobs = torch.nn.functional.log_softmax(
output.next_token_logits, dim=-1
)
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
if return_top_logprob:
@@ -342,8 +327,8 @@ class CudaGraphRunner:
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=batch.top_logprobs_nums,
)
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
logits_output.next_token_logprobs, logits_metadata
output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
output.next_token_logprobs, logits_metadata
)[1]
return sample_output, logits_output
return output

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,7 +16,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models."""
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Optional
import numpy as np
import torch
@@ -28,7 +26,6 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
class ForwardMode(IntEnum):
@@ -45,7 +42,6 @@ class InputMetadata:
"""Store all inforamtion of a forward pass."""
forward_mode: ForwardMode
sampling_info: SamplingBatchInfo
batch_size: int
req_pool_indices: torch.Tensor
seq_lens: torch.Tensor
@@ -183,7 +179,6 @@ class InputMetadata:
):
ret = cls(
forward_mode=forward_mode,
sampling_info=batch.sampling_info,
batch_size=batch.batch_size(),
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
@@ -194,8 +189,6 @@ class InputMetadata:
top_logprobs_nums=batch.top_logprobs_nums,
)
ret.sampling_info.prepare_penalties()
ret.compute_positions(batch)
ret.compute_extend_infos(batch)

View File

@@ -21,7 +21,7 @@ import importlib.resources
import logging
import pkgutil
from functools import lru_cache
from typing import Optional, Tuple, Type
from typing import Optional, Type
import torch
import torch.nn as nn
@@ -44,8 +44,6 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
@@ -517,11 +515,7 @@ class ModelRunner:
@torch.inference_mode()
def forward_decode(self, batch: ScheduleBatch):
if (
self.cuda_graph_runner
and self.cuda_graph_runner.can_run(len(batch.reqs))
and not batch.sampling_info.has_bias()
):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch(
@@ -570,9 +564,7 @@ class ModelRunner:
input_metadata.image_offsets,
)
def forward(
self, batch: ScheduleBatch, forward_mode: ForwardMode
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE:

View File

@@ -31,18 +31,20 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
LoraConfig = None
@@ -381,11 +383,17 @@ class ChatGLMForCausalLM(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))

View File

@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.model = CohereModel(config, quant_config)
@torch.no_grad()
@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
positions,
input_metadata,
)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [

View File

@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -45,7 +45,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -633,7 +632,6 @@ class DeepseekV2ForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward(
self,
@@ -642,11 +640,9 @@ class DeepseekV2ForCausalLM(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -37,7 +37,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
self.quant_config = quant_config
self.model = GemmaModel(config, quant_config=quant_config)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return (sample_output, logits_output)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -397,7 +396,6 @@ class Gemma2ForCausalLM(nn.Module):
self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@@ -408,11 +406,9 @@ class Gemma2ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)

View File

@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))

View File

@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -298,7 +297,6 @@ class Grok1ModelForCausalLM(nn.Module):
self.model = Grok1Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@@ -315,11 +313,9 @@ class Grok1ModelForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
self.model = InternLM2Model(config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.output.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -39,9 +39,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -303,7 +302,6 @@ class LlamaForCausalLM(nn.Module):
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@@ -312,13 +310,11 @@ class LlamaForCausalLM(nn.Module):
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput:
) -> LogitProcessorOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_module_name(self, name):
stacked_params_mapping = [

View File

@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama2 import LlamaModel
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
(input_metadata.batch_size, self.config.classification_out_size)
).to(input_ids.device)
return LogitsProcessorOutput(
return LogitProcessorOutput(
next_token_logits=scores,
next_token_logprobs=scores,
normalized_prompt_logprobs=scores,

View File

@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
self.scale_width = self.config.hidden_size / self.config.dim_model_base
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -300,7 +299,6 @@ class MixtralForCausalLM(nn.Module):
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward(
self,
@@ -310,11 +308,9 @@ class MixtralForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module):
self.model = MixtralModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@@ -262,11 +260,10 @@ class QWenLMHeadModel(nn.Module):
input_metadata: InputMetadata,
):
hidden_states = self.transformer(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
next_tokens = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -38,9 +38,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
Qwen2Config = None
@@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module):
self.model = Qwen2Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@torch.no_grad()
@@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
if not get_embedding:
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
else:
return self.pooler(hidden_states, input_metadata)

View File

@@ -35,8 +35,10 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
@@ -47,7 +49,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -365,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@@ -376,11 +376,20 @@ class Qwen2MoeForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def compute_logits(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
logits = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module):
self.model = StableLMEpochModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -21,63 +21,10 @@ class SamplingBatchInfo:
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
min_ps: torch.Tensor = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Bias Tensors
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None
# Penalizer
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
linear_penalties: torch.Tensor = None
scaling_penalties: torch.Tensor = None
def has_bias(self):
return (
self.logit_bias is not None
or self.vocab_mask is not None
or self.linear_penalties is not None
or self.scaling_penalties is not None
)
@classmethod
def dummy_one(cls, max_bs: int, vocab_size: int):
ret = cls(vocab_size=vocab_size)
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda")
return ret
def __getitem__(self, key):
if isinstance(key, slice):
# NOTE: We do not use cuda graph when there is bias tensors
assert not self.has_bias()
return SamplingBatchInfo(
vocab_size=self.vocab_size,
temperatures=self.temperatures[key],
top_ps=self.top_ps[key],
top_ks=self.top_ks[key],
min_ps=self.min_ps[key],
need_min_p_sampling=self.need_min_p_sampling,
)
else:
raise NotImplementedError
def inplace_assign(self, bs: int, other: SamplingBatchInfo):
# NOTE: We do not use cuda graph when there is bias tensors
assert not self.has_bias()
self.vocab_size = other.vocab_size
self.need_min_p_sampling = other.need_min_p_sampling
self.temperatures[:bs] = other.temperatures
self.top_ps[:bs] = other.top_ps
self.top_ks[:bs] = other.top_ks
self.min_ps[:bs] = other.min_ps
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
device = "cuda"
@@ -98,7 +45,6 @@ class SamplingBatchInfo:
ret.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
)
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
@@ -126,25 +72,6 @@ class SamplingBatchInfo:
return ret
def prepare_penalties(self):
self.scaling_penalties = None
self.linear_penalties = None
for penalizer in self.penalizer_orchestrator.penalizers.values():
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
if penalizer.is_prepared():
self.scaling_penalties = penalizer.cumulated_repetition_penalties
else:
if penalizer.is_prepared():
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device="cuda",
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self, batch: ScheduleBatch):
bs, reqs = batch.batch_size(), batch.reqs
device = "cuda"

View File

@@ -180,7 +180,7 @@ class SRTRunner:
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
port=port,
mem_fraction_static=0.69,
mem_fraction_static=0.7,
trust_remote_code=False,
is_embedding=not self.is_generation,
)

View File

@@ -1 +1 @@
__version__ = "0.2.14"
__version__ = "0.2.14.post1"