Higher priority for user input of max_prefill_tokens & format (#540)

This commit is contained in:
Ying Sheng
2024-06-12 21:48:40 -07:00
committed by GitHub
parent 1374334d38
commit fb9296f0ed
50 changed files with 817 additions and 569 deletions

View File

@@ -24,10 +24,10 @@ from sglang.api import (
# SGL Backends
from sglang.backend.anthropic import Anthropic
from sglang.backend.litellm import LiteLLM
from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.backend.vertexai import VertexAI
from sglang.backend.litellm import LiteLLM
# Global Configurations
from sglang.global_config import global_config

View File

@@ -33,7 +33,8 @@ class LiteLLM(BaseBackend):
self.model_name = model_name
self.chat_template = chat_template or get_chat_template_by_model_path(
model_name)
model_name
)
self.client_params = {
"api_key": api_key,

View File

@@ -1,7 +1,7 @@
import dataclasses
import logging
import time
import warnings
import dataclasses
from typing import Callable, List, Optional, Union
import numpy as np
@@ -105,14 +105,16 @@ class OpenAI(BaseBackend):
def get_chat_template(self):
return self.chat_template
def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
num_api_spec_tokens: int, spec_var_name: str):
def _prepare_spec_execution(
self,
sampling_params: SglSamplingParams,
num_api_spec_tokens: int,
spec_var_name: str,
):
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
else:
assert (
self.spec_kwargs["max_tokens"] == num_api_spec_tokens
)
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
@@ -151,8 +153,9 @@ class OpenAI(BaseBackend):
)
prompt = s.messages_
else:
return self._prepare_spec_execution(sampling_params,
s.num_api_spec_tokens, spec_var_name)
return self._prepare_spec_execution(
sampling_params, s.num_api_spec_tokens, spec_var_name
)
else:
prompt = s.text_
@@ -325,7 +328,7 @@ class OpenAI(BaseBackend):
ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0]
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
self.token_usage.completion_tokens= ret.usage.completion_tokens
self.token_usage.completion_tokens = ret.usage.completion_tokens
# TODO:
# 1. return logits as the scores
@@ -355,7 +358,9 @@ class OpenAI(BaseBackend):
return decision, scores, None, None
def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
for attempt in range(retries):
try:
if is_chat:
@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
return comp
def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
def openai_completion_stream(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
generator = client.chat.completions.create(
messages=prompt, stream=True, stream_options={"include_usage": True},
**kwargs
messages=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
yield content or "", {}
else:
generator = client.completions.create(
prompt=prompt, stream=True, stream_options={"include_usage": True},
**kwargs
prompt=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:

View File

@@ -507,7 +507,7 @@ class StreamExecutor:
)
return
else: # Speculative execution on models with completion interface
else: # Speculative execution on models with completion interface
comp, meta_info = self._spec_gen(sampling_params)
self.text_ += comp

View File

@@ -81,12 +81,10 @@ class SglSamplingParams:
"top_p": self.top_p,
"top_k": self.top_k,
}
def to_litellm_kwargs(self):
if self.regex is not None:
warnings.warn(
"Regular expression is not supported in the LiteLLM backend."
)
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
return {
"max_tokens": self.max_new_tokens,
"stop": self.stop or None,

View File

@@ -10,4 +10,4 @@ if __name__ == "__main__":
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
launch_server(server_args, None)
launch_server(server_args, None)

View File

@@ -1,4 +1,5 @@
"""Launch the inference server for Llava-video model."""
import argparse
import multiprocessing as mp

View File

@@ -4,7 +4,7 @@ from typing import Dict, Optional, Union
from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel

View File

@@ -1,4 +1,5 @@
"""Cache for the compressed finite state machine."""
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache

View File

@@ -8,11 +8,12 @@ from collections import defaultdict
import interegular
import outlines.caching
from sglang.srt.constrained import (
FSMInfo,
disk_cache,
make_deterministic_fsm,
make_byte_level_fsm,
make_deterministic_fsm,
)
from sglang.srt.constrained.base_cache import BaseCache

View File

@@ -1,4 +1,5 @@
"""Conversation templates."""
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses

View File

@@ -1,10 +1,10 @@
"""Utilities for Huggingface Transformers."""
import functools
import json
import os
import warnings
import functools
from typing import Optional, Union, AbstractSet, Collection, Literal
from typing import AbstractSet, Collection, Literal, Optional, Union
from huggingface_hub import snapshot_download
from transformers import (
@@ -179,6 +179,7 @@ def get_processor(
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import tiktoken
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON
@@ -190,7 +191,8 @@ class TiktokenTokenizer:
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
}
special_tokens = {
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
bytes(item["bytes"]).decode(): item["token"]
for item in tok_dict["special_tokens"]
}
assert tok_dict["word_split"] == "V1"
@@ -202,7 +204,10 @@ class TiktokenTokenizer:
}
if "default_allowed_special" in tok_dict:
default_allowed_special = set(
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
[
bytes(bytes_list).decode()
for bytes_list in tok_dict["default_allowed_special"]
]
)
else:
default_allowed_special = None
@@ -216,14 +221,20 @@ class TiktokenTokenizer:
self,
text: str,
*,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
allowed_special: Union[
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
self,
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Convert to HF interface
@@ -237,10 +248,14 @@ class TiktokenTokenizer:
def decode(self, x):
return self.tokenizer.decode(x)
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index):
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
return self.tokenizer.decode_single_token_bytes(index).decode(
"utf-8", errors="ignore"
)

View File

@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.utils import is_hip
@@ -109,12 +108,16 @@ def fused_moe_kernel(
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if use_fp8:
a_scale = tl.load(a_scale_ptr)
@@ -130,13 +133,12 @@ def fused_moe_kernel(
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator)
@@ -147,9 +149,7 @@ def fused_moe_kernel(
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_fp8:
@@ -159,15 +159,14 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
topk_ids: torch.Tensor, block_size: int, num_experts: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
@@ -206,32 +205,38 @@ def moe_align_block_size(
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
ops.moe_align_block_size(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any], compute_type: tl.dtype,
use_fp8: bool) -> None:
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: Dict[str, Any],
compute_type: tl.dtype,
use_fp8: bool,
) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
@@ -242,8 +247,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
grid = lambda META: (
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)
fused_moe_kernel[grid](
A,
@@ -281,8 +288,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@functools.lru_cache
def get_moe_configs(E: int, N: int,
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
@@ -297,11 +303,11 @@ def get_moe_configs(E: int, N: int,
json_file_name = get_config_file_name(E, N, dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.",
config_file_path)
logger.info("Using configuration from %s for MoE layer.", config_file_path)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
@@ -352,40 +358,30 @@ def fused_moe(
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
M, _ = hidden_states.shape
E, N, _ = w1.shape
if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels
topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
@@ -400,8 +396,7 @@ def fused_moe(
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2],
"float8" if use_fp8 else None)
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
if configs:
# If an optimal configuration map has been found, look up the
@@ -415,7 +410,7 @@ def fused_moe(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
"num_stages": 4,
}
if M <= E:
@@ -425,61 +420,72 @@ def fused_moe(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
"num_stages": 4,
}
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)
topk_ids, config["BLOCK_SIZE_M"], E
)
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8)
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8)
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)

View File

@@ -1,4 +1,5 @@
"""Logits processing."""
import torch
from torch import nn
from vllm.distributed import (

View File

@@ -1,6 +1,7 @@
"""Radix attention."""
import torch
import numpy as np
import torch
from torch import nn
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
@@ -10,7 +11,9 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class RadixAttention(nn.Module):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
def __init__(
self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1
):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads

View File

@@ -4,7 +4,7 @@ import asyncio
import logging
import queue
import threading
from typing import List, Callable
from typing import Callable, List
import uvloop
import zmq
@@ -70,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread):
# async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0:
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
has_finished = any(
[obj.finished_reason is not None for obj in out_pyobjs]
)
if has_finished:
await asyncio.sleep(self.request_dependency_delay)
await asyncio.sleep(global_config.wait_for_new_request_delay)
@@ -108,4 +110,4 @@ def start_data_parallel_worker(
step_func=model_tp_client.step,
)
worker_thread.start()
return worker_thread
return worker_thread

View File

@@ -1,17 +1,17 @@
"""Meta data for requests and batches"""
import warnings
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import List
import warnings
import numpy as np
import torch
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.constrained import RegexGuide
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5

View File

@@ -13,15 +13,15 @@ import zmq
import zmq.asyncio
from sglang.global_config import global_config
from sglang.srt.managers.controller.dp_worker import (
DataParallelWorkerThread,
start_data_parallel_worker,
)
from sglang.srt.managers.io_struct import (
AbortReq,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.dp_worker import (
DataParallelWorkerThread,
start_data_parallel_worker,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback
@@ -136,7 +136,7 @@ class Controller:
self.recv_reqs = []
if next_step_input:
await self.dispatching(next_step_input)
#else:
# else:
# logger.error("There is no live worker.")
await asyncio.sleep(global_config.wait_for_new_request_delay)

View File

@@ -1,4 +1,5 @@
"""A controller that manages a group of tensor parallel workers."""
import asyncio
import logging
import time
@@ -49,7 +50,9 @@ class ControllerSingle:
# async sleep for receiving the subsequent request and avoiding cache miss
slept = False
if len(out_pyobjs) != 0:
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
has_finished = any(
[obj.finished_reason is not None for obj in out_pyobjs]
)
if has_finished:
if self.request_dependency_delay > 0:
slept = True
@@ -94,4 +97,4 @@ def start_controller_process(
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
kill_parent_process()
kill_parent_process()

View File

@@ -1,4 +1,5 @@
"""ModelRunner runs the forward passes of the models."""
import importlib
import importlib.resources
import logging
@@ -12,15 +13,18 @@ import torch
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import initialize_model_parallel, init_distributed_environment
from vllm.distributed import init_distributed_environment, initialize_model_parallel
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
from sglang.srt.utils import (
get_available_gpu_memory,
is_multimodal_model,
monkey_patch_vllm_p2p_access_check,
)
logger = logging.getLogger("srt.model_runner")
@@ -441,7 +445,9 @@ def import_model_classes():
module = importlib.import_module(name)
if hasattr(module, "EntryClass"):
entry = module.EntryClass
if isinstance(entry, list): # To support multiple model classes in one module
if isinstance(
entry, list
): # To support multiple model classes in one module
for tmp in entry:
model_arch_name_to_cls[tmp.__name__] = tmp
else:
@@ -449,7 +455,9 @@ def import_model_classes():
# compat: some models such as chatglm has incorrect class set in config.json
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
if hasattr(module, "EntryClassRemapping") and isinstance(module.EntryClassRemapping, list):
if hasattr(module, "EntryClassRemapping") and isinstance(
module.EntryClassRemapping, list
):
for remap in module.EntryClassRemapping:
if isinstance(remap, tuple) and len(remap) == 2:
model_arch_name_to_cls[remap[0]] = remap[1]

View File

@@ -1,6 +1,7 @@
"""
The radix tree data structure for managing the KV cache.
"""
import heapq
import time
from collections import defaultdict

View File

@@ -1,4 +1,5 @@
"""Request scheduler heuristic."""
import random
from collections import defaultdict

View File

@@ -15,22 +15,22 @@ from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.infer_batch import (
FINISH_ABORT,
BaseFinishReason,
Batch,
FINISH_ABORT,
ForwardMode,
Req,
)
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.utils import (
@@ -96,13 +96,13 @@ class ModelTpServer:
trust_remote_code=server_args.trust_remote_code,
)
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = max(
self.model_config.context_len,
(
min(self.max_total_num_tokens // 6, 65536)
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
),
self.max_prefill_tokens = (
max(
self.model_config.context_len,
min(self.max_total_num_tokens // 6, 65536),
)
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
)
self.max_running_requests = (
self.max_total_num_tokens // 2

View File

@@ -1,4 +1,5 @@
"""DetokenizerManager is a process that detokenizes the token ids."""
import asyncio
import inspect
@@ -7,10 +8,10 @@ import zmq
import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback, graceful_registry
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

View File

@@ -7,8 +7,8 @@ import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
from sglang.srt.sampling_params import SamplingParams
@dataclass

View File

@@ -1,11 +1,12 @@
"""TokenizerManager is a process that tokenizes the text."""
import asyncio
import concurrent.futures
import dataclasses
import logging
import multiprocessing as mp
import os
from typing import List, Dict
from typing import Dict, List
import numpy as np
import transformers
@@ -23,11 +24,11 @@ from sglang.srt.hf_transformers_utils import (
from sglang.srt.managers.io_struct import (
AbortReq,
BatchStrOut,
BatchTokenIDOut,
FlushCacheReq,
GenerateReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.io_struct import BatchTokenIDOut
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
@@ -91,7 +92,7 @@ class TokenizerManager:
)
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}
self.rid_to_state: Dict[str, ReqState] = {}
async def get_pixel_values(self, image_data):
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
@@ -322,7 +323,6 @@ class TokenizerManager:
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()
def convert_logprob_style(
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
):

View File

@@ -1,8 +1,9 @@
from typing import Optional
from sglang.srt.hf_transformers_utils import get_config, get_context_length
from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length
class ModelConfig:
def __init__(
@@ -17,8 +18,12 @@ class ModelConfig:
self.trust_remote_code = trust_remote_code
self.revision = revision
self.model_overide_args = model_overide_args
self.hf_config = get_config(self.path, trust_remote_code, revision,
model_overide_args=model_overide_args)
self.hf_config = get_config(
self.path,
trust_remote_code,
revision,
model_overide_args=model_overide_args,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
if context_length is not None:
self.context_len = context_length
@@ -55,18 +60,23 @@ class ModelConfig:
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False))
if not new_decoder_arch_falcon and getattr(self.hf_text_config,
"multi_query", False):
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)
)
if not new_decoder_arch_falcon and getattr(
self.hf_text_config, "multi_query", False
):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For DBRX and MPT
if self.hf_config.model_type in ["dbrx", "mpt"]:
return getattr(self.hf_config.attn_config, "kv_n_heads",
self.hf_config.num_attention_heads)
return getattr(
self.hf_config.attn_config,
"kv_n_heads",
self.hf_config.num_attention_heads,
)
attributes = [
# For Falcon:
@@ -94,13 +104,12 @@ class ModelConfig:
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // tensor_parallel_size)
return max(1, total_num_kv_heads // tensor_parallel_size)
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
No op for pure text models.
"""
if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have

View File

@@ -5,30 +5,32 @@
from typing import Iterable, List, Optional, Tuple
import torch
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.layers.logits_processor import LogitsProcessor
from torch import nn
from torch.nn import LayerNorm
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
LoraConfig = None
@@ -49,9 +51,11 @@ class GLMAttention(nn.Module):
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.multi_query_attention = config.multi_query_attention
self.total_num_kv_heads = (config.multi_query_group_num
if config.multi_query_attention else
config.num_attention_heads)
self.total_num_kv_heads = (
config.multi_query_group_num
if config.multi_query_attention
else config.num_attention_heads
)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
@@ -91,11 +95,13 @@ class GLMAttention(nn.Module):
base=10000 * rope_ratio,
is_neox_style=False,
)
self.attn = RadixAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
)
def forward(
self,
@@ -176,14 +182,16 @@ class GLMBlock(nn.Module):
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
config.apply_residual_connection_post_layernorm
)
self.fp32_residual_connection = config.fp32_residual_connection
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = layer_norm_func(config.hidden_size,
eps=config.layernorm_epsilon)
self.input_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon
)
# Self attention.
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
@@ -191,7 +199,8 @@ class GLMBlock(nn.Module):
# Layernorm on the attention output
self.post_attention_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
config.hidden_size, eps=config.layernorm_epsilon
)
# MLP
self.mlp = GLMMLP(config, quant_config)
@@ -250,16 +259,19 @@ class GLMTransformer(nn.Module):
self.num_layers = config.num_layers
# Transformer layers.
self.layers = nn.ModuleList([
GLMBlock(config, i, cache_config, quant_config)
for i in range(self.num_layers)
])
self.layers = nn.ModuleList(
[
GLMBlock(config, i, cache_config, quant_config)
for i in range(self.num_layers)
]
)
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
config.hidden_size, eps=config.layernorm_epsilon
)
def forward(
self,
@@ -291,16 +303,16 @@ class ChatGLMModel(nn.Module):
):
super().__init__()
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size)
self.embedding = VocabParallelEmbedding(
config.padded_vocab_size, config.hidden_size
)
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, cache_config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size)
self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
def forward(
self,
@@ -322,7 +334,7 @@ class ChatGLMModel(nn.Module):
class ChatGLMForCausalLM(nn.Module):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
"dense_h_to_4h": ["dense_h_to_4h"],
}
# LoRA specific attributes
supported_lora_modules = [
@@ -344,8 +356,7 @@ class ChatGLMForCausalLM(nn.Module):
super().__init__()
self.config: ChatGLMConfig = config
self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config)
@@ -357,8 +368,7 @@ class ChatGLMForCausalLM(nn.Module):
positions: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions,
input_metadata)
hidden_states = self.transformer(input_ids, positions, input_metadata)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
@@ -382,10 +392,10 @@ class ChatGLMForCausalLM(nn.Module):
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = ChatGLMForCausalLM
# compat: glm model.config class == ChatGLMModel
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]

View File

@@ -23,7 +23,7 @@
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from typing import Optional, Tuple, Iterable
from typing import Iterable, Optional, Tuple
import torch
import torch.utils.checkpoint
@@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention

View File

@@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor

View File

@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig, CacheConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm

View File

@@ -1,7 +1,7 @@
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
from typing import Iterable, Optional, Tuple, List
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
@@ -9,7 +9,6 @@ import torch.nn.functional as F
import tqdm
from torch import nn
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from vllm.config import CacheConfig
from vllm.distributed import (
@@ -35,12 +34,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.fused_moe import fused_moe
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
use_fused = True
@@ -134,9 +132,12 @@ class Grok1MoEUnfused(nn.Module):
final_hidden_states = torch.zeros(
(hidden_states.shape[0], hidden_dim),
dtype=hidden_states.dtype, device=hidden_states.device
dtype=hidden_states.dtype,
device=hidden_states.device,
)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_total_experts).permute(2, 1, 0)
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_total_experts
).permute(2, 1, 0)
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
@@ -153,7 +154,10 @@ class Grok1MoEUnfused(nn.Module):
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
current_hidden_states = (
expert_layer(current_state)
* routing_weights[top_x_list, idx_list, None]
)
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
@@ -198,32 +202,46 @@ class Grok1MoE(nn.Module):
self.params_dtype = params_dtype
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None)
self.gate = ReplicatedLinear(
self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None,
)
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=params_dtype))
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=params_dtype,
)
)
self.w2_weight = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=params_dtype))
torch.empty(
self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=params_dtype,
)
)
set_weight_attrs(self.w13_weight, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_weight, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(
self.w13_weight,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2_weight,
{
"weight_loader": self.weight_loader,
},
)
# Used for fp8.
self.w13_scale = None
@@ -233,46 +251,69 @@ class Grok1MoE(nn.Module):
if self.use_fp8:
# WEIGHT_SCALE (for fp8)
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
self.w13_scale = nn.Parameter(
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
self.w2_scale = nn.Parameter(
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.w13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(
self.w13_scale,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2_scale,
{
"weight_loader": self.weight_loader,
},
)
# ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
self.a13_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
self.a2_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
"was not serialized fp8."
)
self.a13_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
self.a2_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(self.a13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(
self.a13_scale,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.a2_scale,
{
"weight_loader": self.weight_loader,
},
)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int, pre_sharded: bool):
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
expert_id: int,
pre_sharded: bool,
):
param_data = param.data
shard_size = self.intermediate_size
if pre_sharded:
@@ -284,8 +325,9 @@ class Grok1MoE(nn.Module):
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
shard, :
]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name:
@@ -298,17 +340,17 @@ class Grok1MoE(nn.Module):
# If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(self.w13_weight.data,
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(self.w2_weight.data,
dtype=torch.float8_e4m3fn)
w13_weight = torch.empty_like(
self.w13_weight.data, dtype=torch.float8_e4m3fn
)
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[
expert] = ops.scaled_fp8_quant(
self.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], self.w2_scale[
expert] = ops.scaled_fp8_quant(
self.w2_weight.data[expert, :, :])
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
self.w13_weight.data[expert, :, :]
)
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
self.w2_weight.data[expert, :, :]
)
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
@@ -319,40 +361,40 @@ class Grok1MoE(nn.Module):
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
"activation scales are None."
)
if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)):
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ")
"Using the maximum across experts for each layer. "
)
self.a13_scale = nn.Parameter(self.a13_scale.max(),
requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k,
renormalize=False,
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.w13_scale,
w2_scale=self.w2_scale,
a1_scale=self.a13_scale,
a2_scale=self.a2_scale)
final_hidden_states = fused_moe(
hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k,
renormalize=False,
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.w13_scale,
w2_scale=self.w2_scale,
a1_scale=self.a13_scale,
a2_scale=self.a2_scale,
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
@@ -462,10 +504,12 @@ class Grok1DecoderLayer(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config)
quant_config=quant_config,
)
else:
self.block_sparse_moe = Grok1MoEUnfused(
config=config, quant_config=quant_config)
config=config, quant_config=quant_config
)
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -478,12 +522,21 @@ class Grok1DecoderLayer(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.post_attn_norm(self.self_attn(
positions=positions, hidden_states=self.pre_attn_norm(hidden_states),
input_metadata=input_metadata,
)) + hidden_states
hidden_states = (
self.post_attn_norm(
self.self_attn(
positions=positions,
hidden_states=self.pre_attn_norm(hidden_states),
input_metadata=input_metadata,
)
)
+ hidden_states
)
hidden_states = self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + hidden_states
hidden_states = (
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
+ hidden_states
)
return hidden_states
@@ -525,9 +578,7 @@ class Grok1Model(nn.Module):
hidden_states.mul_(self.config.embedding_multiplier_scale)
for i in range(len(self.layers)):
hidden_states = self.layers[i](
positions, hidden_states, input_metadata
)
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
hidden_states = self.norm(hidden_states)
hidden_states.mul_(self.config.output_multiplier_scale)
@@ -572,28 +623,41 @@ class Grok1ModelForCausalLM(nn.Module):
]
if use_fused:
expert_params_mapping = [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
expert_params_mapping = (
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
(
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
)
else:
expert_params_mapping = []
@@ -601,11 +665,11 @@ class Grok1ModelForCausalLM(nn.Module):
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
for name, loaded_weight in weights:
#print(get_tensor_model_parallel_rank(), name)
# print(get_tensor_model_parallel_rank(), name)
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@@ -623,19 +687,22 @@ class Grok1ModelForCausalLM(nn.Module):
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id,
pre_sharded=get_tensor_model_parallel_world_size() > 1)
weight_loader(
param,
loaded_weight,
weight_name,
expert_id=expert_id,
pre_sharded=get_tensor_model_parallel_world_size() > 1,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
@@ -645,10 +712,11 @@ def all_close_1d(x: torch.Tensor) -> bool:
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
def _prepare_presharded_weights(self,
model_name_or_path: str,
revision: Optional[str],
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
def _prepare_presharded_weights(
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
) -> Tuple[str, List[str], bool]:
import glob
import os
@@ -668,4 +736,4 @@ def _prepare_presharded_weights(self,
return hf_folder, hf_weights_files, use_safetensors
EntryClass = Grok1ModelForCausalLM
EntryClass = Grok1ModelForCausalLM

View File

@@ -1,7 +1,7 @@
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional, Tuple, Iterable
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import tqdm
@@ -10,7 +10,7 @@ from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -158,9 +158,11 @@ class LlamaDecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,

View File

@@ -1,11 +1,17 @@
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from typing import List, Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig
from transformers import (
CLIPVisionConfig,
CLIPVisionModel,
LlavaConfig,
MistralConfig,
Qwen2Config,
)
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
@@ -19,8 +25,8 @@ from sglang.srt.mm_utils import (
unpad_image_shape,
)
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
class LlavaLlamaForCausalLM(nn.Module):
@@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
@@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward():
)
EntryClass = [
LlavaLlamaForCausalLM,
LlavaQwenForCausalLM,
LlavaMistralForCausalLM
]
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]

View File

@@ -1,6 +1,6 @@
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
from typing import List, Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch

View File

@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
self.params_dtype = params_dtype
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None)
self.gate = ReplicatedLinear(
self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None,
)
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=params_dtype))
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=params_dtype,
)
)
self.w2_weight = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=params_dtype))
torch.empty(
self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=params_dtype,
)
)
set_weight_attrs(self.w13_weight, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_weight, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(
self.w13_weight,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2_weight,
{
"weight_loader": self.weight_loader,
},
)
# Used for fp8.
self.w13_scale = None
@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
if self.use_fp8:
# WEIGHT_SCALE (for fp8)
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
self.w13_scale = nn.Parameter(
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
self.w2_scale = nn.Parameter(
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.w13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(
self.w13_scale,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2_scale,
{
"weight_loader": self.weight_loader,
},
)
# ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
self.a13_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
self.a2_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
"was not serialized fp8."
)
self.a13_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
self.a2_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(self.a13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(
self.a13_scale,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.a2_scale,
{
"weight_loader": self.weight_loader,
},
)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
expert_id: int,
):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
shard, :
]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name:
@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
# If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(self.w13_weight.data,
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(self.w2_weight.data,
dtype=torch.float8_e4m3fn)
w13_weight = torch.empty_like(
self.w13_weight.data, dtype=torch.float8_e4m3fn
)
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[
expert] = ops.scaled_fp8_quant(
self.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], self.w2_scale[
expert] = ops.scaled_fp8_quant(
self.w2_weight.data[expert, :, :])
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
self.w13_weight.data[expert, :, :]
)
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
self.w2_weight.data[expert, :, :]
)
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
@@ -193,40 +228,40 @@ class MixtralMoE(nn.Module):
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
"activation scales are None."
)
if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)):
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ")
"Using the maximum across experts for each layer. "
)
self.a13_scale = nn.Parameter(self.a13_scale.max(),
requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.w13_scale,
w2_scale=self.w2_scale,
a1_scale=self.a13_scale,
a2_scale=self.a2_scale)
final_hidden_states = fused_moe(
hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.w13_scale,
w2_scale=self.w2_scale,
a1_scale=self.a13_scale,
a2_scale=self.a2_scale,
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config)
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
expert_params_mapping = (
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
(
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id)
weight_loader(
param, loaded_weight, weight_name, expert_id=expert_id
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)

View File

@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata

View File

@@ -1,6 +1,6 @@
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
from typing import Any, Dict, Optional, Iterable, Tuple
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn

View File

@@ -1,7 +1,7 @@
# Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional, Tuple, Iterable
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn

View File

@@ -2,7 +2,7 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
model compatible with HuggingFace weights."""
from typing import Optional, Tuple, Iterable
from typing import Iterable, Optional, Tuple
import torch
from torch import nn

View File

@@ -1,14 +1,14 @@
"""Inference-only Yi-VL model."""
from typing import Tuple, Iterable, Optional
from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlavaConfig
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.models.llava import (
LlavaLlamaForCausalLM,
monkey_path_clip_vision_embed_forward,

View File

@@ -6,7 +6,7 @@ import os
from http import HTTPStatus
from fastapi import Request
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
from sglang.srt.conversation import (
Conversation,
@@ -40,21 +40,18 @@ chat_template_name = None
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
error = ErrorResponse(message=message,
type=err_type,
code=status_code.value)
return JSONResponse(content=error.model_dump(),
status_code=error.code)
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
):
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
return JSONResponse(content=error.model_dump(), status_code=error.code)
def create_streaming_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
error = ErrorResponse(message=message,
type=err_type,
code=status_code.value)
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> str:
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
json_str = json.dumps({"error": error.model_dump()})
return json_str
@@ -125,7 +122,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
n_prev_token = 0
try:
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request):
adapted_request, raw_request
):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
@@ -154,12 +152,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
decode_token_logprobs=content["meta_info"][
"decode_token_logprobs"
][n_prev_token:],
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
n_prev_token:
],
decode_top_logprobs=content["meta_info"][
"decode_top_logprobs"
][n_prev_token:],
)
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
n_prev_token = len(
content["meta_info"]["decode_token_logprobs"]
)
else:
logprobs = None
@@ -188,13 +188,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request))
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response.
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__()
adapted_request, raw_request
).__anext__()
except ValueError as e:
return create_error_response(str(e))
@@ -299,7 +303,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
stream_buffer = ""
try:
async for content in tokenizer_manager.generate_request(adapted_request, raw_request):
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
if is_first:
# First chunk with role
is_first = False
@@ -334,13 +340,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request))
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response.
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__()
adapted_request, raw_request
).__anext__()
except ValueError as e:
return create_error_response(str(e))

