Higher priority for user input of max_prefill_tokens & format (#540)

This commit is contained in:
Ying Sheng
2024-06-12 21:48:40 -07:00
committed by GitHub
parent 1374334d38
commit fb9296f0ed
50 changed files with 817 additions and 569 deletions

View File

@@ -158,7 +158,9 @@ async def send_request(
timeout = aiohttp.ClientTimeout(total=3 * 3600) timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
while True: while True:
async with session.post(api_url, headers=headers, json=pload) as response: async with session.post(
api_url, headers=headers, json=pload
) as response:
chunks = [] chunks = []
async for chunk, _ in response.content.iter_chunks(): async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk) chunks.append(chunk)
@@ -228,19 +230,32 @@ def main(args: argparse.Namespace):
np.random.seed(args.seed) np.random.seed(args.seed)
api_url = f"http://{args.host}:{args.port}/generate" api_url = f"http://{args.host}:{args.port}/generate"
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=args.trust_remote_code) tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code
)
if args.dataset: if args.dataset:
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
else: else:
input_lens = np.random.randint( input_lens = np.random.randint(
int(args.input_len * args.range_ratio), args.input_len + 1, size=args.num_prompts) int(args.input_len * args.range_ratio),
args.input_len + 1,
size=args.num_prompts,
)
output_lens = np.random.randint( output_lens = np.random.randint(
int(args.output_len * args.range_ratio), args.output_len + 1, size=args.num_prompts) int(args.output_len * args.range_ratio),
args.output_len + 1,
size=args.num_prompts,
)
offsets = np.random.randint(0, tokenizer.vocab_size, size=args.num_prompts) offsets = np.random.randint(0, tokenizer.vocab_size, size=args.num_prompts)
input_requests = [] input_requests = []
for i in range(args.num_prompts): for i in range(args.num_prompts):
prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) prompt = tokenizer.decode(
[
(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])
]
)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
@@ -287,16 +302,15 @@ if __name__ == "__main__":
) )
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=30000) parser.add_argument("--port", type=int, default=30000)
parser.add_argument( parser.add_argument("--dataset", type=str, help="Path to the dataset.")
"--dataset", type=str, help="Path to the dataset."
)
parser.add_argument("--input-len", type=int, default=2048) parser.add_argument("--input-len", type=int, default=2048)
parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--range-ratio", type=float, default=1.0) parser.add_argument("--range-ratio", type=float, default=1.0)
parser.add_argument( parser.add_argument(
"--tokenizer", type=str, "--tokenizer",
type=str,
default="NousResearch/Meta-Llama-3-8B", default="NousResearch/Meta-Llama-3-8B",
help="Name or path of the tokenizer." help="Name or path of the tokenizer.",
) )
parser.add_argument( parser.add_argument(
"--best-of", "--best-of",

View File

@@ -24,10 +24,10 @@ from sglang.api import (
# SGL Backends # SGL Backends
from sglang.backend.anthropic import Anthropic from sglang.backend.anthropic import Anthropic
from sglang.backend.litellm import LiteLLM
from sglang.backend.openai import OpenAI from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.backend.vertexai import VertexAI from sglang.backend.vertexai import VertexAI
from sglang.backend.litellm import LiteLLM
# Global Configurations # Global Configurations
from sglang.global_config import global_config from sglang.global_config import global_config

View File

@@ -33,7 +33,8 @@ class LiteLLM(BaseBackend):
self.model_name = model_name self.model_name = model_name
self.chat_template = chat_template or get_chat_template_by_model_path( self.chat_template = chat_template or get_chat_template_by_model_path(
model_name) model_name
)
self.client_params = { self.client_params = {
"api_key": api_key, "api_key": api_key,

View File

@@ -1,7 +1,7 @@
import dataclasses
import logging import logging
import time import time
import warnings import warnings
import dataclasses
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
@@ -105,14 +105,16 @@ class OpenAI(BaseBackend):
def get_chat_template(self): def get_chat_template(self):
return self.chat_template return self.chat_template
def _prepare_spec_execution(self, sampling_params: SglSamplingParams, def _prepare_spec_execution(
num_api_spec_tokens: int, spec_var_name: str): self,
sampling_params: SglSamplingParams,
num_api_spec_tokens: int,
spec_var_name: str,
):
if "max_tokens" not in self.spec_kwargs: if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = num_api_spec_tokens self.spec_kwargs["max_tokens"] = num_api_spec_tokens
else: else:
assert ( assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
self.spec_kwargs["max_tokens"] == num_api_spec_tokens
)
params = sampling_params.to_openai_kwargs() params = sampling_params.to_openai_kwargs()
for key, value in params.items(): for key, value in params.items():
@@ -151,8 +153,9 @@ class OpenAI(BaseBackend):
) )
prompt = s.messages_ prompt = s.messages_
else: else:
return self._prepare_spec_execution(sampling_params, return self._prepare_spec_execution(
s.num_api_spec_tokens, spec_var_name) sampling_params, s.num_api_spec_tokens, spec_var_name
)
else: else:
prompt = s.text_ prompt = s.text_
@@ -355,7 +358,9 @@ class OpenAI(BaseBackend):
return decision, scores, None, None return decision, scores, None, None
def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs): def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
for attempt in range(retries): for attempt in range(retries):
try: try:
if is_chat: if is_chat:
@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
return comp return comp
def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs): def openai_completion_stream(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
for attempt in range(retries): for attempt in range(retries):
try: try:
if is_chat: if is_chat:
if "stop" in kwargs and kwargs["stop"] is None: if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop") kwargs.pop("stop")
generator = client.chat.completions.create( generator = client.chat.completions.create(
messages=prompt, stream=True, stream_options={"include_usage": True}, messages=prompt,
**kwargs stream=True,
stream_options={"include_usage": True},
**kwargs,
) )
for ret in generator: for ret in generator:
if len(ret.choices) == 0: if len(ret.choices) == 0:
@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
yield content or "", {} yield content or "", {}
else: else:
generator = client.completions.create( generator = client.completions.create(
prompt=prompt, stream=True, stream_options={"include_usage": True}, prompt=prompt,
**kwargs stream=True,
stream_options={"include_usage": True},
**kwargs,
) )
for ret in generator: for ret in generator:
if len(ret.choices) == 0: if len(ret.choices) == 0:

View File

@@ -84,9 +84,7 @@ class SglSamplingParams:
def to_litellm_kwargs(self): def to_litellm_kwargs(self):
if self.regex is not None: if self.regex is not None:
warnings.warn( warnings.warn("Regular expression is not supported in the LiteLLM backend.")
"Regular expression is not supported in the LiteLLM backend."
)
return { return {
"max_tokens": self.max_new_tokens, "max_tokens": self.max_new_tokens,
"stop": self.stop or None, "stop": self.stop or None,

View File

@@ -1,4 +1,5 @@
"""Launch the inference server for Llava-video model.""" """Launch the inference server for Llava-video model."""
import argparse import argparse
import multiprocessing as mp import multiprocessing as mp

View File

@@ -4,7 +4,7 @@ from typing import Dict, Optional, Union
from outlines.caching import cache as disk_cache from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,4 +1,5 @@
"""Cache for the compressed finite state machine.""" """Cache for the compressed finite state machine."""
from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache from sglang.srt.constrained.base_cache import BaseCache

View File

@@ -8,11 +8,12 @@ from collections import defaultdict
import interegular import interegular
import outlines.caching import outlines.caching
from sglang.srt.constrained import ( from sglang.srt.constrained import (
FSMInfo, FSMInfo,
disk_cache, disk_cache,
make_deterministic_fsm,
make_byte_level_fsm, make_byte_level_fsm,
make_deterministic_fsm,
) )
from sglang.srt.constrained.base_cache import BaseCache from sglang.srt.constrained.base_cache import BaseCache

View File

@@ -1,4 +1,5 @@
"""Conversation templates.""" """Conversation templates."""
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses import dataclasses

View File

@@ -1,10 +1,10 @@
"""Utilities for Huggingface Transformers.""" """Utilities for Huggingface Transformers."""
import functools
import json import json
import os import os
import warnings import warnings
import functools from typing import AbstractSet, Collection, Literal, Optional, Union
from typing import Optional, Union, AbstractSet, Collection, Literal
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import ( from transformers import (
@@ -179,6 +179,7 @@ def get_processor(
class TiktokenTokenizer: class TiktokenTokenizer:
def __init__(self, tokenizer_path): def __init__(self, tokenizer_path):
import tiktoken import tiktoken
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON # Read JSON
@@ -190,7 +191,8 @@ class TiktokenTokenizer:
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"] bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
} }
special_tokens = { special_tokens = {
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"] bytes(item["bytes"]).decode(): item["token"]
for item in tok_dict["special_tokens"]
} }
assert tok_dict["word_split"] == "V1" assert tok_dict["word_split"] == "V1"
@@ -202,7 +204,10 @@ class TiktokenTokenizer:
} }
if "default_allowed_special" in tok_dict: if "default_allowed_special" in tok_dict:
default_allowed_special = set( default_allowed_special = set(
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]] [
bytes(bytes_list).decode()
for bytes_list in tok_dict["default_allowed_special"]
]
) )
else: else:
default_allowed_special = None default_allowed_special = None
@@ -216,14 +221,20 @@ class TiktokenTokenizer:
self, self,
text: str, text: str,
*, *,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 allowed_special: Union[
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all", disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]: ) -> list[int]:
if isinstance(allowed_special, set): if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode( return tiktoken.Encoding.encode(
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special self,
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
) )
tokenizer.encode = functools.partial(encode_patched, tokenizer) tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Convert to HF interface # Convert to HF interface
@@ -237,10 +248,14 @@ class TiktokenTokenizer:
def decode(self, x): def decode(self, x):
return self.tokenizer.decode(x) return self.tokenizer.decode(x)
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False): def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if isinstance(batch[0], int): if isinstance(batch[0], int):
batch = [[x] for x in batch] batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch) return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index): def convert_ids_to_tokens(self, index):
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore") return self.tokenizer.decode_single_token_bytes(index).decode(
"utf-8", errors="ignore"
)

