Move sampler into CUDA graph (#1201)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
|
|||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class LogitProcessorOutput:
|
class LogitsProcessorOutput:
|
||||||
# The logits of the next tokens. shape: [#seq, vocab_size]
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
||||||
next_token_logits: torch.Tensor
|
next_token_logits: torch.Tensor
|
||||||
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
||||||
@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
# Return only last_logits if logprob is not requested
|
# Return only last_logits if logprob is not requested
|
||||||
if not logits_metadata.return_logprob:
|
if not logits_metadata.return_logprob:
|
||||||
return LogitProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
next_token_logprobs=None,
|
next_token_logprobs=None,
|
||||||
normalized_prompt_logprobs=None,
|
normalized_prompt_logprobs=None,
|
||||||
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
else:
|
else:
|
||||||
output_top_logprobs = None
|
output_top_logprobs = None
|
||||||
|
|
||||||
return LogitProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
next_token_logprobs=last_logprobs,
|
next_token_logprobs=last_logprobs,
|
||||||
normalized_prompt_logprobs=None,
|
normalized_prompt_logprobs=None,
|
||||||
@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
# Remove the last token logprob for the prefill tokens.
|
# Remove the last token logprob for the prefill tokens.
|
||||||
input_token_logprobs = input_token_logprobs[:-1]
|
input_token_logprobs = input_token_logprobs[:-1]
|
||||||
|
|
||||||
return LogitProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
next_token_logprobs=last_logprobs,
|
next_token_logprobs=last_logprobs,
|
||||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer.sampling import (
|
from flashinfer.sampling import (
|
||||||
@@ -9,6 +11,8 @@ from flashinfer.sampling import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
|
|
||||||
# TODO: move this dict to another place
|
# TODO: move this dict to another place
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
@@ -16,30 +20,71 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SampleOutput:
|
||||||
|
success: torch.Tensor
|
||||||
|
probs: torch.Tensor
|
||||||
|
batch_next_token_ids: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class Sampler(CustomOp):
|
class Sampler(CustomOp):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||||
|
# min-token, presence, frequency
|
||||||
|
if sampling_info.linear_penalties is not None:
|
||||||
|
logits += sampling_info.linear_penalties
|
||||||
|
|
||||||
|
# repetition
|
||||||
|
if sampling_info.scaling_penalties is not None:
|
||||||
|
logits = torch.where(
|
||||||
|
logits > 0,
|
||||||
|
logits / sampling_info.scaling_penalties,
|
||||||
|
logits * sampling_info.scaling_penalties,
|
||||||
|
)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def _get_probs(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_info: SamplingBatchInfo,
|
||||||
|
is_torch_compile: bool = False,
|
||||||
|
):
|
||||||
# Post process logits
|
# Post process logits
|
||||||
logits = logits.contiguous()
|
logits = logits.contiguous()
|
||||||
logits.div_(sampling_info.temperatures)
|
logits.div_(sampling_info.temperatures)
|
||||||
|
if is_torch_compile:
|
||||||
|
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
||||||
|
logits.add_(0)
|
||||||
|
|
||||||
if sampling_info.logit_bias is not None:
|
if sampling_info.logit_bias is not None:
|
||||||
logits.add_(sampling_info.logit_bias)
|
logits.add_(sampling_info.logit_bias)
|
||||||
|
|
||||||
if sampling_info.vocab_mask is not None:
|
if sampling_info.vocab_mask is not None:
|
||||||
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
|
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
|
||||||
|
|
||||||
logits = sampling_info.penalizer_orchestrator.apply(logits)
|
logits = self._apply_penalties(logits, sampling_info)
|
||||||
|
|
||||||
probs = torch.softmax(logits, dim=-1)
|
return torch.softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
||||||
|
sampling_info: SamplingBatchInfo,
|
||||||
|
):
|
||||||
|
if isinstance(logits, LogitsProcessorOutput):
|
||||||
|
logits = logits.next_token_logits
|
||||||
|
|
||||||
|
probs = self._get_probs(logits, sampling_info)
|
||||||
|
|
||||||
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
uniform_samples = torch.rand(
|
uniform_samples = torch.rand(
|
||||||
(max_top_k_round, batch_size), device=probs.device
|
(max_top_k_round, batch_size), device=probs.device
|
||||||
)
|
)
|
||||||
if sampling_info.min_ps.any():
|
if sampling_info.need_min_p_sampling:
|
||||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||||
batch_next_token_ids, success = min_p_sampling_from_probs(
|
batch_next_token_ids, success = min_p_sampling_from_probs(
|
||||||
@@ -55,18 +100,23 @@ class Sampler(CustomOp):
|
|||||||
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
||||||
)
|
)
|
||||||
|
|
||||||
if not torch.all(success):
|
return SampleOutput(success, probs, batch_next_token_ids)
|
||||||
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
|
||||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
def forward_native(
|
||||||
argmax_ids = torch.argmax(probs, dim=-1)
|
self,
|
||||||
batch_next_token_ids = torch.where(
|
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
||||||
success, batch_next_token_ids, argmax_ids
|
sampling_info: SamplingBatchInfo,
|
||||||
|
):
|
||||||
|
if isinstance(logits, LogitsProcessorOutput):
|
||||||
|
logits = logits.next_token_logits
|
||||||
|
|
||||||
|
probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
|
||||||
|
|
||||||
|
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
|
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
||||||
)
|
)
|
||||||
|
|
||||||
return batch_next_token_ids
|
return SampleOutput(success, probs, batch_next_token_ids)
|
||||||
|
|
||||||
def forward_native():
|
|
||||||
raise NotImplementedError("Native forward is not implemented yet.")
|
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
@@ -87,7 +137,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|||||||
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||||
try:
|
try:
|
||||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
# FIXME: torch.multiomial does not support num_samples = 1
|
||||||
|
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
|
||||||
|
:, :1
|
||||||
|
]
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.warning(f"Sampling error: {e}")
|
logger.warning(f"Sampling error: {e}")
|
||||||
batch_next_token_ids = torch.zeros(
|
batch_next_token_ids = torch.zeros(
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Copyright 2023-2024 SGLang Team
|
Copyright 2023-2024 SGLang Team
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -17,7 +19,7 @@ limitations under the License.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -29,6 +31,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.sampler import SampleOutput
|
||||||
|
|
||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
@@ -671,11 +677,17 @@ class ScheduleBatch:
|
|||||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
|
|
||||||
def sample(self, logits: torch.Tensor):
|
def check_sample_results(self, sample_output: SampleOutput):
|
||||||
from sglang.srt.layers.sampler import Sampler
|
if not torch.all(sample_output.success):
|
||||||
|
probs = sample_output.probs
|
||||||
|
batch_next_token_ids = sample_output.batch_next_token_ids
|
||||||
|
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
||||||
|
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||||
|
argmax_ids = torch.argmax(probs, dim=-1)
|
||||||
|
batch_next_token_ids = torch.where(
|
||||||
|
sample_output.success, batch_next_token_ids, argmax_ids
|
||||||
|
)
|
||||||
|
sample_output.probs = probs
|
||||||
|
sample_output.batch_next_token_ids = batch_next_token_ids
|
||||||
|
|
||||||
sampler = Sampler()
|
return sample_output.batch_next_token_ids
|
||||||
|
|
||||||
batch_next_token_ids = sampler(logits, self.sampling_info)
|
|
||||||
|
|
||||||
return batch_next_token_ids
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
|
|||||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
@@ -486,21 +486,29 @@ class ModelTpServer:
|
|||||||
if self.model_runner.is_generation:
|
if self.model_runner.is_generation:
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
sample_output, logits_output = self.model_runner.forward(
|
||||||
next_token_ids = batch.sample(output.next_token_logits)
|
batch, ForwardMode.EXTEND
|
||||||
|
)
|
||||||
|
next_token_ids = batch.check_sample_results(sample_output)
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
next_token_ids
|
next_token_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if output.next_token_logprobs is not None:
|
if logits_output.next_token_logprobs is not None:
|
||||||
output.next_token_logprobs = output.next_token_logprobs[
|
logits_output.next_token_logprobs = (
|
||||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
logits_output.next_token_logprobs[
|
||||||
|
torch.arange(
|
||||||
|
len(next_token_ids), device=next_token_ids.device
|
||||||
|
),
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
].tolist()
|
].tolist()
|
||||||
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
)
|
||||||
output.normalized_prompt_logprobs = (
|
logits_output.input_token_logprobs = (
|
||||||
output.normalized_prompt_logprobs.tolist()
|
logits_output.input_token_logprobs.tolist()
|
||||||
|
)
|
||||||
|
logits_output.normalized_prompt_logprobs = (
|
||||||
|
logits_output.normalized_prompt_logprobs.tolist()
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
@@ -539,12 +547,14 @@ class ModelTpServer:
|
|||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
self.add_logprob_return_values(
|
||||||
|
i, req, pt, next_token_ids, logits_output
|
||||||
|
)
|
||||||
pt += req.extend_input_len
|
pt += req.extend_input_len
|
||||||
else:
|
else:
|
||||||
assert batch.extend_num_tokens != 0
|
assert batch.extend_num_tokens != 0
|
||||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||||
embeddings = output.embeddings.tolist()
|
embeddings = logits_output.embeddings.tolist()
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
@@ -572,7 +582,7 @@ class ModelTpServer:
|
|||||||
req: Req,
|
req: Req,
|
||||||
pt: int,
|
pt: int,
|
||||||
next_token_ids: List[int],
|
next_token_ids: List[int],
|
||||||
output: LogitProcessorOutput,
|
output: LogitsProcessorOutput,
|
||||||
):
|
):
|
||||||
if req.normalized_prompt_logprob is None:
|
if req.normalized_prompt_logprob is None:
|
||||||
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
||||||
@@ -654,15 +664,17 @@ class ModelTpServer:
|
|||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
|
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
sample_output, logits_output = self.model_runner.forward(
|
||||||
next_token_ids = batch.sample(output.next_token_logits)
|
batch, ForwardMode.DECODE
|
||||||
|
)
|
||||||
|
next_token_ids = batch.check_sample_results(sample_output)
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
next_token_ids
|
next_token_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if output.next_token_logprobs is not None:
|
if logits_output.next_token_logprobs is not None:
|
||||||
next_token_logprobs = output.next_token_logprobs[
|
next_token_logprobs = logits_output.next_token_logprobs[
|
||||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
].tolist()
|
].tolist()
|
||||||
@@ -688,7 +700,7 @@ class ModelTpServer:
|
|||||||
(next_token_logprobs[i], next_token_id)
|
(next_token_logprobs[i], next_token_id)
|
||||||
)
|
)
|
||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
|
|||||||
@@ -25,16 +25,18 @@ from vllm.distributed.parallel_state import graph_capture
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import (
|
from sglang.srt.layers.logits_processor import (
|
||||||
LogitProcessorOutput,
|
|
||||||
LogitsMetadata,
|
LogitsMetadata,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
|
LogitsProcessorOutput,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.sampler import SampleOutput
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
InputMetadata,
|
InputMetadata,
|
||||||
update_flashinfer_indices,
|
update_flashinfer_indices,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||||
|
|
||||||
|
|
||||||
@@ -143,6 +145,10 @@ class CudaGraphRunner:
|
|||||||
self.flashinfer_kv_indices.clone(),
|
self.flashinfer_kv_indices.clone(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Sampling inputs
|
||||||
|
vocab_size = model_runner.model_config.vocab_size
|
||||||
|
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
||||||
|
|
||||||
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
||||||
|
|
||||||
if use_torch_compile:
|
if use_torch_compile:
|
||||||
@@ -234,6 +240,7 @@ class CudaGraphRunner:
|
|||||||
def run_once():
|
def run_once():
|
||||||
input_metadata = InputMetadata(
|
input_metadata = InputMetadata(
|
||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
sampling_info=self.sampling_info[:bs],
|
||||||
batch_size=bs,
|
batch_size=bs,
|
||||||
req_pool_indices=req_pool_indices,
|
req_pool_indices=req_pool_indices,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
@@ -298,27 +305,35 @@ class CudaGraphRunner:
|
|||||||
self.flashinfer_handlers[bs],
|
self.flashinfer_handlers[bs],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Sampling inputs
|
||||||
|
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.graphs[bs].replay()
|
self.graphs[bs].replay()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
output = self.output_buffers[bs]
|
sample_output, logits_output = self.output_buffers[bs]
|
||||||
|
|
||||||
# Unpad
|
# Unpad
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
output = LogitProcessorOutput(
|
logits_output = LogitsProcessorOutput(
|
||||||
next_token_logits=output.next_token_logits[:raw_bs],
|
next_token_logits=logits_output.next_token_logits[:raw_bs],
|
||||||
next_token_logprobs=None,
|
next_token_logprobs=None,
|
||||||
normalized_prompt_logprobs=None,
|
normalized_prompt_logprobs=None,
|
||||||
input_token_logprobs=None,
|
input_token_logprobs=None,
|
||||||
input_top_logprobs=None,
|
input_top_logprobs=None,
|
||||||
output_top_logprobs=None,
|
output_top_logprobs=None,
|
||||||
)
|
)
|
||||||
|
sample_output = SampleOutput(
|
||||||
|
sample_output.success[:raw_bs],
|
||||||
|
sample_output.probs[:raw_bs],
|
||||||
|
sample_output.batch_next_token_ids[:raw_bs],
|
||||||
|
)
|
||||||
|
|
||||||
# Extract logprobs
|
# Extract logprobs
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
output.next_token_logprobs = torch.nn.functional.log_softmax(
|
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
||||||
output.next_token_logits, dim=-1
|
logits_output.next_token_logits, dim=-1
|
||||||
)
|
)
|
||||||
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
||||||
if return_top_logprob:
|
if return_top_logprob:
|
||||||
@@ -326,8 +341,8 @@ class CudaGraphRunner:
|
|||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
)
|
)
|
||||||
output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||||
output.next_token_logprobs, logits_metadata
|
logits_output.next_token_logprobs, logits_metadata
|
||||||
)[1]
|
)[1]
|
||||||
|
|
||||||
return output
|
return sample_output, logits_output
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Copyright 2023-2024 SGLang Team
|
Copyright 2023-2024 SGLang Team
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -16,7 +18,7 @@ limitations under the License.
|
|||||||
"""ModelRunner runs the forward passes of the models."""
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -26,6 +28,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(IntEnum):
|
class ForwardMode(IntEnum):
|
||||||
@@ -42,6 +45,7 @@ class InputMetadata:
|
|||||||
"""Store all inforamtion of a forward pass."""
|
"""Store all inforamtion of a forward pass."""
|
||||||
|
|
||||||
forward_mode: ForwardMode
|
forward_mode: ForwardMode
|
||||||
|
sampling_info: SamplingBatchInfo
|
||||||
batch_size: int
|
batch_size: int
|
||||||
req_pool_indices: torch.Tensor
|
req_pool_indices: torch.Tensor
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
@@ -179,6 +183,7 @@ class InputMetadata:
|
|||||||
):
|
):
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
|
sampling_info=batch.sampling_info,
|
||||||
batch_size=batch.batch_size(),
|
batch_size=batch.batch_size(),
|
||||||
req_pool_indices=batch.req_pool_indices,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
seq_lens=batch.seq_lens,
|
seq_lens=batch.seq_lens,
|
||||||
@@ -189,6 +194,8 @@ class InputMetadata:
|
|||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ret.sampling_info.prepare_penalties()
|
||||||
|
|
||||||
ret.compute_positions(batch)
|
ret.compute_positions(batch)
|
||||||
|
|
||||||
ret.compute_extend_infos(batch)
|
ret.compute_extend_infos(batch)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import importlib.resources
|
|||||||
import logging
|
import logging
|
||||||
import pkgutil
|
import pkgutil
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, Type
|
from typing import Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
|
from sglang.srt.layers.sampler import SampleOutput
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
MHATokenToKVPool,
|
MHATokenToKVPool,
|
||||||
@@ -514,7 +516,11 @@ class ModelRunner:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_decode(self, batch: ScheduleBatch):
|
def forward_decode(self, batch: ScheduleBatch):
|
||||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
if (
|
||||||
|
self.cuda_graph_runner
|
||||||
|
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
||||||
|
and not batch.sampling_info.has_bias()
|
||||||
|
):
|
||||||
return self.cuda_graph_runner.replay(batch)
|
return self.cuda_graph_runner.replay(batch)
|
||||||
|
|
||||||
input_metadata = InputMetadata.from_schedule_batch(
|
input_metadata = InputMetadata.from_schedule_batch(
|
||||||
@@ -563,7 +569,9 @@ class ModelRunner:
|
|||||||
input_metadata.image_offsets,
|
input_metadata.image_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
def forward(
|
||||||
|
self, batch: ScheduleBatch, forward_mode: ForwardMode
|
||||||
|
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
||||||
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
||||||
return self.forward_extend_multi_modal(batch)
|
return self.forward_extend_multi_modal(batch)
|
||||||
elif forward_mode == ForwardMode.DECODE:
|
elif forward_mode == ForwardMode.DECODE:
|
||||||
|
|||||||
@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
||||||
from vllm.sequence import SamplerOutput
|
|
||||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
LoraConfig = None
|
LoraConfig = None
|
||||||
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
def sample(
|
return sample_output, logits_output
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[SamplerOutput]:
|
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
self.model = CohereModel(config, quant_config)
|
self.model = CohereModel(config, quant_config)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
|
|||||||
positions,
|
positions,
|
||||||
input_metadata,
|
input_metadata,
|
||||||
)
|
)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
|
|||||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
expert_params_mapping = [
|
expert_params_mapping = [
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
|
|||||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
@@ -632,6 +633,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -640,9 +642,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -287,6 +288,7 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = GemmaModel(config, quant_config=quant_config)
|
self.model = GemmaModel(config, quant_config=quant_config)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -297,9 +299,11 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return (sample_output, logits_output)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.activation import GeluAndMul
|
from sglang.srt.layers.activation import GeluAndMul
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -396,6 +397,7 @@ class Gemma2ForCausalLM(nn.Module):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Gemma2Model(config, cache_config, quant_config)
|
self.model = Gemma2Model(config, cache_config, quant_config)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -406,9 +408,11 @@ class Gemma2ForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def get_attention_sliding_window_size(self):
|
def get_attention_sliding_window_size(self):
|
||||||
return get_attention_sliding_window_size(self.config)
|
return get_attention_sliding_window_size(self.config)
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -261,6 +262,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
if lora_config:
|
if lora_config:
|
||||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -270,9 +272,11 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from sglang.srt.layers.fused_moe import FusedMoE
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -297,6 +298,7 @@ class Grok1ModelForCausalLM(nn.Module):
|
|||||||
self.model = Grok1Model(config, quant_config=quant_config)
|
self.model = Grok1Model(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||||
@@ -313,9 +315,11 @@ class Grok1ModelForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
|
|||||||
self.model = InternLM2Model(config, quant_config)
|
self.model = InternLM2Model(config, quant_config)
|
||||||
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.output.weight, input_metadata
|
input_ids, hidden_states, self.output.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -302,6 +303,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
self.model = LlamaModel(config, quant_config=quant_config)
|
self.model = LlamaModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -310,11 +312,13 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> LogitProcessorOutput:
|
) -> LogitsProcessorOutput:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def get_module_name(self, name):
|
def get_module_name(self, name):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
|||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
from sglang.srt.models.llama2 import LlamaModel
|
from sglang.srt.models.llama2 import LlamaModel
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
|
|||||||
(input_metadata.batch_size, self.config.classification_out_size)
|
(input_metadata.batch_size, self.config.classification_out_size)
|
||||||
).to(input_ids.device)
|
).to(input_ids.device)
|
||||||
|
|
||||||
return LogitProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=scores,
|
next_token_logits=scores,
|
||||||
next_token_logprobs=scores,
|
next_token_logprobs=scores,
|
||||||
normalized_prompt_logprobs=scores,
|
normalized_prompt_logprobs=scores,
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -297,6 +298,7 @@ class MiniCPMForCausalLM(nn.Module):
|
|||||||
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -314,9 +316,11 @@ class MiniCPMForCausalLM(nn.Module):
|
|||||||
lm_head_weight = self.model.embed_tokens.weight
|
lm_head_weight = self.model.embed_tokens.weight
|
||||||
else:
|
else:
|
||||||
lm_head_weight = self.lm_head.weight
|
lm_head_weight = self.lm_head.weight
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, lm_head_weight, input_metadata
|
input_ids, hidden_states, lm_head_weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -299,6 +300,7 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
|
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -308,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -333,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
|
|||||||
self.model = MixtralModel(config, quant_config=quant_config)
|
self.model = MixtralModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -343,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
):
|
):
|
||||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||||
next_tokens = self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
return next_tokens
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -38,8 +38,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
Qwen2Config = None
|
Qwen2Config = None
|
||||||
@@ -276,6 +277,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
self.model = Qwen2Model(config, quant_config=quant_config)
|
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -289,9 +291,11 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
if not get_embedding:
|
if not get_embedding:
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
else:
|
else:
|
||||||
return self.pooler(hidden_states, input_metadata)
|
return self.pooler(hidden_states, input_metadata)
|
||||||
|
|
||||||
|
|||||||
@@ -35,10 +35,8 @@ from vllm.model_executor.layers.linear import (
|
|||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
@@ -49,6 +47,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
def compute_logits(
|
return sample_output, logits_output
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
logits = self.logits_processor(
|
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
|
||||||
)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
|
|||||||
self.model = StableLMEpochModel(config, quant_config=quant_config)
|
self.model = StableLMEpochModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
logits_output = self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||||
|
return sample_output, logits_output
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -21,10 +21,63 @@ class SamplingBatchInfo:
|
|||||||
top_ps: torch.Tensor = None
|
top_ps: torch.Tensor = None
|
||||||
top_ks: torch.Tensor = None
|
top_ks: torch.Tensor = None
|
||||||
min_ps: torch.Tensor = None
|
min_ps: torch.Tensor = None
|
||||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
|
||||||
|
# Dispatch in CUDA graph
|
||||||
|
need_min_p_sampling: bool = False
|
||||||
|
|
||||||
|
# Bias Tensors
|
||||||
logit_bias: torch.Tensor = None
|
logit_bias: torch.Tensor = None
|
||||||
vocab_mask: torch.Tensor = None
|
vocab_mask: torch.Tensor = None
|
||||||
|
|
||||||
|
# Penalizer
|
||||||
|
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||||
|
linear_penalties: torch.Tensor = None
|
||||||
|
scaling_penalties: torch.Tensor = None
|
||||||
|
|
||||||
|
def has_bias(self):
|
||||||
|
return (
|
||||||
|
self.logit_bias is not None
|
||||||
|
or self.vocab_mask is not None
|
||||||
|
or self.linear_penalties is not None
|
||||||
|
or self.scaling_penalties is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dummy_one(cls, max_bs: int, vocab_size: int):
|
||||||
|
ret = cls(vocab_size=vocab_size)
|
||||||
|
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
|
||||||
|
ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
|
||||||
|
ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
|
||||||
|
ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
if isinstance(key, slice):
|
||||||
|
# NOTE: We do not use cuda graph when there is bias tensors
|
||||||
|
assert not self.has_bias()
|
||||||
|
return SamplingBatchInfo(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
temperatures=self.temperatures[key],
|
||||||
|
top_ps=self.top_ps[key],
|
||||||
|
top_ks=self.top_ks[key],
|
||||||
|
min_ps=self.min_ps[key],
|
||||||
|
need_min_p_sampling=self.need_min_p_sampling,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def inplace_assign(self, bs: int, other: SamplingBatchInfo):
|
||||||
|
# NOTE: We do not use cuda graph when there is bias tensors
|
||||||
|
assert not self.has_bias()
|
||||||
|
|
||||||
|
self.vocab_size = other.vocab_size
|
||||||
|
self.need_min_p_sampling = other.need_min_p_sampling
|
||||||
|
|
||||||
|
self.temperatures[:bs] = other.temperatures
|
||||||
|
self.top_ps[:bs] = other.top_ps
|
||||||
|
self.top_ks[:bs] = other.top_ks
|
||||||
|
self.min_ps[:bs] = other.min_ps
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
@@ -45,6 +98,7 @@ class SamplingBatchInfo:
|
|||||||
ret.min_ps = torch.tensor(
|
ret.min_ps = torch.tensor(
|
||||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
||||||
)
|
)
|
||||||
|
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
||||||
|
|
||||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||||
@@ -72,6 +126,25 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def prepare_penalties(self):
|
||||||
|
self.scaling_penalties = None
|
||||||
|
self.linear_penalties = None
|
||||||
|
|
||||||
|
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
||||||
|
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
||||||
|
if penalizer.is_prepared():
|
||||||
|
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
||||||
|
else:
|
||||||
|
if penalizer.is_prepared():
|
||||||
|
if self.linear_penalties is None:
|
||||||
|
bs = self.penalizer_orchestrator.batch.batch_size()
|
||||||
|
self.linear_penalties = torch.zeros(
|
||||||
|
(bs, self.vocab_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||||
|
|
||||||
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
||||||
bs, reqs = batch.batch_size(), batch.reqs
|
bs, reqs = batch.batch_size(), batch.reqs
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ class SRTRunner:
|
|||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
dtype=get_dtype_str(torch_dtype),
|
dtype=get_dtype_str(torch_dtype),
|
||||||
port=port,
|
port=port,
|
||||||
mem_fraction_static=0.7,
|
mem_fraction_static=0.69,
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
is_embedding=not self.is_generation,
|
is_embedding=not self.is_generation,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user