View File

@@ -13,7 +13,7 @@ import sys
import threading
import time
from http import HTTPStatus
from typing import Optional, Dict
from typing import Dict, Optional
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -29,10 +29,14 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.manager_multi import (
start_controller_process as start_controller_process_multi,
)
from sglang.srt.managers.controller.manager_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api_adapter import (
load_chat_template_for_openai_api,
@@ -97,8 +101,11 @@ async def generate_request(obj: GenerateReqInput, request: Request):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj))
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj),
)
else:
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()

View File

@@ -1,8 +1,8 @@
"""Common utilities."""
import base64
import multiprocessing
import logging
import multiprocessing
import os
import random
import socket
@@ -17,12 +17,11 @@ import requests
import rpyc
import torch
import triton
from rpyc.utils.server import ThreadedServer
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from rpyc.utils.server import ThreadedServer
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
@@ -377,7 +376,7 @@ def init_rpyc_service(service: rpyc.Service, port: int):
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600
"sync_request_timeout": 3600,
},
)
t.logger.setLevel(logging.WARN)
@@ -396,7 +395,7 @@ def connect_to_rpyc_service(port, host="localhost"):
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600
"sync_request_timeout": 3600,
},
)
break
@@ -423,7 +422,9 @@ def suppress_other_loggers():
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.config").setLevel(logging.ERROR)
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(logging.WARN)
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
logging.WARN
)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
@@ -464,6 +465,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
device_name = torch.cuda.get_device_name(gpu_id)
if "RTX 40" not in device_name:
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
@@ -485,4 +487,3 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
)
response = await call_next(request)
return response