View File

@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.utils import is_hip
@@ -109,12 +108,16 @@ def fused_moe_kernel(
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + a_ptrs = a_ptr + (
offs_k[None, :] * stride_ak) offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
off_experts = tl.load(expert_ids_ptr + pid_m) off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + b_ptrs = (
offs_bn[None, :] * stride_bn) b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if use_fp8: if use_fp8:
a_scale = tl.load(a_scale_ptr) a_scale = tl.load(a_scale_ptr)
@@ -130,13 +133,12 @@ def fused_moe_kernel(
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the # Load the next block of A and B, generate a mask by checking the
# K dimension. # K dimension.
a = tl.load(a_ptrs, a = tl.load(
mask=token_mask[:, None] & a_ptrs,
(offs_k[None, :] < K - k * BLOCK_SIZE_K), mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0) other=0.0,
b = tl.load(b_ptrs, )
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
other=0.0)
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_fp8: if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator) accumulator = tl.dot(a, b, acc=accumulator)
@@ -147,9 +149,7 @@ def fused_moe_kernel(
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT: if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
if use_fp8: if use_fp8:
@@ -159,15 +159,14 @@ def fused_moe_kernel(
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N) c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size( def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int, topk_ids: torch.Tensor, block_size: int, num_experts: int
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
size for matrix multiplication. size for matrix multiplication.
@@ -206,32 +205,38 @@ def moe_align_block_size(
by block_size for proper block matrix operations. by block_size for proper block matrix operations.
""" """
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ), sorted_ids = torch.empty(
dtype=torch.int32, (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
device=topk_ids.device) )
sorted_ids.fill_(topk_ids.numel()) sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ), expert_ids = torch.empty(
dtype=torch.int32, (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
device=topk_ids.device) )
num_tokens_post_pad = torch.empty((1), num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
dtype=torch.int32, ops.moe_align_block_size(
device=topk_ids.device) topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, )
expert_ids, num_tokens_post_pad)
return sorted_ids, expert_ids, num_tokens_post_pad return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: Optional[torch.Tensor], A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, mul_routed_weight: bool,
config: Dict[str, Any], compute_type: tl.dtype, top_k: int,
use_fp8: bool) -> None: config: Dict[str, Any],
compute_type: tl.dtype,
use_fp8: bool,
) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
@@ -242,8 +247,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A, A_scale = ops.scaled_fp8_quant(A, A_scale) A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None assert B_scale is not None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ grid = lambda META: (
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)
fused_moe_kernel[grid]( fused_moe_kernel[grid](
A, A,
@@ -281,8 +288,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@functools.lru_cache @functools.lru_cache
def get_moe_configs(E: int, N: int, def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
@@ -297,11 +303,11 @@ def get_moe_configs(E: int, N: int,
json_file_name = get_config_file_name(E, N, dtype) json_file_name = get_config_file_name(E, N, dtype)
config_file_path = os.path.join( config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path) as f: with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.", logger.info("Using configuration from %s for MoE layer.", config_file_path)
config_file_path)
# If a configuration has been found, return it # If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()} return {int(key): val for key, val in json.load(f).items()}
@@ -352,40 +358,30 @@ def fused_moe(
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
# Check constraints. # Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], ( assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [ assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape M, _ = hidden_states.shape
E, N, _ = w1.shape E, N, _ = w1.shape
if is_hip(): if is_hip():
# The MoE kernels are not yet supported on ROCm. # The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output, routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else: else:
import vllm._moe_C as moe_kernels import vllm._moe_C as moe_kernels
topk_weights = torch.empty(M, topk_weights = torch.empty(
topk, M, topk, dtype=torch.float32, device=hidden_states.device
dtype=torch.float32, )
device=hidden_states.device) topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_ids = torch.empty(M, token_expert_indicies = torch.empty(
topk, M, topk, dtype=torch.int32, device=hidden_states.device
dtype=torch.int32, )
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax( moe_kernels.topk_softmax(
topk_weights, topk_weights,
topk_ids, topk_ids,
@@ -400,8 +396,7 @@ def fused_moe(
config = override_config config = override_config
else: else:
# First try to load optimal config from the file # First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2], configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
"float8" if use_fp8 else None)
if configs: if configs:
# If an optimal configuration map has been found, look up the # If an optimal configuration map has been found, look up the
@@ -415,7 +410,7 @@ def fused_moe(
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 4 "num_stages": 4,
} }
if M <= E: if M <= E:
@@ -425,25 +420,32 @@ def fused_moe(
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16, "GROUP_SIZE_M": 16,
"num_warps": 8, "num_warps": 8,
"num_stages": 4 "num_stages": 4,
} }
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype,
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), )
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype,
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), )
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E) topk_ids, config["BLOCK_SIZE_M"], E
compute_type = (tl.bfloat16 )
if hidden_states.dtype == torch.bfloat16 else tl.float16) compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
invoke_fused_moe_kernel(hidden_states, invoke_fused_moe_kernel(
hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
a1_scale, a1_scale,
@@ -457,11 +459,13 @@ def fused_moe(
topk_ids.shape[1], topk_ids.shape[1],
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8) use_fp8=use_fp8,
)
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2, invoke_fused_moe_kernel(
intermediate_cache2,
w2, w2,
intermediate_cache3, intermediate_cache3,
a2_scale, a2_scale,
@@ -475,11 +479,13 @@ def fused_moe(
1, 1,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8) use_fp8=use_fp8,
)
if inplace: if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1, dim=1,
out=hidden_states) out=hidden_states,
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), )
dim=1) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)

