Higher priority for user input of max_prefill_tokens & format (#540)
This commit is contained in:
@@ -24,10 +24,10 @@ from sglang.api import (
|
||||
|
||||
# SGL Backends
|
||||
from sglang.backend.anthropic import Anthropic
|
||||
from sglang.backend.litellm import LiteLLM
|
||||
from sglang.backend.openai import OpenAI
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.backend.vertexai import VertexAI
|
||||
from sglang.backend.litellm import LiteLLM
|
||||
|
||||
# Global Configurations
|
||||
from sglang.global_config import global_config
|
||||
|
||||
@@ -33,7 +33,8 @@ class LiteLLM(BaseBackend):
|
||||
self.model_name = model_name
|
||||
|
||||
self.chat_template = chat_template or get_chat_template_by_model_path(
|
||||
model_name)
|
||||
model_name
|
||||
)
|
||||
|
||||
self.client_params = {
|
||||
"api_key": api_key,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
import dataclasses
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -105,14 +105,16 @@ class OpenAI(BaseBackend):
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
|
||||
num_api_spec_tokens: int, spec_var_name: str):
|
||||
def _prepare_spec_execution(
|
||||
self,
|
||||
sampling_params: SglSamplingParams,
|
||||
num_api_spec_tokens: int,
|
||||
spec_var_name: str,
|
||||
):
|
||||
if "max_tokens" not in self.spec_kwargs:
|
||||
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
||||
else:
|
||||
assert (
|
||||
self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
||||
)
|
||||
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
||||
|
||||
params = sampling_params.to_openai_kwargs()
|
||||
for key, value in params.items():
|
||||
@@ -151,8 +153,9 @@ class OpenAI(BaseBackend):
|
||||
)
|
||||
prompt = s.messages_
|
||||
else:
|
||||
return self._prepare_spec_execution(sampling_params,
|
||||
s.num_api_spec_tokens, spec_var_name)
|
||||
return self._prepare_spec_execution(
|
||||
sampling_params, s.num_api_spec_tokens, spec_var_name
|
||||
)
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
@@ -325,7 +328,7 @@ class OpenAI(BaseBackend):
|
||||
ret_str = ret.choices[0].text
|
||||
ret_token = self.tokenizer.encode(ret_str)[0]
|
||||
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
self.token_usage.completion_tokens= ret.usage.completion_tokens
|
||||
self.token_usage.completion_tokens = ret.usage.completion_tokens
|
||||
|
||||
# TODO:
|
||||
# 1. return logits as the scores
|
||||
@@ -355,7 +358,9 @@ class OpenAI(BaseBackend):
|
||||
return decision, scores, None, None
|
||||
|
||||
|
||||
def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
|
||||
def openai_completion(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
):
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
|
||||
return comp
|
||||
|
||||
|
||||
def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
|
||||
def openai_completion_stream(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
):
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
if "stop" in kwargs and kwargs["stop"] is None:
|
||||
kwargs.pop("stop")
|
||||
generator = client.chat.completions.create(
|
||||
messages=prompt, stream=True, stream_options={"include_usage": True},
|
||||
**kwargs
|
||||
messages=prompt,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
**kwargs,
|
||||
)
|
||||
for ret in generator:
|
||||
if len(ret.choices) == 0:
|
||||
@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
|
||||
yield content or "", {}
|
||||
else:
|
||||
generator = client.completions.create(
|
||||
prompt=prompt, stream=True, stream_options={"include_usage": True},
|
||||
**kwargs
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
**kwargs,
|
||||
)
|
||||
for ret in generator:
|
||||
if len(ret.choices) == 0:
|
||||
|
||||
@@ -507,7 +507,7 @@ class StreamExecutor:
|
||||
)
|
||||
return
|
||||
|
||||
else: # Speculative execution on models with completion interface
|
||||
else: # Speculative execution on models with completion interface
|
||||
comp, meta_info = self._spec_gen(sampling_params)
|
||||
|
||||
self.text_ += comp
|
||||
|
||||
@@ -81,12 +81,10 @@ class SglSamplingParams:
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
|
||||
def to_litellm_kwargs(self):
|
||||
if self.regex is not None:
|
||||
warnings.warn(
|
||||
"Regular expression is not supported in the LiteLLM backend."
|
||||
)
|
||||
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"stop": self.stop or None,
|
||||
|
||||
@@ -10,4 +10,4 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
launch_server(server_args, None)
|
||||
launch_server(server_args, None)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Launch the inference server for Llava-video model."""
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Dict, Optional, Union
|
||||
from outlines.caching import cache as disk_cache
|
||||
from outlines.caching import disable_cache
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm
|
||||
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
||||
from outlines.models.transformers import TransformerTokenizer
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Cache for the compressed finite state machine."""
|
||||
|
||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||
from sglang.srt.constrained.base_cache import BaseCache
|
||||
|
||||
|
||||
@@ -8,11 +8,12 @@ from collections import defaultdict
|
||||
|
||||
import interegular
|
||||
import outlines.caching
|
||||
|
||||
from sglang.srt.constrained import (
|
||||
FSMInfo,
|
||||
disk_cache,
|
||||
make_deterministic_fsm,
|
||||
make_byte_level_fsm,
|
||||
make_deterministic_fsm,
|
||||
)
|
||||
from sglang.srt.constrained.base_cache import BaseCache
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Conversation templates."""
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
import dataclasses
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Utilities for Huggingface Transformers."""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
import functools
|
||||
from typing import Optional, Union, AbstractSet, Collection, Literal
|
||||
from typing import AbstractSet, Collection, Literal, Optional, Union
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import (
|
||||
@@ -179,6 +179,7 @@ def get_processor(
|
||||
class TiktokenTokenizer:
|
||||
def __init__(self, tokenizer_path):
|
||||
import tiktoken
|
||||
|
||||
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||
|
||||
# Read JSON
|
||||
@@ -190,7 +191,8 @@ class TiktokenTokenizer:
|
||||
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
|
||||
}
|
||||
special_tokens = {
|
||||
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
|
||||
bytes(item["bytes"]).decode(): item["token"]
|
||||
for item in tok_dict["special_tokens"]
|
||||
}
|
||||
assert tok_dict["word_split"] == "V1"
|
||||
|
||||
@@ -202,7 +204,10 @@ class TiktokenTokenizer:
|
||||
}
|
||||
if "default_allowed_special" in tok_dict:
|
||||
default_allowed_special = set(
|
||||
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
|
||||
[
|
||||
bytes(bytes_list).decode()
|
||||
for bytes_list in tok_dict["default_allowed_special"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
default_allowed_special = None
|
||||
@@ -216,14 +221,20 @@ class TiktokenTokenizer:
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
||||
allowed_special: Union[
|
||||
Literal["all"], AbstractSet[str]
|
||||
] = set(), # noqa: B006
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
) -> list[int]:
|
||||
if isinstance(allowed_special, set):
|
||||
allowed_special |= self._default_allowed_special
|
||||
return tiktoken.Encoding.encode(
|
||||
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
|
||||
self,
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
|
||||
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
||||
|
||||
# Convert to HF interface
|
||||
@@ -237,10 +248,14 @@ class TiktokenTokenizer:
|
||||
def decode(self, x):
|
||||
return self.tokenizer.decode(x)
|
||||
|
||||
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
|
||||
def batch_decode(
|
||||
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
||||
):
|
||||
if isinstance(batch[0], int):
|
||||
batch = [[x] for x in batch]
|
||||
return self.tokenizer.decode_batch(batch)
|
||||
|
||||
def convert_ids_to_tokens(self, index):
|
||||
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
|
||||
return self.tokenizer.decode_single_token_bytes(index).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
@@ -109,12 +108,16 @@ def fused_moe_kernel(
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
||||
)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
||||
offs_bn[None, :] * stride_bn)
|
||||
b_ptrs = (
|
||||
b_ptr
|
||||
+ off_experts * stride_be
|
||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
)
|
||||
|
||||
if use_fp8:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
@@ -130,13 +133,12 @@ def fused_moe_kernel(
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
a = tl.load(a_ptrs,
|
||||
mask=token_mask[:, None] &
|
||||
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
if use_fp8:
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
@@ -147,9 +149,7 @@ def fused_moe_kernel(
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||
mask=token_mask,
|
||||
other=0)
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
|
||||
if use_fp8:
|
||||
@@ -159,15 +159,14 @@ def fused_moe_kernel(
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||
None, :]
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids: torch.Tensor, block_size: int,
|
||||
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
size for matrix multiplication.
|
||||
@@ -206,32 +205,38 @@ def moe_align_block_size(
|
||||
by block_size for proper block matrix operations.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
sorted_ids = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
expert_ids = torch.empty((max_num_m_blocks, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
||||
expert_ids, num_tokens_post_pad)
|
||||
expert_ids = torch.empty(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
ops.moe_align_block_size(
|
||||
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
||||
)
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool, top_k: int,
|
||||
config: Dict[str, Any], compute_type: tl.dtype,
|
||||
use_fp8: bool) -> None:
|
||||
def invoke_fused_moe_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: Dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8: bool,
|
||||
) -> None:
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
@@ -242,8 +247,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
assert B_scale is not None
|
||||
|
||||
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
||||
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
||||
grid = lambda META: (
|
||||
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
@@ -281,8 +288,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_moe_configs(E: int, N: int,
|
||||
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
||||
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the fused MoE kernel.
|
||||
|
||||
@@ -297,11 +303,11 @@ def get_moe_configs(E: int, N: int,
|
||||
json_file_name = get_config_file_name(E, N, dtype)
|
||||
|
||||
config_file_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
||||
)
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path) as f:
|
||||
logger.info("Using configuration from %s for MoE layer.",
|
||||
config_file_path)
|
||||
logger.info("Using configuration from %s for MoE layer.", config_file_path)
|
||||
# If a configuration has been found, return it
|
||||
return {int(key): val for key, val in json.load(f).items()}
|
||||
|
||||
@@ -352,40 +358,30 @@ def fused_moe(
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||||
M, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
|
||||
if is_hip():
|
||||
# The MoE kernels are not yet supported on ROCm.
|
||||
routing_weights = torch.softmax(gating_output,
|
||||
dim=-1,
|
||||
dtype=torch.float32)
|
||||
routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
|
||||
else:
|
||||
import vllm._moe_C as moe_kernels
|
||||
|
||||
topk_weights = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
token_expert_indicies = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||
token_expert_indicies = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
moe_kernels.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
@@ -400,8 +396,7 @@ def fused_moe(
|
||||
config = override_config
|
||||
else:
|
||||
# First try to load optimal config from the file
|
||||
configs = get_moe_configs(E, w2.shape[2],
|
||||
"float8" if use_fp8 else None)
|
||||
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
||||
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
@@ -415,7 +410,7 @@ def fused_moe(
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
"num_stages": 4,
|
||||
}
|
||||
|
||||
if M <= E:
|
||||
@@ -425,61 +420,72 @@ def fused_moe(
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
"num_stages": 4,
|
||||
}
|
||||
|
||||
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache1 = torch.empty(
|
||||
(M, topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache3 = torch.empty(
|
||||
(M, topk_ids.shape[1], w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, config['BLOCK_SIZE_M'], E)
|
||||
compute_type = (tl.bfloat16
|
||||
if hidden_states.dtype == torch.bfloat16 else tl.float16)
|
||||
topk_ids, config["BLOCK_SIZE_M"], E
|
||||
)
|
||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||
|
||||
invoke_fused_moe_kernel(hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
topk_ids.shape[1],
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8)
|
||||
invoke_fused_moe_kernel(
|
||||
hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
topk_ids.shape[1],
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8,
|
||||
)
|
||||
|
||||
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
|
||||
invoke_fused_moe_kernel(intermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8)
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8,
|
||||
)
|
||||
|
||||
if inplace:
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=hidden_states)
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1)
|
||||
return torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=hidden_states,
|
||||
)
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Logits processing."""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.distributed import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Radix attention."""
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
@@ -10,7 +11,9 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
|
||||
def __init__(
|
||||
self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_q_head_num = num_heads
|
||||
self.tp_k_head_num = num_kv_heads
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import List, Callable
|
||||
from typing import Callable, List
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
@@ -70,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread):
|
||||
|
||||
# async sleep for receiving the subsequent request and avoiding cache miss
|
||||
if len(out_pyobjs) != 0:
|
||||
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
|
||||
has_finished = any(
|
||||
[obj.finished_reason is not None for obj in out_pyobjs]
|
||||
)
|
||||
if has_finished:
|
||||
await asyncio.sleep(self.request_dependency_delay)
|
||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||
@@ -108,4 +110,4 @@ def start_data_parallel_worker(
|
||||
step_func=model_tp_client.step,
|
||||
)
|
||||
worker_thread.start()
|
||||
return worker_thread
|
||||
return worker_thread
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
"""Meta data for requests and batches"""
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import List
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
|
||||
@@ -13,15 +13,15 @@ import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.managers.controller.dp_worker import (
|
||||
DataParallelWorkerThread,
|
||||
start_data_parallel_worker,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
FlushCacheReq,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.controller.dp_worker import (
|
||||
DataParallelWorkerThread,
|
||||
start_data_parallel_worker,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
@@ -136,7 +136,7 @@ class Controller:
|
||||
self.recv_reqs = []
|
||||
if next_step_input:
|
||||
await self.dispatching(next_step_input)
|
||||
#else:
|
||||
# else:
|
||||
# logger.error("There is no live worker.")
|
||||
|
||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""A controller that manages a group of tensor parallel workers."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
@@ -49,7 +50,9 @@ class ControllerSingle:
|
||||
# async sleep for receiving the subsequent request and avoiding cache miss
|
||||
slept = False
|
||||
if len(out_pyobjs) != 0:
|
||||
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
|
||||
has_finished = any(
|
||||
[obj.finished_reason is not None for obj in out_pyobjs]
|
||||
)
|
||||
if has_finished:
|
||||
if self.request_dependency_delay > 0:
|
||||
slept = True
|
||||
@@ -94,4 +97,4 @@ def start_controller_process(
|
||||
except Exception:
|
||||
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
||||
finally:
|
||||
kill_parent_process()
|
||||
kill_parent_process()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
import importlib
|
||||
import importlib.resources
|
||||
import logging
|
||||
@@ -12,15 +13,18 @@ import torch
|
||||
import torch.nn as nn
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
from vllm.distributed import initialize_model_parallel, init_distributed_environment
|
||||
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
|
||||
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
is_multimodal_model,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("srt.model_runner")
|
||||
|
||||
@@ -441,7 +445,9 @@ def import_model_classes():
|
||||
module = importlib.import_module(name)
|
||||
if hasattr(module, "EntryClass"):
|
||||
entry = module.EntryClass
|
||||
if isinstance(entry, list): # To support multiple model classes in one module
|
||||
if isinstance(
|
||||
entry, list
|
||||
): # To support multiple model classes in one module
|
||||
for tmp in entry:
|
||||
model_arch_name_to_cls[tmp.__name__] = tmp
|
||||
else:
|
||||
@@ -449,7 +455,9 @@ def import_model_classes():
|
||||
|
||||
# compat: some models such as chatglm has incorrect class set in config.json
|
||||
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
|
||||
if hasattr(module, "EntryClassRemapping") and isinstance(module.EntryClassRemapping, list):
|
||||
if hasattr(module, "EntryClassRemapping") and isinstance(
|
||||
module.EntryClassRemapping, list
|
||||
):
|
||||
for remap in module.EntryClassRemapping:
|
||||
if isinstance(remap, tuple) and len(remap) == 2:
|
||||
model_arch_name_to_cls[remap[0]] = remap[1]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
The radix tree data structure for managing the KV cache.
|
||||
"""
|
||||
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Request scheduler heuristic."""
|
||||
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
@@ -15,22 +15,22 @@ from sglang.global_config import global_config
|
||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchTokenIDOut,
|
||||
FlushCacheReq,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.controller.infer_batch import (
|
||||
FINISH_ABORT,
|
||||
BaseFinishReason,
|
||||
Batch,
|
||||
FINISH_ABORT,
|
||||
ForwardMode,
|
||||
Req,
|
||||
)
|
||||
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchTokenIDOut,
|
||||
FlushCacheReq,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
@@ -96,13 +96,13 @@ class ModelTpServer:
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
||||
self.max_prefill_tokens = max(
|
||||
self.model_config.context_len,
|
||||
(
|
||||
min(self.max_total_num_tokens // 6, 65536)
|
||||
if server_args.max_prefill_tokens is None
|
||||
else server_args.max_prefill_tokens
|
||||
),
|
||||
self.max_prefill_tokens = (
|
||||
max(
|
||||
self.model_config.context_len,
|
||||
min(self.max_total_num_tokens // 6, 65536),
|
||||
)
|
||||
if server_args.max_prefill_tokens is None
|
||||
else server_args.max_prefill_tokens
|
||||
)
|
||||
self.max_running_requests = (
|
||||
self.max_total_num_tokens // 2
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
@@ -7,10 +8,10 @@ import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
||||
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.utils import get_exception_traceback, graceful_registry
|
||||
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""TokenizerManager is a process that tokenizes the text."""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import transformers
|
||||
@@ -23,11 +24,11 @@ from sglang.srt.hf_transformers_utils import (
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -91,7 +92,7 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
self.to_create_loop = True
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
|
||||
async def get_pixel_values(self, image_data):
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
@@ -322,7 +323,6 @@ class TokenizerManager:
|
||||
state.finished = recv_obj.finished_reason[i] is not None
|
||||
state.event.set()
|
||||
|
||||
|
||||
def convert_logprob_style(
|
||||
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
||||
):
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(
|
||||
@@ -17,8 +18,12 @@ class ModelConfig:
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.revision = revision
|
||||
self.model_overide_args = model_overide_args
|
||||
self.hf_config = get_config(self.path, trust_remote_code, revision,
|
||||
model_overide_args=model_overide_args)
|
||||
self.hf_config = get_config(
|
||||
self.path,
|
||||
trust_remote_code,
|
||||
revision,
|
||||
model_overide_args=model_overide_args,
|
||||
)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
if context_length is not None:
|
||||
self.context_len = context_length
|
||||
@@ -55,18 +60,23 @@ class ModelConfig:
|
||||
# KV heads.
|
||||
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||
new_decoder_arch_falcon = (
|
||||
self.hf_config.model_type in falcon_model_types
|
||||
and getattr(self.hf_config, "new_decoder_architecture", False))
|
||||
if not new_decoder_arch_falcon and getattr(self.hf_text_config,
|
||||
"multi_query", False):
|
||||
self.hf_config.model_type in falcon_model_types
|
||||
and getattr(self.hf_config, "new_decoder_architecture", False)
|
||||
)
|
||||
if not new_decoder_arch_falcon and getattr(
|
||||
self.hf_text_config, "multi_query", False
|
||||
):
|
||||
# Multi-query attention, only one KV head.
|
||||
# Currently, tensor parallelism is not supported in this case.
|
||||
return 1
|
||||
|
||||
# For DBRX and MPT
|
||||
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
||||
return getattr(self.hf_config.attn_config, "kv_n_heads",
|
||||
self.hf_config.num_attention_heads)
|
||||
return getattr(
|
||||
self.hf_config.attn_config,
|
||||
"kv_n_heads",
|
||||
self.hf_config.num_attention_heads,
|
||||
)
|
||||
|
||||
attributes = [
|
||||
# For Falcon:
|
||||
@@ -94,13 +104,12 @@ class ModelConfig:
|
||||
# the tensor parallel size. We will replicate the KV heads in the
|
||||
# case where the number of KV heads is smaller than the tensor
|
||||
# parallel size so each GPU has at least one KV head.
|
||||
return max(1,
|
||||
total_num_kv_heads // tensor_parallel_size)
|
||||
return max(1, total_num_kv_heads // tensor_parallel_size)
|
||||
|
||||
|
||||
def get_hf_text_config(config: PretrainedConfig):
|
||||
"""Get the "sub" config relevant to llm for multi modal models.
|
||||
No op for pure text models.
|
||||
No op for pure text models.
|
||||
"""
|
||||
if hasattr(config, "text_config"):
|
||||
# The code operates under the assumption that text_config should have
|
||||
|
||||
@@ -5,30 +5,32 @@
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
LoraConfig = None
|
||||
|
||||
@@ -49,9 +51,11 @@ class GLMAttention(nn.Module):
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.multi_query_attention = config.multi_query_attention
|
||||
self.total_num_kv_heads = (config.multi_query_group_num
|
||||
if config.multi_query_attention else
|
||||
config.num_attention_heads)
|
||||
self.total_num_kv_heads = (
|
||||
config.multi_query_group_num
|
||||
if config.multi_query_attention
|
||||
else config.num_attention_heads
|
||||
)
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
@@ -91,11 +95,13 @@ class GLMAttention(nn.Module):
|
||||
base=10000 * rope_ratio,
|
||||
is_neox_style=False,
|
||||
)
|
||||
self.attn = RadixAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -176,14 +182,16 @@ class GLMBlock(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
config.apply_residual_connection_post_layernorm)
|
||||
config.apply_residual_connection_post_layernorm
|
||||
)
|
||||
|
||||
self.fp32_residual_connection = config.fp32_residual_connection
|
||||
|
||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = layer_norm_func(config.hidden_size,
|
||||
eps=config.layernorm_epsilon)
|
||||
self.input_layernorm = layer_norm_func(
|
||||
config.hidden_size, eps=config.layernorm_epsilon
|
||||
)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
|
||||
@@ -191,7 +199,8 @@ class GLMBlock(nn.Module):
|
||||
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = layer_norm_func(
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
config.hidden_size, eps=config.layernorm_epsilon
|
||||
)
|
||||
|
||||
# MLP
|
||||
self.mlp = GLMMLP(config, quant_config)
|
||||
@@ -250,16 +259,19 @@ class GLMTransformer(nn.Module):
|
||||
self.num_layers = config.num_layers
|
||||
|
||||
# Transformer layers.
|
||||
self.layers = nn.ModuleList([
|
||||
GLMBlock(config, i, cache_config, quant_config)
|
||||
for i in range(self.num_layers)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
GLMBlock(config, i, cache_config, quant_config)
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if self.post_layer_norm:
|
||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
# Final layer norm before output.
|
||||
self.final_layernorm = layer_norm_func(
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
config.hidden_size, eps=config.layernorm_epsilon
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -291,16 +303,16 @@ class ChatGLMModel(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
|
||||
config.hidden_size)
|
||||
self.embedding = VocabParallelEmbedding(
|
||||
config.padded_vocab_size, config.hidden_size
|
||||
)
|
||||
|
||||
self.num_layers = config.num_layers
|
||||
self.multi_query_group_num = config.multi_query_group_num
|
||||
self.kv_channels = config.kv_channels
|
||||
self.encoder = GLMTransformer(config, cache_config, quant_config)
|
||||
|
||||
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
||||
config.hidden_size)
|
||||
self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -322,7 +334,7 @@ class ChatGLMModel(nn.Module):
|
||||
class ChatGLMForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
"dense_h_to_4h": ["dense_h_to_4h"],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
@@ -344,8 +356,7 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
super().__init__()
|
||||
self.config: ChatGLMConfig = config
|
||||
self.quant_config = quant_config
|
||||
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
||||
8192)
|
||||
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
|
||||
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
||||
self.lm_head = self.transformer.output_layer
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
@@ -357,8 +368,7 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
input_metadata)
|
||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
@@ -382,10 +392,10 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = ChatGLMForCausalLM
|
||||
# compat: glm model.config class == ChatGLMModel
|
||||
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
|
||||
# This file is based on the LLama model definition file in transformers
|
||||
"""PyTorch Cohere model."""
|
||||
from typing import Optional, Tuple, Iterable
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
|
||||
@@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig, CacheConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
||||
"""Inference-only Grok1 model."""
|
||||
from typing import Iterable, Optional, Tuple, List
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -9,7 +9,6 @@ import torch.nn.functional as F
|
||||
import tqdm
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (
|
||||
@@ -35,12 +34,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.fused_moe import fused_moe
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
use_fused = True
|
||||
|
||||
|
||||
@@ -134,9 +132,12 @@ class Grok1MoEUnfused(nn.Module):
|
||||
|
||||
final_hidden_states = torch.zeros(
|
||||
(hidden_states.shape[0], hidden_dim),
|
||||
dtype=hidden_states.dtype, device=hidden_states.device
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_total_experts).permute(2, 1, 0)
|
||||
expert_mask = torch.nn.functional.one_hot(
|
||||
selected_experts, num_classes=self.num_total_experts
|
||||
).permute(2, 1, 0)
|
||||
|
||||
for expert_idx in self.expert_indicies:
|
||||
expert_layer = self.experts[expert_idx]
|
||||
@@ -153,7 +154,10 @@ class Grok1MoEUnfused(nn.Module):
|
||||
# the current expert. We need to make sure to multiply the output hidden
|
||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
|
||||
current_hidden_states = (
|
||||
expert_layer(current_state)
|
||||
* routing_weights[top_x_list, idx_list, None]
|
||||
)
|
||||
|
||||
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||
# the `top_x` tensor here.
|
||||
@@ -198,32 +202,46 @@ class Grok1MoE(nn.Module):
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.gate = ReplicatedLinear(self.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
quant_config=None)
|
||||
self.gate = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
quant_config=None,
|
||||
)
|
||||
|
||||
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
self.w13_weight = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
dtype=params_dtype))
|
||||
torch.empty(
|
||||
self.num_total_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
dtype=params_dtype,
|
||||
)
|
||||
)
|
||||
self.w2_weight = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
dtype=params_dtype))
|
||||
torch.empty(
|
||||
self.num_total_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
dtype=params_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
set_weight_attrs(self.w13_weight, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.w2_weight, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(
|
||||
self.w13_weight,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.w2_weight,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
|
||||
# Used for fp8.
|
||||
self.w13_scale = None
|
||||
@@ -233,46 +251,69 @@ class Grok1MoE(nn.Module):
|
||||
|
||||
if self.use_fp8:
|
||||
# WEIGHT_SCALE (for fp8)
|
||||
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
self.w13_scale = nn.Parameter(
|
||||
torch.ones(self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.w2_scale = nn.Parameter(
|
||||
torch.ones(self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(self.w13_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.w2_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(
|
||||
self.w13_scale,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.w2_scale,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
|
||||
# ACT_SCALE (for fp8)
|
||||
if quant_config.activation_scheme == "static":
|
||||
if not quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8.")
|
||||
self.a13_scale = nn.Parameter(torch.zeros(
|
||||
self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
self.a2_scale = nn.Parameter(torch.zeros(
|
||||
self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
"was not serialized fp8."
|
||||
)
|
||||
self.a13_scale = nn.Parameter(
|
||||
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.a2_scale = nn.Parameter(
|
||||
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
set_weight_attrs(self.a13_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.a2_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(
|
||||
self.a13_scale,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.a2_scale,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||
weight_name: str, expert_id: int, pre_sharded: bool):
|
||||
def weight_loader(
|
||||
self,
|
||||
param: nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
expert_id: int,
|
||||
pre_sharded: bool,
|
||||
):
|
||||
param_data = param.data
|
||||
shard_size = self.intermediate_size
|
||||
if pre_sharded:
|
||||
@@ -284,8 +325,9 @@ class Grok1MoE(nn.Module):
|
||||
if weight_name.endswith("w1.weight"):
|
||||
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
||||
if weight_name.endswith("w3.weight"):
|
||||
param_data[expert_id,
|
||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
||||
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
||||
shard, :
|
||||
]
|
||||
if weight_name.endswith("w2.weight"):
|
||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
||||
@@ -298,17 +340,17 @@ class Grok1MoE(nn.Module):
|
||||
|
||||
# If checkpoint is fp16, quantize here.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
w13_weight = torch.empty_like(self.w13_weight.data,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_weight = torch.empty_like(self.w2_weight.data,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w13_weight = torch.empty_like(
|
||||
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
||||
for expert in range(self.num_total_experts):
|
||||
w13_weight[expert, :, :], self.w13_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
self.w13_weight.data[expert, :, :])
|
||||
w2_weight[expert, :, :], self.w2_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
self.w2_weight.data[expert, :, :])
|
||||
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
||||
self.w13_weight.data[expert, :, :]
|
||||
)
|
||||
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
self.w2_weight.data[expert, :, :]
|
||||
)
|
||||
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
||||
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
@@ -319,40 +361,40 @@ class Grok1MoE(nn.Module):
|
||||
if self.a13_scale is None or self.a2_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None.")
|
||||
"activation scales are None."
|
||||
)
|
||||
|
||||
if (not all_close_1d(self.a13_scale)
|
||||
or not all_close_1d(self.a2_scale)):
|
||||
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
||||
print_warning_once(
|
||||
"Found act_scales that are not equal for fp8 MoE layer. "
|
||||
"Using the maximum across experts for each layer. ")
|
||||
"Using the maximum across experts for each layer. "
|
||||
)
|
||||
|
||||
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
||||
requires_grad=False)
|
||||
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
||||
requires_grad=False)
|
||||
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
||||
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.w13_weight,
|
||||
self.w2_weight,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=False,
|
||||
inplace=True,
|
||||
use_fp8=self.use_fp8,
|
||||
w1_scale=self.w13_scale,
|
||||
w2_scale=self.w2_scale,
|
||||
a1_scale=self.a13_scale,
|
||||
a2_scale=self.a2_scale)
|
||||
final_hidden_states = fused_moe(
|
||||
hidden_states,
|
||||
self.w13_weight,
|
||||
self.w2_weight,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=False,
|
||||
inplace=True,
|
||||
use_fp8=self.use_fp8,
|
||||
w1_scale=self.w13_scale,
|
||||
w2_scale=self.w2_scale,
|
||||
a1_scale=self.a13_scale,
|
||||
a2_scale=self.a2_scale,
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_size)
|
||||
|
||||
@@ -462,10 +504,12 @@ class Grok1DecoderLayer(nn.Module):
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.block_sparse_moe = Grok1MoEUnfused(
|
||||
config=config, quant_config=quant_config)
|
||||
config=config, quant_config=quant_config
|
||||
)
|
||||
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -478,12 +522,21 @@ class Grok1DecoderLayer(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.post_attn_norm(self.self_attn(
|
||||
positions=positions, hidden_states=self.pre_attn_norm(hidden_states),
|
||||
input_metadata=input_metadata,
|
||||
)) + hidden_states
|
||||
hidden_states = (
|
||||
self.post_attn_norm(
|
||||
self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=self.pre_attn_norm(hidden_states),
|
||||
input_metadata=input_metadata,
|
||||
)
|
||||
)
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + hidden_states
|
||||
hidden_states = (
|
||||
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -525,9 +578,7 @@ class Grok1Model(nn.Module):
|
||||
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
hidden_states = self.layers[i](
|
||||
positions, hidden_states, input_metadata
|
||||
)
|
||||
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states.mul_(self.config.output_multiplier_scale)
|
||||
@@ -572,28 +623,41 @@ class Grok1ModelForCausalLM(nn.Module):
|
||||
]
|
||||
|
||||
if use_fused:
|
||||
expert_params_mapping = [
|
||||
# These are the weight scales for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
] + [
|
||||
# These are the weights for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
||||
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
] + [
|
||||
# These are the activation scales for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
expert_params_mapping = (
|
||||
[
|
||||
# These are the weight scales for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
(
|
||||
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.weight_scale",
|
||||
expert_id,
|
||||
)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
+ [
|
||||
# These are the weights for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
(
|
||||
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
||||
f"experts.{expert_id}.{weight_name}.weight",
|
||||
expert_id,
|
||||
)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
+ [
|
||||
# These are the activation scales for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
(
|
||||
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.act_scale",
|
||||
expert_id,
|
||||
)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
expert_params_mapping = []
|
||||
|
||||
@@ -601,11 +665,11 @@ class Grok1ModelForCausalLM(nn.Module):
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
|
||||
for name, loaded_weight in weights:
|
||||
#print(get_tensor_model_parallel_rank(), name)
|
||||
# print(get_tensor_model_parallel_rank(), name)
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -623,19 +687,22 @@ class Grok1ModelForCausalLM(nn.Module):
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
expert_id=expert_id,
|
||||
pre_sharded=get_tensor_model_parallel_world_size() > 1)
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
expert_id=expert_id,
|
||||
pre_sharded=get_tensor_model_parallel_world_size() > 1,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
@@ -645,10 +712,11 @@ def all_close_1d(x: torch.Tensor) -> bool:
|
||||
|
||||
|
||||
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
||||
def _prepare_presharded_weights(self,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
|
||||
|
||||
|
||||
def _prepare_presharded_weights(
|
||||
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
import glob
|
||||
import os
|
||||
|
||||
@@ -668,4 +736,4 @@ def _prepare_presharded_weights(self,
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
|
||||
EntryClass = Grok1ModelForCausalLM
|
||||
EntryClass = Grok1ModelForCausalLM
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Optional, Tuple, Iterable
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
@@ -10,7 +10,7 @@ from transformers import LlamaConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -158,9 +158,11 @@ class LlamaDecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
config, "original_max_position_embeddings", None
|
||||
):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
config.original_max_position_embeddings
|
||||
)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.self_attn = LlamaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import List, Iterable, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig
|
||||
from transformers import (
|
||||
CLIPVisionConfig,
|
||||
CLIPVisionModel,
|
||||
LlavaConfig,
|
||||
MistralConfig,
|
||||
Qwen2Config,
|
||||
)
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -19,8 +25,8 @@ from sglang.srt.mm_utils import (
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
|
||||
|
||||
class LlavaLlamaForCausalLM(nn.Module):
|
||||
@@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
||||
|
||||
first_call = True
|
||||
|
||||
|
||||
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
@@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward():
|
||||
)
|
||||
|
||||
|
||||
EntryClass = [
|
||||
LlavaLlamaForCausalLM,
|
||||
LlavaQwenForCausalLM,
|
||||
LlavaMistralForCausalLM
|
||||
]
|
||||
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import List, Iterable, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
||||
across all ranks.
|
||||
@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.gate = ReplicatedLinear(self.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
quant_config=None)
|
||||
self.gate = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
quant_config=None,
|
||||
)
|
||||
|
||||
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
self.w13_weight = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
dtype=params_dtype))
|
||||
torch.empty(
|
||||
self.num_total_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
dtype=params_dtype,
|
||||
)
|
||||
)
|
||||
self.w2_weight = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
dtype=params_dtype))
|
||||
torch.empty(
|
||||
self.num_total_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
dtype=params_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
set_weight_attrs(self.w13_weight, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.w2_weight, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(
|
||||
self.w13_weight,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.w2_weight,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
|
||||
# Used for fp8.
|
||||
self.w13_scale = None
|
||||
@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
|
||||
|
||||
if self.use_fp8:
|
||||
# WEIGHT_SCALE (for fp8)
|
||||
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
self.w13_scale = nn.Parameter(
|
||||
torch.ones(self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.w2_scale = nn.Parameter(
|
||||
torch.ones(self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(self.w13_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.w2_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(
|
||||
self.w13_scale,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.w2_scale,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
|
||||
# ACT_SCALE (for fp8)
|
||||
if quant_config.activation_scheme == "static":
|
||||
if not quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8.")
|
||||
self.a13_scale = nn.Parameter(torch.zeros(
|
||||
self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
self.a2_scale = nn.Parameter(torch.zeros(
|
||||
self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
"was not serialized fp8."
|
||||
)
|
||||
self.a13_scale = nn.Parameter(
|
||||
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.a2_scale = nn.Parameter(
|
||||
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
set_weight_attrs(self.a13_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.a2_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(
|
||||
self.a13_scale,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.a2_scale,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||
weight_name: str, expert_id: int):
|
||||
def weight_loader(
|
||||
self,
|
||||
param: nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
expert_id: int,
|
||||
):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param.data
|
||||
shard_size = self.intermediate_size
|
||||
@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
|
||||
if weight_name.endswith("w1.weight"):
|
||||
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
||||
if weight_name.endswith("w3.weight"):
|
||||
param_data[expert_id,
|
||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
||||
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
||||
shard, :
|
||||
]
|
||||
if weight_name.endswith("w2.weight"):
|
||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
||||
@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
|
||||
|
||||
# If checkpoint is fp16, quantize here.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
w13_weight = torch.empty_like(self.w13_weight.data,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_weight = torch.empty_like(self.w2_weight.data,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w13_weight = torch.empty_like(
|
||||
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
||||
for expert in range(self.num_total_experts):
|
||||
w13_weight[expert, :, :], self.w13_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
self.w13_weight.data[expert, :, :])
|
||||
w2_weight[expert, :, :], self.w2_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
self.w2_weight.data[expert, :, :])
|
||||
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
||||
self.w13_weight.data[expert, :, :]
|
||||
)
|
||||
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
self.w2_weight.data[expert, :, :]
|
||||
)
|
||||
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
||||
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
@@ -193,40 +228,40 @@ class MixtralMoE(nn.Module):
|
||||
if self.a13_scale is None or self.a2_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None.")
|
||||
"activation scales are None."
|
||||
)
|
||||
|
||||
if (not all_close_1d(self.a13_scale)
|
||||
or not all_close_1d(self.a2_scale)):
|
||||
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
||||
print_warning_once(
|
||||
"Found act_scales that are not equal for fp8 MoE layer. "
|
||||
"Using the maximum across experts for each layer. ")
|
||||
"Using the maximum across experts for each layer. "
|
||||
)
|
||||
|
||||
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
||||
requires_grad=False)
|
||||
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
||||
requires_grad=False)
|
||||
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
||||
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.w13_weight,
|
||||
self.w2_weight,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8=self.use_fp8,
|
||||
w1_scale=self.w13_scale,
|
||||
w2_scale=self.w2_scale,
|
||||
a1_scale=self.a13_scale,
|
||||
a2_scale=self.a2_scale)
|
||||
final_hidden_states = fused_moe(
|
||||
hidden_states,
|
||||
self.w13_weight,
|
||||
self.w2_weight,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8=self.use_fp8,
|
||||
w1_scale=self.w13_scale,
|
||||
w2_scale=self.w2_scale,
|
||||
a1_scale=self.a13_scale,
|
||||
a2_scale=self.a2_scale,
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_size)
|
||||
|
||||
@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
expert_params_mapping = [
|
||||
# These are the weight scales for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
] + [
|
||||
# These are the weights for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
||||
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
] + [
|
||||
# These are the activation scales for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
expert_params_mapping = (
|
||||
[
|
||||
# These are the weight scales for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
(
|
||||
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.weight_scale",
|
||||
expert_id,
|
||||
)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
+ [
|
||||
# These are the weights for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
(
|
||||
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
||||
f"experts.{expert_id}.{weight_name}.weight",
|
||||
expert_id,
|
||||
)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
+ [
|
||||
# These are the activation scales for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
(
|
||||
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.act_scale",
|
||||
expert_id,
|
||||
)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
expert_id=expert_id)
|
||||
weight_loader(
|
||||
param, loaded_weight, weight_name, expert_id=expert_id
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
||||
from typing import Any, Dict, Optional, Iterable, Tuple
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Adapted from llama2.py
|
||||
# Modify details for the adaptation of Qwen2 model.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Optional, Tuple, Iterable
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
|
||||
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
|
||||
model compatible with HuggingFace weights."""
|
||||
from typing import Optional, Tuple, Iterable
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""Inference-only Yi-VL model."""
|
||||
|
||||
from typing import Tuple, Iterable, Optional
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.models.llava import (
|
||||
LlavaLlamaForCausalLM,
|
||||
monkey_path_clip_vision_embed_forward,
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from sglang.srt.conversation import (
|
||||
Conversation,
|
||||
@@ -40,21 +40,18 @@ chat_template_name = None
|
||||
def create_error_response(
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
|
||||
error = ErrorResponse(message=message,
|
||||
type=err_type,
|
||||
code=status_code.value)
|
||||
return JSONResponse(content=error.model_dump(),
|
||||
status_code=error.code)
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
):
|
||||
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
|
||||
return JSONResponse(content=error.model_dump(), status_code=error.code)
|
||||
|
||||
|
||||
def create_streaming_error_response(
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
|
||||
error = ErrorResponse(message=message,
|
||||
type=err_type,
|
||||
code=status_code.value)
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
) -> str:
|
||||
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
|
||||
json_str = json.dumps({"error": error.model_dump()})
|
||||
return json_str
|
||||
|
||||
@@ -125,7 +122,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
n_prev_token = 0
|
||||
try:
|
||||
async for content in tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request):
|
||||
adapted_request, raw_request
|
||||
):
|
||||
text = content["text"]
|
||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||
@@ -154,12 +152,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
decode_token_logprobs=content["meta_info"][
|
||||
"decode_token_logprobs"
|
||||
][n_prev_token:],
|
||||
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
|
||||
n_prev_token:
|
||||
],
|
||||
decode_top_logprobs=content["meta_info"][
|
||||
"decode_top_logprobs"
|
||||
][n_prev_token:],
|
||||
)
|
||||
|
||||
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
|
||||
n_prev_token = len(
|
||||
content["meta_info"]["decode_token_logprobs"]
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
@@ -188,13 +188,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
yield f"data: {error}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
|
||||
background=tokenizer_manager.create_abort_task(adapted_request))
|
||||
return StreamingResponse(
|
||||
generate_stream_resp(),
|
||||
media_type="text/event-stream",
|
||||
background=tokenizer_manager.create_abort_task(adapted_request),
|
||||
)
|
||||
|
||||
# Non-streaming response.
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request).__anext__()
|
||||
adapted_request, raw_request
|
||||
).__anext__()
|
||||
except ValueError as e:
|
||||
return create_error_response(str(e))
|
||||
|
||||
@@ -299,7 +303,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
|
||||
stream_buffer = ""
|
||||
try:
|
||||
async for content in tokenizer_manager.generate_request(adapted_request, raw_request):
|
||||
async for content in tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
if is_first:
|
||||
# First chunk with role
|
||||
is_first = False
|
||||
@@ -334,13 +340,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
yield f"data: {error}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
|
||||
background=tokenizer_manager.create_abort_task(adapted_request))
|
||||
return StreamingResponse(
|
||||
generate_stream_resp(),
|
||||
media_type="text/event-stream",
|
||||
background=tokenizer_manager.create_abort_task(adapted_request),
|
||||
)
|
||||
|
||||
# Non-streaming response.
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request).__anext__()
|
||||
adapted_request, raw_request
|
||||
).__anext__()
|
||||
except ValueError as e:
|
||||
return create_error_response(str(e))
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -29,10 +29,14 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.constrained import disable_cache
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.controller.manager_multi import (
|
||||
start_controller_process as start_controller_process_multi,
|
||||
)
|
||||
from sglang.srt.managers.controller.manager_single import (
|
||||
start_controller_process as start_controller_process_single,
|
||||
)
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
|
||||
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.openai_api_adapter import (
|
||||
load_chat_template_for_openai_api,
|
||||
@@ -97,8 +101,11 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream",
|
||||
background=tokenizer_manager.create_abort_task(obj))
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
background=tokenizer_manager.create_abort_task(obj),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Common utilities."""
|
||||
|
||||
import base64
|
||||
import multiprocessing
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
@@ -17,12 +17,11 @@ import requests
|
||||
import rpyc
|
||||
import torch
|
||||
import triton
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -377,7 +376,7 @@ def init_rpyc_service(service: rpyc.Service, port: int):
|
||||
protocol_config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600
|
||||
"sync_request_timeout": 3600,
|
||||
},
|
||||
)
|
||||
t.logger.setLevel(logging.WARN)
|
||||
@@ -396,7 +395,7 @@ def connect_to_rpyc_service(port, host="localhost"):
|
||||
config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600
|
||||
"sync_request_timeout": 3600,
|
||||
},
|
||||
)
|
||||
break
|
||||
@@ -423,7 +422,9 @@ def suppress_other_loggers():
|
||||
|
||||
vllm_default_logger.setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
||||
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
|
||||
logging.WARN
|
||||
)
|
||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||
|
||||
@@ -464,6 +465,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
||||
device_name = torch.cuda.get_device_name(gpu_id)
|
||||
if "RTX 40" not in device_name:
|
||||
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
||||
|
||||
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
||||
|
||||
|
||||
@@ -485,4 +487,3 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
@@ -356,16 +356,25 @@ def test_completion_speculative():
|
||||
s += "Construct a character within the following format:\n"
|
||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
|
||||
s += (
|
||||
"Name:"
|
||||
+ sgl.gen("name", stop="\n")
|
||||
+ "\nBirthday:"
|
||||
+ sgl.gen("birthday", stop="\n")
|
||||
)
|
||||
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
||||
|
||||
|
||||
@sgl.function
|
||||
def gen_character_no_spec(s):
|
||||
s += "Construct a character within the following format:\n"
|
||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
|
||||
s += (
|
||||
"Name:"
|
||||
+ sgl.gen("name", stop="\n")
|
||||
+ "\nBirthday:"
|
||||
+ sgl.gen("birthday", stop="\n")
|
||||
)
|
||||
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
||||
|
||||
token_usage = sgl.global_config.default_backend.token_usage
|
||||
@@ -378,7 +387,9 @@ def test_completion_speculative():
|
||||
gen_character_no_spec().sync()
|
||||
usage_with_no_spec = token_usage.prompt_tokens
|
||||
|
||||
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
|
||||
assert (
|
||||
usage_with_spec < usage_with_no_spec
|
||||
), f"{usage_with_spec} vs {usage_with_no_spec}"
|
||||
|
||||
|
||||
def test_chat_completion_speculative():
|
||||
@@ -386,8 +397,17 @@ def test_chat_completion_speculative():
|
||||
def gen_character_spec(s):
|
||||
s += sgl.system("You are a helpful assistant.")
|
||||
s += sgl.user("Construct a character within the following format:")
|
||||
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
|
||||
s += sgl.assistant(
|
||||
"Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||
)
|
||||
s += sgl.user("Please generate new Name, Birthday and Job.\n")
|
||||
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||
s += sgl.assistant(
|
||||
"Name:"
|
||||
+ sgl.gen("name", stop="\n")
|
||||
+ "\nBirthday:"
|
||||
+ sgl.gen("birthday", stop="\n")
|
||||
+ "\nJob:"
|
||||
+ sgl.gen("job", stop="\n")
|
||||
)
|
||||
|
||||
gen_character_spec().sync()
|
||||
gen_character_spec().sync()
|
||||
|
||||
@@ -15,7 +15,6 @@ from json import dumps
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -255,8 +254,10 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
|
||||
|
||||
def graceful_registry(sub_module_name):
|
||||
def graceful_shutdown(signum, frame):
|
||||
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...")
|
||||
logger.info(
|
||||
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
|
||||
)
|
||||
if signum == signal.SIGTERM:
|
||||
logger.info(f"{sub_module_name} recive sigterm")
|
||||
|
||||
signal.signal(signal.SIGTERM, graceful_shutdown)
|
||||
signal.signal(signal.SIGTERM, graceful_shutdown)
|
||||
|
||||
Reference in New Issue
Block a user