View File

@@ -356,16 +356,25 @@ def test_completion_speculative():
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
s += (
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
)
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
@sgl.function
def gen_character_no_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
s += (
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
)
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
token_usage = sgl.global_config.default_backend.token_usage
@@ -378,7 +387,9 @@ def test_completion_speculative():
gen_character_no_spec().sync()
usage_with_no_spec = token_usage.prompt_tokens
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
assert (
usage_with_spec < usage_with_no_spec
), f"{usage_with_spec} vs {usage_with_no_spec}"
def test_chat_completion_speculative():
@@ -386,8 +397,17 @@ def test_chat_completion_speculative():
def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:")
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
s += sgl.assistant(
"Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
)
s += sgl.user("Please generate new Name, Birthday and Job.\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
s += sgl.assistant(
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
+ "\nJob:"
+ sgl.gen("job", stop="\n")
)
gen_character_spec().sync()
gen_character_spec().sync()

View File

@@ -15,7 +15,6 @@ from json import dumps
import numpy as np
import requests
logger = logging.getLogger(__name__)
@@ -255,8 +254,10 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
def graceful_registry(sub_module_name):
def graceful_shutdown(signum, frame):
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...")
logger.info(
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
)
if signum == signal.SIGTERM:
logger.info(f"{sub_module_name} recive sigterm")
signal.signal(signal.SIGTERM, graceful_shutdown)
signal.signal(signal.SIGTERM, graceful_shutdown)