View File

@@ -1,4 +1,5 @@
"""Logits processing.""" """Logits processing."""
import torch import torch
from torch import nn from torch import nn
from vllm.distributed import ( from vllm.distributed import (

View File

@@ -1,6 +1,7 @@
"""Radix attention.""" """Radix attention."""
import torch
import numpy as np import numpy as np
import torch
from torch import nn from torch import nn
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
@@ -10,7 +11,9 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1): def __init__(
self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1
):
super().__init__() super().__init__()
self.tp_q_head_num = num_heads self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads self.tp_k_head_num = num_kv_heads

View File

@@ -4,7 +4,7 @@ import asyncio
import logging import logging
import queue import queue
import threading import threading
from typing import List, Callable from typing import Callable, List
import uvloop import uvloop
import zmq import zmq
@@ -70,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread):
# async sleep for receiving the subsequent request and avoiding cache miss # async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0: if len(out_pyobjs) != 0:
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs]) has_finished = any(
[obj.finished_reason is not None for obj in out_pyobjs]
)
if has_finished: if has_finished:
await asyncio.sleep(self.request_dependency_delay) await asyncio.sleep(self.request_dependency_delay)
await asyncio.sleep(global_config.wait_for_new_request_delay) await asyncio.sleep(global_config.wait_for_new_request_delay)

View File

@@ -1,17 +1,17 @@
"""Meta data for requests and batches""" """Meta data for requests and batches"""
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List from typing import List
import warnings
import numpy as np import numpy as np
import torch import torch
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.constrained import RegexGuide
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5

View File

@@ -13,15 +13,15 @@ import zmq
import zmq.asyncio import zmq.asyncio
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.managers.controller.dp_worker import (
DataParallelWorkerThread,
start_data_parallel_worker,
)
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
FlushCacheReq, FlushCacheReq,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.controller.dp_worker import (
DataParallelWorkerThread,
start_data_parallel_worker,
)
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback

View File

@@ -1,4 +1,5 @@
"""A controller that manages a group of tensor parallel workers.""" """A controller that manages a group of tensor parallel workers."""
import asyncio import asyncio
import logging import logging
import time import time
@@ -49,7 +50,9 @@ class ControllerSingle:
# async sleep for receiving the subsequent request and avoiding cache miss # async sleep for receiving the subsequent request and avoiding cache miss
slept = False slept = False
if len(out_pyobjs) != 0: if len(out_pyobjs) != 0:
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs]) has_finished = any(
[obj.finished_reason is not None for obj in out_pyobjs]
)
if has_finished: if has_finished:
if self.request_dependency_delay > 0: if self.request_dependency_delay > 0:
slept = True slept = True

View File

@@ -1,4 +1,5 @@
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
import importlib import importlib
import importlib.resources import importlib.resources
import logging import logging
@@ -12,15 +13,18 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import initialize_model_parallel, init_distributed_environment from vllm.distributed import init_distributed_environment, initialize_model_parallel
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check from sglang.srt.utils import (
get_available_gpu_memory,
is_multimodal_model,
monkey_patch_vllm_p2p_access_check,
)
logger = logging.getLogger("srt.model_runner") logger = logging.getLogger("srt.model_runner")
@@ -441,7 +445,9 @@ def import_model_classes():
module = importlib.import_module(name) module = importlib.import_module(name)
if hasattr(module, "EntryClass"): if hasattr(module, "EntryClass"):
entry = module.EntryClass entry = module.EntryClass
if isinstance(entry, list): # To support multiple model classes in one module if isinstance(
entry, list
): # To support multiple model classes in one module
for tmp in entry: for tmp in entry:
model_arch_name_to_cls[tmp.__name__] = tmp model_arch_name_to_cls[tmp.__name__] = tmp
else: else:
@@ -449,7 +455,9 @@ def import_model_classes():
# compat: some models such as chatglm has incorrect class set in config.json # compat: some models such as chatglm has incorrect class set in config.json
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ] # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
if hasattr(module, "EntryClassRemapping") and isinstance(module.EntryClassRemapping, list): if hasattr(module, "EntryClassRemapping") and isinstance(
module.EntryClassRemapping, list
):
for remap in module.EntryClassRemapping: for remap in module.EntryClassRemapping:
if isinstance(remap, tuple) and len(remap) == 2: if isinstance(remap, tuple) and len(remap) == 2:
model_arch_name_to_cls[remap[0]] = remap[1] model_arch_name_to_cls[remap[0]] = remap[1]

View File

@@ -1,6 +1,7 @@
""" """
The radix tree data structure for managing the KV cache. The radix tree data structure for managing the KV cache.
""" """
import heapq import heapq
import time import time
from collections import defaultdict from collections import defaultdict

View File

@@ -1,4 +1,5 @@
"""Request scheduler heuristic.""" """Request scheduler heuristic."""
import random import random
from collections import defaultdict from collections import defaultdict

View File

@@ -15,22 +15,22 @@ from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.infer_batch import ( from sglang.srt.managers.controller.infer_batch import (
FINISH_ABORT,
BaseFinishReason, BaseFinishReason,
Batch, Batch,
FINISH_ABORT,
ForwardMode, ForwardMode,
Req, Req,
) )
from sglang.srt.managers.controller.model_runner import ModelRunner from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import ModelPortArgs, ServerArgs from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
@@ -96,13 +96,13 @@ class ModelTpServer:
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
) )
self.max_total_num_tokens = self.model_runner.max_total_num_tokens self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = max( self.max_prefill_tokens = (
max(
self.model_config.context_len, self.model_config.context_len,
( min(self.max_total_num_tokens // 6, 65536),
min(self.max_total_num_tokens // 6, 65536) )
if server_args.max_prefill_tokens is None if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens else server_args.max_prefill_tokens
),
) )
self.max_running_requests = ( self.max_running_requests = (
self.max_total_num_tokens // 2 self.max_total_num_tokens // 2

View File

@@ -1,4 +1,5 @@
"""DetokenizerManager is a process that detokenizes the token ids.""" """DetokenizerManager is a process that detokenizes the token ids."""
import asyncio import asyncio
import inspect import inspect
@@ -7,10 +8,10 @@ import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback, graceful_registry from sglang.utils import get_exception_traceback, graceful_registry
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

View File

@@ -7,8 +7,8 @@ import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.managers.controller.infer_batch import BaseFinishReason from sglang.srt.managers.controller.infer_batch import BaseFinishReason
from sglang.srt.sampling_params import SamplingParams
@dataclass @dataclass

View File

@@ -1,11 +1,12 @@
"""TokenizerManager is a process that tokenizes the text.""" """TokenizerManager is a process that tokenizes the text."""
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import dataclasses import dataclasses
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
from typing import List, Dict from typing import Dict, List
import numpy as np import numpy as np
import transformers import transformers
@@ -23,11 +24,11 @@ from sglang.srt.hf_transformers_utils import (
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchStrOut, BatchStrOut,
BatchTokenIDOut,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.io_struct import BatchTokenIDOut
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
@@ -322,7 +323,6 @@ class TokenizerManager:
state.finished = recv_obj.finished_reason[i] is not None state.finished = recv_obj.finished_reason[i] is not None
state.event.set() state.event.set()
def convert_logprob_style( def convert_logprob_style(
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
): ):

View File

@@ -1,8 +1,9 @@
from typing import Optional from typing import Optional
from sglang.srt.hf_transformers_utils import get_config, get_context_length
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length
class ModelConfig: class ModelConfig:
def __init__( def __init__(
@@ -17,8 +18,12 @@ class ModelConfig:
self.trust_remote_code = trust_remote_code self.trust_remote_code = trust_remote_code
self.revision = revision self.revision = revision
self.model_overide_args = model_overide_args self.model_overide_args = model_overide_args
self.hf_config = get_config(self.path, trust_remote_code, revision, self.hf_config = get_config(
model_overide_args=model_overide_args) self.path,
trust_remote_code,
revision,
model_overide_args=model_overide_args,
)
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
if context_length is not None: if context_length is not None:
self.context_len = context_length self.context_len = context_length
@@ -56,17 +61,22 @@ class ModelConfig:
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = ( new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)) and getattr(self.hf_config, "new_decoder_architecture", False)
if not new_decoder_arch_falcon and getattr(self.hf_text_config, )
"multi_query", False): if not new_decoder_arch_falcon and getattr(
self.hf_text_config, "multi_query", False
):
# Multi-query attention, only one KV head. # Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case. # Currently, tensor parallelism is not supported in this case.
return 1 return 1
# For DBRX and MPT # For DBRX and MPT
if self.hf_config.model_type in ["dbrx", "mpt"]: if self.hf_config.model_type in ["dbrx", "mpt"]:
return getattr(self.hf_config.attn_config, "kv_n_heads", return getattr(
self.hf_config.num_attention_heads) self.hf_config.attn_config,
"kv_n_heads",
self.hf_config.num_attention_heads,
)
attributes = [ attributes = [
# For Falcon: # For Falcon:
@@ -94,8 +104,7 @@ class ModelConfig:
# the tensor parallel size. We will replicate the KV heads in the # the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor # case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head. # parallel size so each GPU has at least one KV head.
return max(1, return max(1, total_num_kv_heads // tensor_parallel_size)
total_num_kv_heads // tensor_parallel_size)
def get_hf_text_config(config: PretrainedConfig): def get_hf_text_config(config: PretrainedConfig):

View File

@@ -5,30 +5,32 @@
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.layers.logits_processor import LogitsProcessor
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear,
from vllm.model_executor.layers.quantization.base_config import ( )
QuantizationConfig) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
LoraConfig = None LoraConfig = None
@@ -49,9 +51,11 @@ class GLMAttention(nn.Module):
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
self.multi_query_attention = config.multi_query_attention self.multi_query_attention = config.multi_query_attention
self.total_num_kv_heads = (config.multi_query_group_num self.total_num_kv_heads = (
if config.multi_query_attention else config.multi_query_group_num
config.num_attention_heads) if config.multi_query_attention
else config.num_attention_heads
)
if self.total_num_kv_heads >= tp_size: if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition # Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
@@ -91,11 +95,13 @@ class GLMAttention(nn.Module):
base=10000 * rope_ratio, base=10000 * rope_ratio,
is_neox_style=False, is_neox_style=False,
) )
self.attn = RadixAttention(self.num_heads, self.attn = RadixAttention(
self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id) layer_id=layer_id,
)
def forward( def forward(
self, self,
@@ -176,14 +182,16 @@ class GLMBlock(nn.Module):
): ):
super().__init__() super().__init__()
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm) config.apply_residual_connection_post_layernorm
)
self.fp32_residual_connection = config.fp32_residual_connection self.fp32_residual_connection = config.fp32_residual_connection
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data. # Layernorm on the input data.
self.input_layernorm = layer_norm_func(config.hidden_size, self.input_layernorm = layer_norm_func(
eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon
)
# Self attention. # Self attention.
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config) self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
@@ -191,7 +199,8 @@ class GLMBlock(nn.Module):
# Layernorm on the attention output # Layernorm on the attention output
self.post_attention_layernorm = layer_norm_func( self.post_attention_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon
)
# MLP # MLP
self.mlp = GLMMLP(config, quant_config) self.mlp = GLMMLP(config, quant_config)
@@ -250,16 +259,19 @@ class GLMTransformer(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
# Transformer layers. # Transformer layers.
self.layers = nn.ModuleList([ self.layers = nn.ModuleList(
[
GLMBlock(config, i, cache_config, quant_config) GLMBlock(config, i, cache_config, quant_config)
for i in range(self.num_layers) for i in range(self.num_layers)
]) ]
)
if self.post_layer_norm: if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = layer_norm_func( self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon
)
def forward( def forward(
self, self,
@@ -291,16 +303,16 @@ class ChatGLMModel(nn.Module):
): ):
super().__init__() super().__init__()
self.embedding = VocabParallelEmbedding(config.padded_vocab_size, self.embedding = VocabParallelEmbedding(
config.hidden_size) config.padded_vocab_size, config.hidden_size
)
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, cache_config, quant_config) self.encoder = GLMTransformer(config, cache_config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size, self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
config.hidden_size)
def forward( def forward(
self, self,
@@ -322,7 +334,7 @@ class ChatGLMModel(nn.Module):
class ChatGLMForCausalLM(nn.Module): class ChatGLMForCausalLM(nn.Module):
packed_modules_mapping = { packed_modules_mapping = {
"query_key_value": ["query_key_value"], "query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"] "dense_h_to_4h": ["dense_h_to_4h"],
} }
# LoRA specific attributes # LoRA specific attributes
supported_lora_modules = [ supported_lora_modules = [
@@ -344,8 +356,7 @@ class ChatGLMForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config: ChatGLMConfig = config self.config: ChatGLMConfig = config
self.quant_config = quant_config self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length", self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config) self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.lm_head = self.transformer.output_layer self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@@ -357,8 +368,7 @@ class ChatGLMForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, hidden_states = self.transformer(input_ids, positions, input_metadata)
input_metadata)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
@@ -382,10 +392,10 @@ class ChatGLMForCausalLM(nn.Module):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader", default_weight_loader)
default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
EntryClass = ChatGLMForCausalLM EntryClass = ChatGLMForCausalLM
# compat: glm model.config class == ChatGLMModel # compat: glm model.config class == ChatGLMModel
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)] EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]

View File

@@ -23,7 +23,7 @@
# This file is based on the LLama model definition file in transformers # This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model.""" """PyTorch Cohere model."""
from typing import Optional, Tuple, Iterable from typing import Iterable, Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention

View File

@@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor

View File

@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig, CacheConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm

View File

@@ -1,7 +1,7 @@
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model.""" """Inference-only Grok1 model."""
from typing import Iterable, Optional, Tuple, List from typing import Iterable, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@@ -9,7 +9,6 @@ import torch.nn.functional as F
import tqdm import tqdm
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import ( from vllm.distributed import (
@@ -35,12 +34,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.fused_moe import fused_moe from sglang.srt.layers.fused_moe import fused_moe
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
use_fused = True use_fused = True
@@ -134,9 +132,12 @@ class Grok1MoEUnfused(nn.Module):
final_hidden_states = torch.zeros( final_hidden_states = torch.zeros(
(hidden_states.shape[0], hidden_dim), (hidden_states.shape[0], hidden_dim),
dtype=hidden_states.dtype, device=hidden_states.device dtype=hidden_states.dtype,
device=hidden_states.device,
) )
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_total_experts).permute(2, 1, 0) expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_total_experts
).permute(2, 1, 0)
for expert_idx in self.expert_indicies: for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx] expert_layer = self.experts[expert_idx]
@@ -153,7 +154,10 @@ class Grok1MoEUnfused(nn.Module):
# the current expert. We need to make sure to multiply the output hidden # the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2) # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] current_hidden_states = (
expert_layer(current_state)
* routing_weights[top_x_list, idx_list, None]
)
# However `index_add_` only support torch tensors for indexing so we'll use # However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here. # the `top_x` tensor here.
@@ -198,32 +202,46 @@ class Grok1MoE(nn.Module):
self.params_dtype = params_dtype self.params_dtype = params_dtype
# Gate always runs at half / full precision for now. # Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size, self.gate = ReplicatedLinear(
self.hidden_size,
self.num_total_experts, self.num_total_experts,
bias=False, bias=False,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
quant_config=None) quant_config=None,
)
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter( self.w13_weight = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(
self.num_total_experts,
2 * self.intermediate_size, 2 * self.intermediate_size,
self.hidden_size, self.hidden_size,
dtype=params_dtype)) dtype=params_dtype,
)
)
self.w2_weight = nn.Parameter( self.w2_weight = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(
self.num_total_experts,
self.hidden_size, self.hidden_size,
self.intermediate_size, self.intermediate_size,
dtype=params_dtype)) dtype=params_dtype,
)
)
set_weight_attrs(self.w13_weight, { set_weight_attrs(
self.w13_weight,
{
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
set_weight_attrs(self.w2_weight, { )
set_weight_attrs(
self.w2_weight,
{
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
)
# Used for fp8. # Used for fp8.
self.w13_scale = None self.w13_scale = None
@@ -233,46 +251,69 @@ class Grok1MoE(nn.Module):
if self.use_fp8: if self.use_fp8:
# WEIGHT_SCALE (for fp8) # WEIGHT_SCALE (for fp8)
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, self.w13_scale = nn.Parameter(
dtype=torch.float32), torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False) requires_grad=False,
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, )
dtype=torch.float32), self.w2_scale = nn.Parameter(
requires_grad=False) torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
# If loading fp8 checkpoint, pass the weight loaders. # If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in # If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading() # process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized: if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.w13_scale, { set_weight_attrs(
self.w13_scale,
{
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
set_weight_attrs(self.w2_scale, { )
set_weight_attrs(
self.w2_scale,
{
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
)
# ACT_SCALE (for fp8) # ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static": if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized: if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError( raise ValueError(
"Found static activation scheme for checkpoint that " "Found static activation scheme for checkpoint that "
"was not serialized fp8.") "was not serialized fp8."
self.a13_scale = nn.Parameter(torch.zeros( )
self.num_total_experts, dtype=torch.float32), self.a13_scale = nn.Parameter(
requires_grad=False) torch.zeros(self.num_total_experts, dtype=torch.float32),
self.a2_scale = nn.Parameter(torch.zeros( requires_grad=False,
self.num_total_experts, dtype=torch.float32), )
requires_grad=False) self.a2_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(self.a13_scale, { set_weight_attrs(
self.a13_scale,
{
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
set_weight_attrs(self.a2_scale, { )
set_weight_attrs(
self.a2_scale,
{
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, def weight_loader(
weight_name: str, expert_id: int, pre_sharded: bool): self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
expert_id: int,
pre_sharded: bool,
):
param_data = param.data param_data = param.data
shard_size = self.intermediate_size shard_size = self.intermediate_size
if pre_sharded: if pre_sharded:
@@ -284,8 +325,9 @@ class Grok1MoE(nn.Module):
if weight_name.endswith("w1.weight"): if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"): if weight_name.endswith("w3.weight"):
param_data[expert_id, param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
shard_size:2 * shard_size, :] = loaded_weight[shard, :] shard, :
]
if weight_name.endswith("w2.weight"): if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard] param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name: if "act_scale" in weight_name or "weight_scale" in weight_name:
@@ -298,17 +340,17 @@ class Grok1MoE(nn.Module):
# If checkpoint is fp16, quantize here. # If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(self.w13_weight.data, w13_weight = torch.empty_like(
dtype=torch.float8_e4m3fn) self.w13_weight.data, dtype=torch.float8_e4m3fn
w2_weight = torch.empty_like(self.w2_weight.data, )
dtype=torch.float8_e4m3fn) w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts): for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[ w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
expert] = ops.scaled_fp8_quant( self.w13_weight.data[expert, :, :]
self.w13_weight.data[expert, :, :]) )
w2_weight[expert, :, :], self.w2_scale[ w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
expert] = ops.scaled_fp8_quant( self.w2_weight.data[expert, :, :]
self.w2_weight.data[expert, :, :]) )
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
@@ -319,25 +361,25 @@ class Grok1MoE(nn.Module):
if self.a13_scale is None or self.a2_scale is None: if self.a13_scale is None or self.a2_scale is None:
raise ValueError( raise ValueError(
"QuantConfig has static quantization, but found " "QuantConfig has static quantization, but found "
"activation scales are None.") "activation scales are None."
)
if (not all_close_1d(self.a13_scale) if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
or not all_close_1d(self.a2_scale)):
print_warning_once( print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. " "Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ") "Using the maximum across experts for each layer. "
)
self.a13_scale = nn.Parameter(self.a13_scale.max(), self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
requires_grad=False) self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size) hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states, final_hidden_states = fused_moe(
hidden_states,
self.w13_weight, self.w13_weight,
self.w2_weight, self.w2_weight,
router_logits, router_logits,
@@ -348,11 +390,11 @@ class Grok1MoE(nn.Module):
w1_scale=self.w13_scale, w1_scale=self.w13_scale,
w2_scale=self.w2_scale, w2_scale=self.w2_scale,
a1_scale=self.a13_scale, a1_scale=self.a13_scale,
a2_scale=self.a2_scale) a2_scale=self.a2_scale,
)
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size) return final_hidden_states.view(num_tokens, hidden_size)
@@ -462,10 +504,12 @@ class Grok1DecoderLayer(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config) quant_config=quant_config,
)
else: else:
self.block_sparse_moe = Grok1MoEUnfused( self.block_sparse_moe = Grok1MoEUnfused(
config=config, quant_config=quant_config) config=config, quant_config=quant_config
)
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -478,12 +522,21 @@ class Grok1DecoderLayer(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.post_attn_norm(self.self_attn( hidden_states = (
positions=positions, hidden_states=self.pre_attn_norm(hidden_states), self.post_attn_norm(
self.self_attn(
positions=positions,
hidden_states=self.pre_attn_norm(hidden_states),
input_metadata=input_metadata, input_metadata=input_metadata,
)) + hidden_states )
)
+ hidden_states
)
hidden_states = self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + hidden_states hidden_states = (
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
+ hidden_states
)
return hidden_states return hidden_states
@@ -525,9 +578,7 @@ class Grok1Model(nn.Module):
hidden_states.mul_(self.config.embedding_multiplier_scale) hidden_states.mul_(self.config.embedding_multiplier_scale)
for i in range(len(self.layers)): for i in range(len(self.layers)):
hidden_states = self.layers[i]( hidden_states = self.layers[i](positions, hidden_states, input_metadata)
positions, hidden_states, input_metadata
)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
hidden_states.mul_(self.config.output_multiplier_scale) hidden_states.mul_(self.config.output_multiplier_scale)
@@ -572,28 +623,41 @@ class Grok1ModelForCausalLM(nn.Module):
] ]
if use_fused: if use_fused:
expert_params_mapping = [ expert_params_mapping = (
[
# These are the weight scales for the experts # These are the weight scales for the experts
# (param_name, weight_name, expert_id) # (param_name, weight_name, expert_id)
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", (
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
for expert_id in range(self.config.num_local_experts) f"experts.{expert_id}.{weight_name}.weight_scale",
for weight_name in ["w1", "w2", "w3"] expert_id,
] + [ )
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
for expert_id in range(self.config.num_local_experts) for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"] for weight_name in ["w1", "w2", "w3"]
] ]
+ [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
)
else: else:
expert_params_mapping = [] expert_params_mapping = []
@@ -605,7 +669,7 @@ class Grok1ModelForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
@@ -623,19 +687,22 @@ class Grok1ModelForCausalLM(nn.Module):
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(
param,
loaded_weight, loaded_weight,
weight_name, weight_name,
expert_id=expert_id, expert_id=expert_id,
pre_sharded=get_tensor_model_parallel_world_size() > 1) pre_sharded=get_tensor_model_parallel_world_size() > 1,
)
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(
default_weight_loader) param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
@@ -645,10 +712,11 @@ def all_close_1d(x: torch.Tensor) -> bool:
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights") old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
def _prepare_presharded_weights(self,
model_name_or_path: str,
revision: Optional[str], def _prepare_presharded_weights(
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
) -> Tuple[str, List[str], bool]:
import glob import glob
import os import os

View File

@@ -1,7 +1,7 @@
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional, Tuple, Iterable from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
import tqdm import tqdm
@@ -10,7 +10,7 @@ from transformers import LlamaConfig
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@@ -158,9 +158,11 @@ class LlamaDecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr( if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None): config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = ( rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings) config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,

View File

@@ -1,11 +1,17 @@
"""Inference-only LLaVa model compatible with HuggingFace weights.""" """Inference-only LLaVa model compatible with HuggingFace weights."""
from typing import List, Iterable, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig from transformers import (
CLIPVisionConfig,
CLIPVisionModel,
LlavaConfig,
MistralConfig,
Qwen2Config,
)
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
@@ -19,8 +25,8 @@ from sglang.srt.mm_utils import (
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
class LlavaLlamaForCausalLM(nn.Module): class LlavaLlamaForCausalLM(nn.Module):
@@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
first_call = True first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
@@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward():
) )
EntryClass = [ EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
LlavaLlamaForCausalLM,
LlavaQwenForCausalLM,
LlavaMistralForCausalLM
]

View File

@@ -1,6 +1,6 @@
"""Inference-only LLaVa video model compatible with HuggingFace weights.""" """Inference-only LLaVa video model compatible with HuggingFace weights."""
from typing import List, Iterable, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch

View File

@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert """A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks. across all ranks.
@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
self.params_dtype = params_dtype self.params_dtype = params_dtype
# Gate always runs at half / full precision for now. # Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size, self.gate = ReplicatedLinear(
self.hidden_size,
self.num_total_experts, self.num_total_experts,
bias=False, bias=False,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
quant_config=None) quant_config=None,
)
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter( self.w13_weight = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(
self.num_total_experts,
2 * self.intermediate_size, 2 * self.intermediate_size,
self.hidden_size, self.hidden_size,
dtype=params_dtype)) dtype=params_dtype,
)
)
self.w2_weight = nn.Parameter( self.w2_weight = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(
self.num_total_experts,
self.hidden_size, self.hidden_size,
self.intermediate_size, self.intermediate_size,
dtype=params_dtype)) dtype=params_dtype,
)
)
set_weight_attrs(self.w13_weight, { set_weight_attrs(
self.w13_weight,
{
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
set_weight_attrs(self.w2_weight, { )
set_weight_attrs(
self.w2_weight,
{
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
)
# Used for fp8. # Used for fp8.
self.w13_scale = None self.w13_scale = None
@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
if self.use_fp8: if self.use_fp8:
# WEIGHT_SCALE (for fp8) # WEIGHT_SCALE (for fp8)
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, self.w13_scale = nn.Parameter(
dtype=torch.float32), torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False) requires_grad=False,
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, )
dtype=torch.float32), self.w2_scale = nn.Parameter(
requires_grad=False) torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
# If loading fp8 checkpoint, pass the weight loaders. # If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in # If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading() # process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized: if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.w13_scale, { set_weight_attrs(
self.w13_scale,
{
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
set_weight_attrs(self.w2_scale, { )
set_weight_attrs(
self.w2_scale,
{
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
)
# ACT_SCALE (for fp8) # ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static": if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized: if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError( raise ValueError(
"Found static activation scheme for checkpoint that " "Found static activation scheme for checkpoint that "
"was not serialized fp8.") "was not serialized fp8."
self.a13_scale = nn.Parameter(torch.zeros( )
self.num_total_experts, dtype=torch.float32), self.a13_scale = nn.Parameter(
requires_grad=False) torch.zeros(self.num_total_experts, dtype=torch.float32),
self.a2_scale = nn.Parameter(torch.zeros( requires_grad=False,
self.num_total_experts, dtype=torch.float32), )
requires_grad=False) self.a2_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(self.a13_scale, { set_weight_attrs(
self.a13_scale,
{
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
set_weight_attrs(self.a2_scale, { )
set_weight_attrs(
self.a2_scale,
{
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, def weight_loader(
weight_name: str, expert_id: int): self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
expert_id: int,
):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
param_data = param.data param_data = param.data
shard_size = self.intermediate_size shard_size = self.intermediate_size
@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
if weight_name.endswith("w1.weight"): if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"): if weight_name.endswith("w3.weight"):
param_data[expert_id, param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
shard_size:2 * shard_size, :] = loaded_weight[shard, :] shard, :
]
if weight_name.endswith("w2.weight"): if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard] param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name: if "act_scale" in weight_name or "weight_scale" in weight_name:
@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
# If checkpoint is fp16, quantize here. # If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(self.w13_weight.data, w13_weight = torch.empty_like(
dtype=torch.float8_e4m3fn) self.w13_weight.data, dtype=torch.float8_e4m3fn
w2_weight = torch.empty_like(self.w2_weight.data, )
dtype=torch.float8_e4m3fn) w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts): for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[ w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
expert] = ops.scaled_fp8_quant( self.w13_weight.data[expert, :, :]
self.w13_weight.data[expert, :, :]) )
w2_weight[expert, :, :], self.w2_scale[ w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
expert] = ops.scaled_fp8_quant( self.w2_weight.data[expert, :, :]
self.w2_weight.data[expert, :, :]) )
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
@@ -193,25 +228,25 @@ class MixtralMoE(nn.Module):
if self.a13_scale is None or self.a2_scale is None: if self.a13_scale is None or self.a2_scale is None:
raise ValueError( raise ValueError(
"QuantConfig has static quantization, but found " "QuantConfig has static quantization, but found "
"activation scales are None.") "activation scales are None."
)
if (not all_close_1d(self.a13_scale) if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
or not all_close_1d(self.a2_scale)):
print_warning_once( print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. " "Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ") "Using the maximum across experts for each layer. "
)
self.a13_scale = nn.Parameter(self.a13_scale.max(), self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
requires_grad=False) self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size) hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states, final_hidden_states = fused_moe(
hidden_states,
self.w13_weight, self.w13_weight,
self.w2_weight, self.w2_weight,
router_logits, router_logits,
@@ -222,11 +257,11 @@ class MixtralMoE(nn.Module):
w1_scale=self.w13_scale, w1_scale=self.w13_scale,
w2_scale=self.w2_scale, w2_scale=self.w2_scale,
a1_scale=self.a13_scale, a1_scale=self.a13_scale,
a2_scale=self.a2_scale) a2_scale=self.a2_scale,
)
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size) return final_hidden_states.view(num_tokens, hidden_size)
@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config) quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
expert_params_mapping = [ expert_params_mapping = (
[
# These are the weight scales for the experts # These are the weight scales for the experts
# (param_name, weight_name, expert_id) # (param_name, weight_name, expert_id)
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", (
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
for expert_id in range(self.config.num_local_experts) f"experts.{expert_id}.{weight_name}.weight_scale",
for weight_name in ["w1", "w2", "w3"] expert_id,
] + [ )
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
for expert_id in range(self.config.num_local_experts) for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"] for weight_name in ["w1", "w2", "w3"]
] ]
+ [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(
loaded_weight, param, loaded_weight, weight_name, expert_id=expert_id
weight_name, )
expert_id=expert_id)
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(
default_weight_loader) param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)

View File

@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata

View File

@@ -1,6 +1,6 @@
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
from typing import Any, Dict, Optional, Iterable, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn

View File

@@ -1,7 +1,7 @@
# Adapted from llama2.py # Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model. # Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional, Tuple, Iterable from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn

View File

@@ -2,7 +2,7 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b) """Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
model compatible with HuggingFace weights.""" model compatible with HuggingFace weights."""
from typing import Optional, Tuple, Iterable from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn

View File

@@ -1,14 +1,14 @@
"""Inference-only Yi-VL model.""" """Inference-only Yi-VL model."""
from typing import Tuple, Iterable, Optional from typing import Iterable, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import CLIPVisionModel, LlavaConfig from transformers import CLIPVisionModel, LlavaConfig
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.models.llava import ( from sglang.srt.models.llava import (
LlavaLlamaForCausalLM, LlavaLlamaForCausalLM,
monkey_path_clip_vision_embed_forward, monkey_path_clip_vision_embed_forward,

View File

@@ -6,7 +6,7 @@ import os
from http import HTTPStatus from http import HTTPStatus
from fastapi import Request from fastapi import Request
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import JSONResponse, StreamingResponse
from sglang.srt.conversation import ( from sglang.srt.conversation import (
Conversation, Conversation,
@@ -40,21 +40,18 @@ chat_template_name = None
def create_error_response( def create_error_response(
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST): status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
error = ErrorResponse(message=message, ):
type=err_type, error = ErrorResponse(message=message, type=err_type, code=status_code.value)
code=status_code.value) return JSONResponse(content=error.model_dump(), status_code=error.code)
return JSONResponse(content=error.model_dump(),
status_code=error.code)
def create_streaming_error_response( def create_streaming_error_response(
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
error = ErrorResponse(message=message, ) -> str:
type=err_type, error = ErrorResponse(message=message, type=err_type, code=status_code.value)
code=status_code.value)
json_str = json.dumps({"error": error.model_dump()}) json_str = json.dumps({"error": error.model_dump()})
return json_str return json_str
@@ -125,7 +122,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
n_prev_token = 0 n_prev_token = 0
try: try:
async for content in tokenizer_manager.generate_request( async for content in tokenizer_manager.generate_request(
adapted_request, raw_request): adapted_request, raw_request
):
text = content["text"] text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"] prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"] completion_tokens = content["meta_info"]["completion_tokens"]
@@ -154,12 +152,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
decode_token_logprobs=content["meta_info"][ decode_token_logprobs=content["meta_info"][
"decode_token_logprobs" "decode_token_logprobs"
][n_prev_token:], ][n_prev_token:],
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][ decode_top_logprobs=content["meta_info"][
n_prev_token: "decode_top_logprobs"
], ][n_prev_token:],
) )
n_prev_token = len(content["meta_info"]["decode_token_logprobs"]) n_prev_token = len(
content["meta_info"]["decode_token_logprobs"]
)
else: else:
logprobs = None logprobs = None
@@ -188,13 +188,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
yield f"data: {error}\n\n" yield f"data: {error}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream", return StreamingResponse(
background=tokenizer_manager.create_abort_task(adapted_request)) generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response. # Non-streaming response.
try: try:
ret = await tokenizer_manager.generate_request( ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__() adapted_request, raw_request
).__anext__()
except ValueError as e: except ValueError as e:
return create_error_response(str(e)) return create_error_response(str(e))
@@ -299,7 +303,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
stream_buffer = "" stream_buffer = ""
try: try:
async for content in tokenizer_manager.generate_request(adapted_request, raw_request): async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
if is_first: if is_first:
# First chunk with role # First chunk with role
is_first = False is_first = False
@@ -334,13 +340,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
yield f"data: {error}\n\n" yield f"data: {error}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream", return StreamingResponse(
background=tokenizer_manager.create_abort_task(adapted_request)) generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response. # Non-streaming response.
try: try:
ret = await tokenizer_manager.generate_request( ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__() adapted_request, raw_request
).__anext__()
except ValueError as e: except ValueError as e:
return create_error_response(str(e)) return create_error_response(str(e))

View File

@@ -13,7 +13,7 @@ import sys
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import Optional, Dict from typing import Dict, Optional
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -29,10 +29,14 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache from sglang.srt.constrained import disable_cache
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.manager_multi import (
start_controller_process as start_controller_process_multi,
)
from sglang.srt.managers.controller.manager_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api_adapter import ( from sglang.srt.openai_api_adapter import (
load_chat_template_for_openai_api, load_chat_template_for_openai_api,
@@ -97,8 +101,11 @@ async def generate_request(obj: GenerateReqInput, request: Request):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream", return StreamingResponse(
background=tokenizer_manager.create_abort_task(obj)) stream_results(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj),
)
else: else:
try: try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()

View File

@@ -1,8 +1,8 @@
"""Common utilities.""" """Common utilities."""
import base64 import base64
import multiprocessing
import logging import logging
import multiprocessing
import os import os
import random import random
import socket import socket
@@ -17,12 +17,11 @@ import requests
import rpyc import rpyc
import torch import torch
import triton import triton
from rpyc.utils.server import ThreadedServer
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from rpyc.utils.server import ThreadedServer
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -377,7 +376,7 @@ def init_rpyc_service(service: rpyc.Service, port: int):
protocol_config={ protocol_config={
"allow_public_attrs": True, "allow_public_attrs": True,
"allow_pickle": True, "allow_pickle": True,
"sync_request_timeout": 3600 "sync_request_timeout": 3600,
}, },
) )
t.logger.setLevel(logging.WARN) t.logger.setLevel(logging.WARN)
@@ -396,7 +395,7 @@ def connect_to_rpyc_service(port, host="localhost"):
config={ config={
"allow_public_attrs": True, "allow_public_attrs": True,
"allow_pickle": True, "allow_pickle": True,
"sync_request_timeout": 3600 "sync_request_timeout": 3600,
}, },
) )
break break
@@ -423,7 +422,9 @@ def suppress_other_loggers():
vllm_default_logger.setLevel(logging.WARN) vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.config").setLevel(logging.ERROR) logging.getLogger("vllm.config").setLevel(logging.ERROR)
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(logging.WARN) logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
logging.WARN
)
logging.getLogger("vllm.selector").setLevel(logging.WARN) logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN)
@@ -464,6 +465,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
device_name = torch.cuda.get_device_name(gpu_id) device_name = torch.cuda.get_device_name(gpu_id)
if "RTX 40" not in device_name: if "RTX 40" not in device_name:
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
@@ -485,4 +487,3 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
) )
response = await call_next(request) response = await call_next(request)
return response return response

View File

@@ -356,16 +356,25 @@ def test_completion_speculative():
s += "Construct a character within the following format:\n" s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n" s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") s += (
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
)
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n" s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
@sgl.function @sgl.function
def gen_character_no_spec(s): def gen_character_no_spec(s):
s += "Construct a character within the following format:\n" s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n" s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") s += (
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
)
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n" s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
token_usage = sgl.global_config.default_backend.token_usage token_usage = sgl.global_config.default_backend.token_usage
@@ -378,7 +387,9 @@ def test_completion_speculative():
gen_character_no_spec().sync() gen_character_no_spec().sync()
usage_with_no_spec = token_usage.prompt_tokens usage_with_no_spec = token_usage.prompt_tokens
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}" assert (
usage_with_spec < usage_with_no_spec
), f"{usage_with_spec} vs {usage_with_no_spec}"
def test_chat_completion_speculative(): def test_chat_completion_speculative():
@@ -386,8 +397,17 @@ def test_chat_completion_speculative():
def gen_character_spec(s): def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.") s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:") s += sgl.user("Construct a character within the following format:")
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n") s += sgl.assistant(
"Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
)
s += sgl.user("Please generate new Name, Birthday and Job.\n") s += sgl.user("Please generate new Name, Birthday and Job.\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) s += sgl.assistant(
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
+ "\nJob:"
+ sgl.gen("job", stop="\n")
)
gen_character_spec().sync() gen_character_spec().sync()

View File

@@ -15,7 +15,6 @@ from json import dumps
import numpy as np import numpy as np
import requests import requests
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -255,7 +254,9 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
def graceful_registry(sub_module_name): def graceful_registry(sub_module_name):
def graceful_shutdown(signum, frame): def graceful_shutdown(signum, frame):
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...") logger.info(
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
)
if signum == signal.SIGTERM: if signum == signal.SIGTERM:
logger.info(f"{sub_module_name} recive sigterm") logger.info(f"{sub_module_name} recive sigterm")

View File

@@ -2,6 +2,8 @@ import unittest
from sglang import OpenAI, set_default_backend from sglang import OpenAI, set_default_backend
from sglang.test.test_programs import ( from sglang.test.test_programs import (
test_chat_completion_speculative,
test_completion_speculative,
test_decode_int, test_decode_int,
test_decode_json, test_decode_json,
test_expert_answer, test_expert_answer,
@@ -14,8 +16,6 @@ from sglang.test.test_programs import (
test_select, test_select,
test_stream, test_stream,
test_tool_use, test_tool_use,
test_completion_speculative,
test_chat_completion_speculative
) )