Higher priority for user input of max_prefill_tokens & format (#540)
This commit is contained in:
@@ -65,7 +65,7 @@ def main(args):
|
|||||||
def get_one_answer(i):
|
def get_one_answer(i):
|
||||||
answer = call_generate(
|
answer = call_generate(
|
||||||
prompt=few_shot_examples + questions[i],
|
prompt=few_shot_examples + questions[i],
|
||||||
#prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i],
|
# prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i],
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=256,
|
max_tokens=256,
|
||||||
stop="Question",
|
stop="Question",
|
||||||
|
|||||||
@@ -158,7 +158,9 @@ async def send_request(
|
|||||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
while True:
|
while True:
|
||||||
async with session.post(api_url, headers=headers, json=pload) as response:
|
async with session.post(
|
||||||
|
api_url, headers=headers, json=pload
|
||||||
|
) as response:
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk, _ in response.content.iter_chunks():
|
async for chunk, _ in response.content.iter_chunks():
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
@@ -228,19 +230,32 @@ def main(args: argparse.Namespace):
|
|||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
api_url = f"http://{args.host}:{args.port}/generate"
|
api_url = f"http://{args.host}:{args.port}/generate"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
args.tokenizer, trust_remote_code=args.trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
if args.dataset:
|
if args.dataset:
|
||||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||||
else:
|
else:
|
||||||
input_lens = np.random.randint(
|
input_lens = np.random.randint(
|
||||||
int(args.input_len * args.range_ratio), args.input_len + 1, size=args.num_prompts)
|
int(args.input_len * args.range_ratio),
|
||||||
|
args.input_len + 1,
|
||||||
|
size=args.num_prompts,
|
||||||
|
)
|
||||||
output_lens = np.random.randint(
|
output_lens = np.random.randint(
|
||||||
int(args.output_len * args.range_ratio), args.output_len + 1, size=args.num_prompts)
|
int(args.output_len * args.range_ratio),
|
||||||
|
args.output_len + 1,
|
||||||
|
size=args.num_prompts,
|
||||||
|
)
|
||||||
offsets = np.random.randint(0, tokenizer.vocab_size, size=args.num_prompts)
|
offsets = np.random.randint(0, tokenizer.vocab_size, size=args.num_prompts)
|
||||||
input_requests = []
|
input_requests = []
|
||||||
for i in range(args.num_prompts):
|
for i in range(args.num_prompts):
|
||||||
prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])])
|
prompt = tokenizer.decode(
|
||||||
|
[
|
||||||
|
(offsets[i] + i + j) % tokenizer.vocab_size
|
||||||
|
for j in range(input_lens[i])
|
||||||
|
]
|
||||||
|
)
|
||||||
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
|
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
@@ -287,16 +302,15 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
parser.add_argument("--port", type=int, default=30000)
|
parser.add_argument("--port", type=int, default=30000)
|
||||||
parser.add_argument(
|
parser.add_argument("--dataset", type=str, help="Path to the dataset.")
|
||||||
"--dataset", type=str, help="Path to the dataset."
|
|
||||||
)
|
|
||||||
parser.add_argument("--input-len", type=int, default=2048)
|
parser.add_argument("--input-len", type=int, default=2048)
|
||||||
parser.add_argument("--output-len", type=int, default=256)
|
parser.add_argument("--output-len", type=int, default=256)
|
||||||
parser.add_argument("--range-ratio", type=float, default=1.0)
|
parser.add_argument("--range-ratio", type=float, default=1.0)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer", type=str,
|
"--tokenizer",
|
||||||
|
type=str,
|
||||||
default="NousResearch/Meta-Llama-3-8B",
|
default="NousResearch/Meta-Llama-3-8B",
|
||||||
help="Name or path of the tokenizer."
|
help="Name or path of the tokenizer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--best-of",
|
"--best-of",
|
||||||
|
|||||||
@@ -170,4 +170,4 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--data_dir", type=str, default="data")
|
parser.add_argument("--data_dir", type=str, default="data")
|
||||||
parser.add_argument("--nsub", type=int, default=60)
|
parser.add_argument("--nsub", type=int, default=60)
|
||||||
args = add_common_other_args_and_parse(parser)
|
args = add_common_other_args_and_parse(parser)
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -24,10 +24,10 @@ from sglang.api import (
|
|||||||
|
|
||||||
# SGL Backends
|
# SGL Backends
|
||||||
from sglang.backend.anthropic import Anthropic
|
from sglang.backend.anthropic import Anthropic
|
||||||
|
from sglang.backend.litellm import LiteLLM
|
||||||
from sglang.backend.openai import OpenAI
|
from sglang.backend.openai import OpenAI
|
||||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.backend.vertexai import VertexAI
|
from sglang.backend.vertexai import VertexAI
|
||||||
from sglang.backend.litellm import LiteLLM
|
|
||||||
|
|
||||||
# Global Configurations
|
# Global Configurations
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ class LiteLLM(BaseBackend):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
self.chat_template = chat_template or get_chat_template_by_model_path(
|
self.chat_template = chat_template or get_chat_template_by_model_path(
|
||||||
model_name)
|
model_name
|
||||||
|
)
|
||||||
|
|
||||||
self.client_params = {
|
self.client_params = {
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
import dataclasses
|
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -105,14 +105,16 @@ class OpenAI(BaseBackend):
|
|||||||
def get_chat_template(self):
|
def get_chat_template(self):
|
||||||
return self.chat_template
|
return self.chat_template
|
||||||
|
|
||||||
def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
|
def _prepare_spec_execution(
|
||||||
num_api_spec_tokens: int, spec_var_name: str):
|
self,
|
||||||
|
sampling_params: SglSamplingParams,
|
||||||
|
num_api_spec_tokens: int,
|
||||||
|
spec_var_name: str,
|
||||||
|
):
|
||||||
if "max_tokens" not in self.spec_kwargs:
|
if "max_tokens" not in self.spec_kwargs:
|
||||||
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
||||||
else:
|
else:
|
||||||
assert (
|
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
||||||
self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
params = sampling_params.to_openai_kwargs()
|
params = sampling_params.to_openai_kwargs()
|
||||||
for key, value in params.items():
|
for key, value in params.items():
|
||||||
@@ -151,8 +153,9 @@ class OpenAI(BaseBackend):
|
|||||||
)
|
)
|
||||||
prompt = s.messages_
|
prompt = s.messages_
|
||||||
else:
|
else:
|
||||||
return self._prepare_spec_execution(sampling_params,
|
return self._prepare_spec_execution(
|
||||||
s.num_api_spec_tokens, spec_var_name)
|
sampling_params, s.num_api_spec_tokens, spec_var_name
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt = s.text_
|
prompt = s.text_
|
||||||
|
|
||||||
@@ -325,7 +328,7 @@ class OpenAI(BaseBackend):
|
|||||||
ret_str = ret.choices[0].text
|
ret_str = ret.choices[0].text
|
||||||
ret_token = self.tokenizer.encode(ret_str)[0]
|
ret_token = self.tokenizer.encode(ret_str)[0]
|
||||||
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||||
self.token_usage.completion_tokens= ret.usage.completion_tokens
|
self.token_usage.completion_tokens = ret.usage.completion_tokens
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
# 1. return logits as the scores
|
# 1. return logits as the scores
|
||||||
@@ -355,7 +358,9 @@ class OpenAI(BaseBackend):
|
|||||||
return decision, scores, None, None
|
return decision, scores, None, None
|
||||||
|
|
||||||
|
|
||||||
def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
|
def openai_completion(
|
||||||
|
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||||
|
):
|
||||||
for attempt in range(retries):
|
for attempt in range(retries):
|
||||||
try:
|
try:
|
||||||
if is_chat:
|
if is_chat:
|
||||||
@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
|
|||||||
return comp
|
return comp
|
||||||
|
|
||||||
|
|
||||||
def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
|
def openai_completion_stream(
|
||||||
|
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||||
|
):
|
||||||
for attempt in range(retries):
|
for attempt in range(retries):
|
||||||
try:
|
try:
|
||||||
if is_chat:
|
if is_chat:
|
||||||
if "stop" in kwargs and kwargs["stop"] is None:
|
if "stop" in kwargs and kwargs["stop"] is None:
|
||||||
kwargs.pop("stop")
|
kwargs.pop("stop")
|
||||||
generator = client.chat.completions.create(
|
generator = client.chat.completions.create(
|
||||||
messages=prompt, stream=True, stream_options={"include_usage": True},
|
messages=prompt,
|
||||||
**kwargs
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
for ret in generator:
|
for ret in generator:
|
||||||
if len(ret.choices) == 0:
|
if len(ret.choices) == 0:
|
||||||
@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
|
|||||||
yield content or "", {}
|
yield content or "", {}
|
||||||
else:
|
else:
|
||||||
generator = client.completions.create(
|
generator = client.completions.create(
|
||||||
prompt=prompt, stream=True, stream_options={"include_usage": True},
|
prompt=prompt,
|
||||||
**kwargs
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
for ret in generator:
|
for ret in generator:
|
||||||
if len(ret.choices) == 0:
|
if len(ret.choices) == 0:
|
||||||
|
|||||||
@@ -507,7 +507,7 @@ class StreamExecutor:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
else: # Speculative execution on models with completion interface
|
else: # Speculative execution on models with completion interface
|
||||||
comp, meta_info = self._spec_gen(sampling_params)
|
comp, meta_info = self._spec_gen(sampling_params)
|
||||||
|
|
||||||
self.text_ += comp
|
self.text_ += comp
|
||||||
|
|||||||
@@ -81,12 +81,10 @@ class SglSamplingParams:
|
|||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_litellm_kwargs(self):
|
def to_litellm_kwargs(self):
|
||||||
if self.regex is not None:
|
if self.regex is not None:
|
||||||
warnings.warn(
|
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
|
||||||
"Regular expression is not supported in the LiteLLM backend."
|
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
"max_tokens": self.max_new_tokens,
|
"max_tokens": self.max_new_tokens,
|
||||||
"stop": self.stop or None,
|
"stop": self.stop or None,
|
||||||
|
|||||||
@@ -10,4 +10,4 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
server_args = ServerArgs.from_cli_args(args)
|
server_args = ServerArgs.from_cli_args(args)
|
||||||
|
|
||||||
launch_server(server_args, None)
|
launch_server(server_args, None)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Launch the inference server for Llava-video model."""
|
"""Launch the inference server for Llava-video model."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Dict, Optional, Union
|
|||||||
from outlines.caching import cache as disk_cache
|
from outlines.caching import cache as disk_cache
|
||||||
from outlines.caching import disable_cache
|
from outlines.caching import disable_cache
|
||||||
from outlines.fsm.guide import RegexGuide
|
from outlines.fsm.guide import RegexGuide
|
||||||
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm
|
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
||||||
from outlines.models.transformers import TransformerTokenizer
|
from outlines.models.transformers import TransformerTokenizer
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Cache for the compressed finite state machine."""
|
"""Cache for the compressed finite state machine."""
|
||||||
|
|
||||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||||
from sglang.srt.constrained.base_cache import BaseCache
|
from sglang.srt.constrained.base_cache import BaseCache
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,12 @@ from collections import defaultdict
|
|||||||
|
|
||||||
import interegular
|
import interegular
|
||||||
import outlines.caching
|
import outlines.caching
|
||||||
|
|
||||||
from sglang.srt.constrained import (
|
from sglang.srt.constrained import (
|
||||||
FSMInfo,
|
FSMInfo,
|
||||||
disk_cache,
|
disk_cache,
|
||||||
make_deterministic_fsm,
|
|
||||||
make_byte_level_fsm,
|
make_byte_level_fsm,
|
||||||
|
make_deterministic_fsm,
|
||||||
)
|
)
|
||||||
from sglang.srt.constrained.base_cache import BaseCache
|
from sglang.srt.constrained.base_cache import BaseCache
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Conversation templates."""
|
"""Conversation templates."""
|
||||||
|
|
||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
"""Utilities for Huggingface Transformers."""
|
"""Utilities for Huggingface Transformers."""
|
||||||
|
|
||||||
|
import functools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
import functools
|
from typing import AbstractSet, Collection, Literal, Optional, Union
|
||||||
from typing import Optional, Union, AbstractSet, Collection, Literal
|
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -179,6 +179,7 @@ def get_processor(
|
|||||||
class TiktokenTokenizer:
|
class TiktokenTokenizer:
|
||||||
def __init__(self, tokenizer_path):
|
def __init__(self, tokenizer_path):
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||||
|
|
||||||
# Read JSON
|
# Read JSON
|
||||||
@@ -190,7 +191,8 @@ class TiktokenTokenizer:
|
|||||||
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
|
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
|
||||||
}
|
}
|
||||||
special_tokens = {
|
special_tokens = {
|
||||||
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
|
bytes(item["bytes"]).decode(): item["token"]
|
||||||
|
for item in tok_dict["special_tokens"]
|
||||||
}
|
}
|
||||||
assert tok_dict["word_split"] == "V1"
|
assert tok_dict["word_split"] == "V1"
|
||||||
|
|
||||||
@@ -202,7 +204,10 @@ class TiktokenTokenizer:
|
|||||||
}
|
}
|
||||||
if "default_allowed_special" in tok_dict:
|
if "default_allowed_special" in tok_dict:
|
||||||
default_allowed_special = set(
|
default_allowed_special = set(
|
||||||
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
|
[
|
||||||
|
bytes(bytes_list).decode()
|
||||||
|
for bytes_list in tok_dict["default_allowed_special"]
|
||||||
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
default_allowed_special = None
|
default_allowed_special = None
|
||||||
@@ -216,14 +221,20 @@ class TiktokenTokenizer:
|
|||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
*,
|
*,
|
||||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
allowed_special: Union[
|
||||||
|
Literal["all"], AbstractSet[str]
|
||||||
|
] = set(), # noqa: B006
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
if isinstance(allowed_special, set):
|
if isinstance(allowed_special, set):
|
||||||
allowed_special |= self._default_allowed_special
|
allowed_special |= self._default_allowed_special
|
||||||
return tiktoken.Encoding.encode(
|
return tiktoken.Encoding.encode(
|
||||||
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
|
self,
|
||||||
|
text,
|
||||||
|
allowed_special=allowed_special,
|
||||||
|
disallowed_special=disallowed_special,
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
||||||
|
|
||||||
# Convert to HF interface
|
# Convert to HF interface
|
||||||
@@ -237,10 +248,14 @@ class TiktokenTokenizer:
|
|||||||
def decode(self, x):
|
def decode(self, x):
|
||||||
return self.tokenizer.decode(x)
|
return self.tokenizer.decode(x)
|
||||||
|
|
||||||
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
|
def batch_decode(
|
||||||
|
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
||||||
|
):
|
||||||
if isinstance(batch[0], int):
|
if isinstance(batch[0], int):
|
||||||
batch = [[x] for x in batch]
|
batch = [[x] for x in batch]
|
||||||
return self.tokenizer.decode_batch(batch)
|
return self.tokenizer.decode_batch(batch)
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, index):
|
def convert_ids_to_tokens(self, index):
|
||||||
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
|
return self.tokenizer.decode_single_token_bytes(index).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import is_hip
|
from vllm.utils import is_hip
|
||||||
@@ -109,12 +108,16 @@ def fused_moe_kernel(
|
|||||||
|
|
||||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
a_ptrs = a_ptr + (
|
||||||
offs_k[None, :] * stride_ak)
|
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
||||||
|
)
|
||||||
|
|
||||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
b_ptrs = (
|
||||||
offs_bn[None, :] * stride_bn)
|
b_ptr
|
||||||
|
+ off_experts * stride_be
|
||||||
|
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||||
|
)
|
||||||
|
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
a_scale = tl.load(a_scale_ptr)
|
a_scale = tl.load(a_scale_ptr)
|
||||||
@@ -130,13 +133,12 @@ def fused_moe_kernel(
|
|||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||||
# Load the next block of A and B, generate a mask by checking the
|
# Load the next block of A and B, generate a mask by checking the
|
||||||
# K dimension.
|
# K dimension.
|
||||||
a = tl.load(a_ptrs,
|
a = tl.load(
|
||||||
mask=token_mask[:, None] &
|
a_ptrs,
|
||||||
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||||
other=0.0)
|
other=0.0,
|
||||||
b = tl.load(b_ptrs,
|
)
|
||||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||||
other=0.0)
|
|
||||||
# We accumulate along the K dimension.
|
# We accumulate along the K dimension.
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
accumulator = tl.dot(a, b, acc=accumulator)
|
accumulator = tl.dot(a, b, acc=accumulator)
|
||||||
@@ -147,9 +149,7 @@ def fused_moe_kernel(
|
|||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
if MUL_ROUTED_WEIGHT:
|
if MUL_ROUTED_WEIGHT:
|
||||||
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||||
mask=token_mask,
|
|
||||||
other=0)
|
|
||||||
accumulator = accumulator * moe_weight[:, None]
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
|
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
@@ -159,15 +159,14 @@ def fused_moe_kernel(
|
|||||||
# -----------------------------------------------------------
|
# -----------------------------------------------------------
|
||||||
# Write back the block of the output
|
# Write back the block of the output
|
||||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||||
None, :]
|
|
||||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
def moe_align_block_size(
|
def moe_align_block_size(
|
||||||
topk_ids: torch.Tensor, block_size: int,
|
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
||||||
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Aligns the token distribution across experts to be compatible with block
|
Aligns the token distribution across experts to be compatible with block
|
||||||
size for matrix multiplication.
|
size for matrix multiplication.
|
||||||
@@ -206,32 +205,38 @@ def moe_align_block_size(
|
|||||||
by block_size for proper block matrix operations.
|
by block_size for proper block matrix operations.
|
||||||
"""
|
"""
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
sorted_ids = torch.empty(
|
||||||
dtype=torch.int32,
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||||
device=topk_ids.device)
|
)
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||||
expert_ids = torch.empty((max_num_m_blocks, ),
|
expert_ids = torch.empty(
|
||||||
dtype=torch.int32,
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
device=topk_ids.device)
|
)
|
||||||
num_tokens_post_pad = torch.empty((1),
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||||
dtype=torch.int32,
|
ops.moe_align_block_size(
|
||||||
device=topk_ids.device)
|
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
||||||
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
)
|
||||||
expert_ids, num_tokens_post_pad)
|
|
||||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||||
|
|
||||||
|
|
||||||
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
def invoke_fused_moe_kernel(
|
||||||
A_scale: Optional[torch.Tensor],
|
A: torch.Tensor,
|
||||||
B_scale: Optional[torch.Tensor],
|
B: torch.Tensor,
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
C: torch.Tensor,
|
||||||
sorted_token_ids: torch.Tensor,
|
A_scale: Optional[torch.Tensor],
|
||||||
expert_ids: torch.Tensor,
|
B_scale: Optional[torch.Tensor],
|
||||||
num_tokens_post_padded: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
mul_routed_weight: bool, top_k: int,
|
topk_ids: torch.Tensor,
|
||||||
config: Dict[str, Any], compute_type: tl.dtype,
|
sorted_token_ids: torch.Tensor,
|
||||||
use_fp8: bool) -> None:
|
expert_ids: torch.Tensor,
|
||||||
|
num_tokens_post_padded: torch.Tensor,
|
||||||
|
mul_routed_weight: bool,
|
||||||
|
top_k: int,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
compute_type: tl.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
) -> None:
|
||||||
assert topk_weights.stride(1) == 1
|
assert topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
@@ -242,8 +247,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|||||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||||
assert B_scale is not None
|
assert B_scale is not None
|
||||||
|
|
||||||
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
grid = lambda META: (
|
||||||
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
|
||||||
|
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
|
||||||
fused_moe_kernel[grid](
|
fused_moe_kernel[grid](
|
||||||
A,
|
A,
|
||||||
@@ -281,8 +288,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def get_moe_configs(E: int, N: int,
|
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
||||||
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
|
||||||
"""
|
"""
|
||||||
Return optimized configurations for the fused MoE kernel.
|
Return optimized configurations for the fused MoE kernel.
|
||||||
|
|
||||||
@@ -297,11 +303,11 @@ def get_moe_configs(E: int, N: int,
|
|||||||
json_file_name = get_config_file_name(E, N, dtype)
|
json_file_name = get_config_file_name(E, N, dtype)
|
||||||
|
|
||||||
config_file_path = os.path.join(
|
config_file_path = os.path.join(
|
||||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
||||||
|
)
|
||||||
if os.path.exists(config_file_path):
|
if os.path.exists(config_file_path):
|
||||||
with open(config_file_path) as f:
|
with open(config_file_path) as f:
|
||||||
logger.info("Using configuration from %s for MoE layer.",
|
logger.info("Using configuration from %s for MoE layer.", config_file_path)
|
||||||
config_file_path)
|
|
||||||
# If a configuration has been found, return it
|
# If a configuration has been found, return it
|
||||||
return {int(key): val for key, val in json.load(f).items()}
|
return {int(key): val for key, val in json.load(f).items()}
|
||||||
|
|
||||||
@@ -352,40 +358,30 @@ def fused_moe(
|
|||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
"Number of tokens mismatch")
|
|
||||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||||
assert hidden_states.dtype in [
|
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||||||
torch.float32, torch.float16, torch.bfloat16
|
|
||||||
]
|
|
||||||
M, _ = hidden_states.shape
|
M, _ = hidden_states.shape
|
||||||
E, N, _ = w1.shape
|
E, N, _ = w1.shape
|
||||||
|
|
||||||
if is_hip():
|
if is_hip():
|
||||||
# The MoE kernels are not yet supported on ROCm.
|
# The MoE kernels are not yet supported on ROCm.
|
||||||
routing_weights = torch.softmax(gating_output,
|
routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
|
||||||
dim=-1,
|
|
||||||
dtype=torch.float32)
|
|
||||||
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
|
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
|
||||||
else:
|
else:
|
||||||
import vllm._moe_C as moe_kernels
|
import vllm._moe_C as moe_kernels
|
||||||
|
|
||||||
topk_weights = torch.empty(M,
|
topk_weights = torch.empty(
|
||||||
topk,
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||||
dtype=torch.float32,
|
)
|
||||||
device=hidden_states.device)
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||||
topk_ids = torch.empty(M,
|
token_expert_indicies = torch.empty(
|
||||||
topk,
|
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||||
dtype=torch.int32,
|
)
|
||||||
device=hidden_states.device)
|
|
||||||
token_expert_indicies = torch.empty(M,
|
|
||||||
topk,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=hidden_states.device)
|
|
||||||
moe_kernels.topk_softmax(
|
moe_kernels.topk_softmax(
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@@ -400,8 +396,7 @@ def fused_moe(
|
|||||||
config = override_config
|
config = override_config
|
||||||
else:
|
else:
|
||||||
# First try to load optimal config from the file
|
# First try to load optimal config from the file
|
||||||
configs = get_moe_configs(E, w2.shape[2],
|
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
||||||
"float8" if use_fp8 else None)
|
|
||||||
|
|
||||||
if configs:
|
if configs:
|
||||||
# If an optimal configuration map has been found, look up the
|
# If an optimal configuration map has been found, look up the
|
||||||
@@ -415,7 +410,7 @@ def fused_moe(
|
|||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4,
|
||||||
}
|
}
|
||||||
|
|
||||||
if M <= E:
|
if M <= E:
|
||||||
@@ -425,61 +420,72 @@ def fused_moe(
|
|||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 16,
|
"GROUP_SIZE_M": 16,
|
||||||
"num_warps": 8,
|
"num_warps": 8,
|
||||||
"num_stages": 4
|
"num_stages": 4,
|
||||||
}
|
}
|
||||||
|
|
||||||
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
intermediate_cache1 = torch.empty(
|
||||||
device=hidden_states.device,
|
(M, topk_ids.shape[1], N),
|
||||||
dtype=hidden_states.dtype)
|
device=hidden_states.device,
|
||||||
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
|
dtype=hidden_states.dtype,
|
||||||
device=hidden_states.device,
|
)
|
||||||
dtype=hidden_states.dtype)
|
intermediate_cache2 = torch.empty(
|
||||||
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
|
(M * topk_ids.shape[1], N // 2),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype)
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
intermediate_cache3 = torch.empty(
|
||||||
|
(M, topk_ids.shape[1], w2.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||||
topk_ids, config['BLOCK_SIZE_M'], E)
|
topk_ids, config["BLOCK_SIZE_M"], E
|
||||||
compute_type = (tl.bfloat16
|
)
|
||||||
if hidden_states.dtype == torch.bfloat16 else tl.float16)
|
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||||
|
|
||||||
invoke_fused_moe_kernel(hidden_states,
|
invoke_fused_moe_kernel(
|
||||||
w1,
|
hidden_states,
|
||||||
intermediate_cache1,
|
w1,
|
||||||
a1_scale,
|
intermediate_cache1,
|
||||||
w1_scale,
|
a1_scale,
|
||||||
topk_weights,
|
w1_scale,
|
||||||
topk_ids,
|
topk_weights,
|
||||||
sorted_token_ids,
|
topk_ids,
|
||||||
expert_ids,
|
sorted_token_ids,
|
||||||
num_tokens_post_padded,
|
expert_ids,
|
||||||
False,
|
num_tokens_post_padded,
|
||||||
topk_ids.shape[1],
|
False,
|
||||||
config,
|
topk_ids.shape[1],
|
||||||
compute_type=compute_type,
|
config,
|
||||||
use_fp8=use_fp8)
|
compute_type=compute_type,
|
||||||
|
use_fp8=use_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
invoke_fused_moe_kernel(intermediate_cache2,
|
invoke_fused_moe_kernel(
|
||||||
w2,
|
intermediate_cache2,
|
||||||
intermediate_cache3,
|
w2,
|
||||||
a2_scale,
|
intermediate_cache3,
|
||||||
w2_scale,
|
a2_scale,
|
||||||
topk_weights,
|
w2_scale,
|
||||||
topk_ids,
|
topk_weights,
|
||||||
sorted_token_ids,
|
topk_ids,
|
||||||
expert_ids,
|
sorted_token_ids,
|
||||||
num_tokens_post_padded,
|
expert_ids,
|
||||||
True,
|
num_tokens_post_padded,
|
||||||
1,
|
True,
|
||||||
config,
|
1,
|
||||||
compute_type=compute_type,
|
config,
|
||||||
use_fp8=use_fp8)
|
compute_type=compute_type,
|
||||||
|
use_fp8=use_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
if inplace:
|
if inplace:
|
||||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
return torch.sum(
|
||||||
dim=1,
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
out=hidden_states)
|
dim=1,
|
||||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
out=hidden_states,
|
||||||
dim=1)
|
)
|
||||||
|
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Logits processing."""
|
"""Logits processing."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Radix attention."""
|
"""Radix attention."""
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||||
@@ -10,7 +11,9 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
|
|||||||
|
|
||||||
|
|
||||||
class RadixAttention(nn.Module):
|
class RadixAttention(nn.Module):
|
||||||
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
|
def __init__(
|
||||||
|
self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_q_head_num = num_heads
|
self.tp_q_head_num = num_heads
|
||||||
self.tp_k_head_num = num_kv_heads
|
self.tp_k_head_num = num_kv_heads
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
from typing import List, Callable
|
from typing import Callable, List
|
||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
import zmq
|
import zmq
|
||||||
@@ -70,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread):
|
|||||||
|
|
||||||
# async sleep for receiving the subsequent request and avoiding cache miss
|
# async sleep for receiving the subsequent request and avoiding cache miss
|
||||||
if len(out_pyobjs) != 0:
|
if len(out_pyobjs) != 0:
|
||||||
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
|
has_finished = any(
|
||||||
|
[obj.finished_reason is not None for obj in out_pyobjs]
|
||||||
|
)
|
||||||
if has_finished:
|
if has_finished:
|
||||||
await asyncio.sleep(self.request_dependency_delay)
|
await asyncio.sleep(self.request_dependency_delay)
|
||||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||||
@@ -108,4 +110,4 @@ def start_data_parallel_worker(
|
|||||||
step_func=model_tp_client.step,
|
step_func=model_tp_client.step,
|
||||||
)
|
)
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
return worker_thread
|
return worker_thread
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
"""Meta data for requests and batches"""
|
"""Meta data for requests and batches"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import List
|
from typing import List
|
||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.constrained import RegexGuide
|
||||||
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
|
||||||
from sglang.srt.constrained import RegexGuide
|
|
||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
|
|||||||
@@ -13,15 +13,15 @@ import zmq
|
|||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
|
from sglang.srt.managers.controller.dp_worker import (
|
||||||
|
DataParallelWorkerThread,
|
||||||
|
start_data_parallel_worker,
|
||||||
|
)
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.controller.dp_worker import (
|
|
||||||
DataParallelWorkerThread,
|
|
||||||
start_data_parallel_worker,
|
|
||||||
)
|
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
@@ -136,7 +136,7 @@ class Controller:
|
|||||||
self.recv_reqs = []
|
self.recv_reqs = []
|
||||||
if next_step_input:
|
if next_step_input:
|
||||||
await self.dispatching(next_step_input)
|
await self.dispatching(next_step_input)
|
||||||
#else:
|
# else:
|
||||||
# logger.error("There is no live worker.")
|
# logger.error("There is no live worker.")
|
||||||
|
|
||||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""A controller that manages a group of tensor parallel workers."""
|
"""A controller that manages a group of tensor parallel workers."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@@ -49,7 +50,9 @@ class ControllerSingle:
|
|||||||
# async sleep for receiving the subsequent request and avoiding cache miss
|
# async sleep for receiving the subsequent request and avoiding cache miss
|
||||||
slept = False
|
slept = False
|
||||||
if len(out_pyobjs) != 0:
|
if len(out_pyobjs) != 0:
|
||||||
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
|
has_finished = any(
|
||||||
|
[obj.finished_reason is not None for obj in out_pyobjs]
|
||||||
|
)
|
||||||
if has_finished:
|
if has_finished:
|
||||||
if self.request_dependency_delay > 0:
|
if self.request_dependency_delay > 0:
|
||||||
slept = True
|
slept = True
|
||||||
@@ -94,4 +97,4 @@ def start_controller_process(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
||||||
finally:
|
finally:
|
||||||
kill_parent_process()
|
kill_parent_process()
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""ModelRunner runs the forward passes of the models."""
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
@@ -12,15 +13,18 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm.config import DeviceConfig, LoadConfig
|
from vllm.config import DeviceConfig, LoadConfig
|
||||||
from vllm.config import ModelConfig as VllmModelConfig
|
from vllm.config import ModelConfig as VllmModelConfig
|
||||||
from vllm.distributed import initialize_model_parallel, init_distributed_environment
|
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
|
from sglang.srt.utils import (
|
||||||
|
get_available_gpu_memory,
|
||||||
|
is_multimodal_model,
|
||||||
|
monkey_patch_vllm_p2p_access_check,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger("srt.model_runner")
|
logger = logging.getLogger("srt.model_runner")
|
||||||
|
|
||||||
@@ -441,7 +445,9 @@ def import_model_classes():
|
|||||||
module = importlib.import_module(name)
|
module = importlib.import_module(name)
|
||||||
if hasattr(module, "EntryClass"):
|
if hasattr(module, "EntryClass"):
|
||||||
entry = module.EntryClass
|
entry = module.EntryClass
|
||||||
if isinstance(entry, list): # To support multiple model classes in one module
|
if isinstance(
|
||||||
|
entry, list
|
||||||
|
): # To support multiple model classes in one module
|
||||||
for tmp in entry:
|
for tmp in entry:
|
||||||
model_arch_name_to_cls[tmp.__name__] = tmp
|
model_arch_name_to_cls[tmp.__name__] = tmp
|
||||||
else:
|
else:
|
||||||
@@ -449,7 +455,9 @@ def import_model_classes():
|
|||||||
|
|
||||||
# compat: some models such as chatglm has incorrect class set in config.json
|
# compat: some models such as chatglm has incorrect class set in config.json
|
||||||
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
|
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
|
||||||
if hasattr(module, "EntryClassRemapping") and isinstance(module.EntryClassRemapping, list):
|
if hasattr(module, "EntryClassRemapping") and isinstance(
|
||||||
|
module.EntryClassRemapping, list
|
||||||
|
):
|
||||||
for remap in module.EntryClassRemapping:
|
for remap in module.EntryClassRemapping:
|
||||||
if isinstance(remap, tuple) and len(remap) == 2:
|
if isinstance(remap, tuple) and len(remap) == 2:
|
||||||
model_arch_name_to_cls[remap[0]] = remap[1]
|
model_arch_name_to_cls[remap[0]] = remap[1]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
The radix tree data structure for managing the KV cache.
|
The radix tree data structure for managing the KV cache.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import heapq
|
import heapq
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Request scheduler heuristic."""
|
"""Request scheduler heuristic."""
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|||||||
@@ -15,22 +15,22 @@ from sglang.global_config import global_config
|
|||||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import (
|
|
||||||
AbortReq,
|
|
||||||
BatchTokenIDOut,
|
|
||||||
FlushCacheReq,
|
|
||||||
TokenizedGenerateReqInput,
|
|
||||||
)
|
|
||||||
from sglang.srt.managers.controller.infer_batch import (
|
from sglang.srt.managers.controller.infer_batch import (
|
||||||
|
FINISH_ABORT,
|
||||||
BaseFinishReason,
|
BaseFinishReason,
|
||||||
Batch,
|
Batch,
|
||||||
FINISH_ABORT,
|
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
Req,
|
Req,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.controller.model_runner import ModelRunner
|
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||||
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
||||||
|
from sglang.srt.managers.io_struct import (
|
||||||
|
AbortReq,
|
||||||
|
BatchTokenIDOut,
|
||||||
|
FlushCacheReq,
|
||||||
|
TokenizedGenerateReqInput,
|
||||||
|
)
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -96,13 +96,13 @@ class ModelTpServer:
|
|||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
)
|
)
|
||||||
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
||||||
self.max_prefill_tokens = max(
|
self.max_prefill_tokens = (
|
||||||
self.model_config.context_len,
|
max(
|
||||||
(
|
self.model_config.context_len,
|
||||||
min(self.max_total_num_tokens // 6, 65536)
|
min(self.max_total_num_tokens // 6, 65536),
|
||||||
if server_args.max_prefill_tokens is None
|
)
|
||||||
else server_args.max_prefill_tokens
|
if server_args.max_prefill_tokens is None
|
||||||
),
|
else server_args.max_prefill_tokens
|
||||||
)
|
)
|
||||||
self.max_running_requests = (
|
self.max_running_requests = (
|
||||||
self.max_total_num_tokens // 2
|
self.max_total_num_tokens // 2
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
@@ -7,10 +8,10 @@ import zmq
|
|||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
||||||
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.utils import get_exception_traceback, graceful_registry
|
from sglang.utils import get_exception_traceback, graceful_registry
|
||||||
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import uuid
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
|
||||||
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
|
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
|
||||||
|
from sglang.srt.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
"""TokenizerManager is a process that tokenizes the text."""
|
"""TokenizerManager is a process that tokenizes the text."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
from typing import List, Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import transformers
|
import transformers
|
||||||
@@ -23,11 +24,11 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
|
BatchTokenIDOut,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
|
||||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
from sglang.srt.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
@@ -91,7 +92,7 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.to_create_loop = True
|
self.to_create_loop = True
|
||||||
self.rid_to_state: Dict[str, ReqState] = {}
|
self.rid_to_state: Dict[str, ReqState] = {}
|
||||||
|
|
||||||
async def get_pixel_values(self, image_data):
|
async def get_pixel_values(self, image_data):
|
||||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||||
@@ -322,7 +323,6 @@ class TokenizerManager:
|
|||||||
state.finished = recv_obj.finished_reason[i] is not None
|
state.finished = recv_obj.finished_reason[i] is not None
|
||||||
state.event.set()
|
state.event.set()
|
||||||
|
|
||||||
|
|
||||||
def convert_logprob_style(
|
def convert_logprob_style(
|
||||||
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -17,8 +18,12 @@ class ModelConfig:
|
|||||||
self.trust_remote_code = trust_remote_code
|
self.trust_remote_code = trust_remote_code
|
||||||
self.revision = revision
|
self.revision = revision
|
||||||
self.model_overide_args = model_overide_args
|
self.model_overide_args = model_overide_args
|
||||||
self.hf_config = get_config(self.path, trust_remote_code, revision,
|
self.hf_config = get_config(
|
||||||
model_overide_args=model_overide_args)
|
self.path,
|
||||||
|
trust_remote_code,
|
||||||
|
revision,
|
||||||
|
model_overide_args=model_overide_args,
|
||||||
|
)
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
if context_length is not None:
|
if context_length is not None:
|
||||||
self.context_len = context_length
|
self.context_len = context_length
|
||||||
@@ -55,18 +60,23 @@ class ModelConfig:
|
|||||||
# KV heads.
|
# KV heads.
|
||||||
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||||
new_decoder_arch_falcon = (
|
new_decoder_arch_falcon = (
|
||||||
self.hf_config.model_type in falcon_model_types
|
self.hf_config.model_type in falcon_model_types
|
||||||
and getattr(self.hf_config, "new_decoder_architecture", False))
|
and getattr(self.hf_config, "new_decoder_architecture", False)
|
||||||
if not new_decoder_arch_falcon and getattr(self.hf_text_config,
|
)
|
||||||
"multi_query", False):
|
if not new_decoder_arch_falcon and getattr(
|
||||||
|
self.hf_text_config, "multi_query", False
|
||||||
|
):
|
||||||
# Multi-query attention, only one KV head.
|
# Multi-query attention, only one KV head.
|
||||||
# Currently, tensor parallelism is not supported in this case.
|
# Currently, tensor parallelism is not supported in this case.
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
# For DBRX and MPT
|
# For DBRX and MPT
|
||||||
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
||||||
return getattr(self.hf_config.attn_config, "kv_n_heads",
|
return getattr(
|
||||||
self.hf_config.num_attention_heads)
|
self.hf_config.attn_config,
|
||||||
|
"kv_n_heads",
|
||||||
|
self.hf_config.num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
attributes = [
|
attributes = [
|
||||||
# For Falcon:
|
# For Falcon:
|
||||||
@@ -94,13 +104,12 @@ class ModelConfig:
|
|||||||
# the tensor parallel size. We will replicate the KV heads in the
|
# the tensor parallel size. We will replicate the KV heads in the
|
||||||
# case where the number of KV heads is smaller than the tensor
|
# case where the number of KV heads is smaller than the tensor
|
||||||
# parallel size so each GPU has at least one KV head.
|
# parallel size so each GPU has at least one KV head.
|
||||||
return max(1,
|
return max(1, total_num_kv_heads // tensor_parallel_size)
|
||||||
total_num_kv_heads // tensor_parallel_size)
|
|
||||||
|
|
||||||
|
|
||||||
def get_hf_text_config(config: PretrainedConfig):
|
def get_hf_text_config(config: PretrainedConfig):
|
||||||
"""Get the "sub" config relevant to llm for multi modal models.
|
"""Get the "sub" config relevant to llm for multi modal models.
|
||||||
No op for pure text models.
|
No op for pure text models.
|
||||||
"""
|
"""
|
||||||
if hasattr(config, "text_config"):
|
if hasattr(config, "text_config"):
|
||||||
# The code operates under the assumption that text_config should have
|
# The code operates under the assumption that text_config should have
|
||||||
|
|||||||
@@ -5,30 +5,32 @@
|
|||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
|
||||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (
|
||||||
QKVParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
RowParallelLinear)
|
QKVParallelLinear,
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
RowParallelLinear,
|
||||||
QuantizationConfig)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||||
|
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
LoraConfig = None
|
LoraConfig = None
|
||||||
|
|
||||||
@@ -49,9 +51,11 @@ class GLMAttention(nn.Module):
|
|||||||
assert self.total_num_heads % tp_size == 0
|
assert self.total_num_heads % tp_size == 0
|
||||||
self.num_heads = self.total_num_heads // tp_size
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
self.multi_query_attention = config.multi_query_attention
|
self.multi_query_attention = config.multi_query_attention
|
||||||
self.total_num_kv_heads = (config.multi_query_group_num
|
self.total_num_kv_heads = (
|
||||||
if config.multi_query_attention else
|
config.multi_query_group_num
|
||||||
config.num_attention_heads)
|
if config.multi_query_attention
|
||||||
|
else config.num_attention_heads
|
||||||
|
)
|
||||||
if self.total_num_kv_heads >= tp_size:
|
if self.total_num_kv_heads >= tp_size:
|
||||||
# Number of KV heads is greater than TP size, so we partition
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
# the KV heads across multiple tensor parallel GPUs.
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
@@ -91,11 +95,13 @@ class GLMAttention(nn.Module):
|
|||||||
base=10000 * rope_ratio,
|
base=10000 * rope_ratio,
|
||||||
is_neox_style=False,
|
is_neox_style=False,
|
||||||
)
|
)
|
||||||
self.attn = RadixAttention(self.num_heads,
|
self.attn = RadixAttention(
|
||||||
self.head_dim,
|
self.num_heads,
|
||||||
self.scaling,
|
self.head_dim,
|
||||||
num_kv_heads=self.num_kv_heads,
|
self.scaling,
|
||||||
layer_id=layer_id)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
layer_id=layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -176,14 +182,16 @@ class GLMBlock(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.apply_residual_connection_post_layernorm = (
|
self.apply_residual_connection_post_layernorm = (
|
||||||
config.apply_residual_connection_post_layernorm)
|
config.apply_residual_connection_post_layernorm
|
||||||
|
)
|
||||||
|
|
||||||
self.fp32_residual_connection = config.fp32_residual_connection
|
self.fp32_residual_connection = config.fp32_residual_connection
|
||||||
|
|
||||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||||
# Layernorm on the input data.
|
# Layernorm on the input data.
|
||||||
self.input_layernorm = layer_norm_func(config.hidden_size,
|
self.input_layernorm = layer_norm_func(
|
||||||
eps=config.layernorm_epsilon)
|
config.hidden_size, eps=config.layernorm_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
# Self attention.
|
# Self attention.
|
||||||
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
|
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
|
||||||
@@ -191,7 +199,8 @@ class GLMBlock(nn.Module):
|
|||||||
|
|
||||||
# Layernorm on the attention output
|
# Layernorm on the attention output
|
||||||
self.post_attention_layernorm = layer_norm_func(
|
self.post_attention_layernorm = layer_norm_func(
|
||||||
config.hidden_size, eps=config.layernorm_epsilon)
|
config.hidden_size, eps=config.layernorm_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
# MLP
|
# MLP
|
||||||
self.mlp = GLMMLP(config, quant_config)
|
self.mlp = GLMMLP(config, quant_config)
|
||||||
@@ -250,16 +259,19 @@ class GLMTransformer(nn.Module):
|
|||||||
self.num_layers = config.num_layers
|
self.num_layers = config.num_layers
|
||||||
|
|
||||||
# Transformer layers.
|
# Transformer layers.
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList(
|
||||||
GLMBlock(config, i, cache_config, quant_config)
|
[
|
||||||
for i in range(self.num_layers)
|
GLMBlock(config, i, cache_config, quant_config)
|
||||||
])
|
for i in range(self.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if self.post_layer_norm:
|
if self.post_layer_norm:
|
||||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||||
# Final layer norm before output.
|
# Final layer norm before output.
|
||||||
self.final_layernorm = layer_norm_func(
|
self.final_layernorm = layer_norm_func(
|
||||||
config.hidden_size, eps=config.layernorm_epsilon)
|
config.hidden_size, eps=config.layernorm_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -291,16 +303,16 @@ class ChatGLMModel(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
|
self.embedding = VocabParallelEmbedding(
|
||||||
config.hidden_size)
|
config.padded_vocab_size, config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
self.num_layers = config.num_layers
|
self.num_layers = config.num_layers
|
||||||
self.multi_query_group_num = config.multi_query_group_num
|
self.multi_query_group_num = config.multi_query_group_num
|
||||||
self.kv_channels = config.kv_channels
|
self.kv_channels = config.kv_channels
|
||||||
self.encoder = GLMTransformer(config, cache_config, quant_config)
|
self.encoder = GLMTransformer(config, cache_config, quant_config)
|
||||||
|
|
||||||
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
|
||||||
config.hidden_size)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -322,7 +334,7 @@ class ChatGLMModel(nn.Module):
|
|||||||
class ChatGLMForCausalLM(nn.Module):
|
class ChatGLMForCausalLM(nn.Module):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"query_key_value": ["query_key_value"],
|
"query_key_value": ["query_key_value"],
|
||||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
"dense_h_to_4h": ["dense_h_to_4h"],
|
||||||
}
|
}
|
||||||
# LoRA specific attributes
|
# LoRA specific attributes
|
||||||
supported_lora_modules = [
|
supported_lora_modules = [
|
||||||
@@ -344,8 +356,7 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config: ChatGLMConfig = config
|
self.config: ChatGLMConfig = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
|
||||||
8192)
|
|
||||||
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
||||||
self.lm_head = self.transformer.output_layer
|
self.lm_head = self.transformer.output_layer
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
@@ -357,8 +368,7 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions,
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||||
input_metadata)
|
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
@@ -382,10 +392,10 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
EntryClass = ChatGLMForCausalLM
|
EntryClass = ChatGLMForCausalLM
|
||||||
# compat: glm model.config class == ChatGLMModel
|
# compat: glm model.config class == ChatGLMModel
|
||||||
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
|
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
|
||||||
|
|||||||
@@ -23,7 +23,7 @@
|
|||||||
|
|
||||||
# This file is based on the LLama model definition file in transformers
|
# This file is based on the LLama model definition file in transformers
|
||||||
"""PyTorch Cohere model."""
|
"""PyTorch Cohere model."""
|
||||||
from typing import Optional, Tuple, Iterable
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import (
|
|||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.config import LoRAConfig, CacheConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul
|
from vllm.model_executor.layers.activation import GeluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
||||||
"""Inference-only Grok1 model."""
|
"""Inference-only Grok1 model."""
|
||||||
from typing import Iterable, Optional, Tuple, List
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -9,7 +9,6 @@ import torch.nn.functional as F
|
|||||||
import tqdm
|
import tqdm
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
@@ -35,12 +34,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
||||||
from sglang.srt.layers.fused_moe import fused_moe
|
from sglang.srt.layers.fused_moe import fused_moe
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
use_fused = True
|
use_fused = True
|
||||||
|
|
||||||
|
|
||||||
@@ -134,9 +132,12 @@ class Grok1MoEUnfused(nn.Module):
|
|||||||
|
|
||||||
final_hidden_states = torch.zeros(
|
final_hidden_states = torch.zeros(
|
||||||
(hidden_states.shape[0], hidden_dim),
|
(hidden_states.shape[0], hidden_dim),
|
||||||
dtype=hidden_states.dtype, device=hidden_states.device
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
)
|
)
|
||||||
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_total_experts).permute(2, 1, 0)
|
expert_mask = torch.nn.functional.one_hot(
|
||||||
|
selected_experts, num_classes=self.num_total_experts
|
||||||
|
).permute(2, 1, 0)
|
||||||
|
|
||||||
for expert_idx in self.expert_indicies:
|
for expert_idx in self.expert_indicies:
|
||||||
expert_layer = self.experts[expert_idx]
|
expert_layer = self.experts[expert_idx]
|
||||||
@@ -153,7 +154,10 @@ class Grok1MoEUnfused(nn.Module):
|
|||||||
# the current expert. We need to make sure to multiply the output hidden
|
# the current expert. We need to make sure to multiply the output hidden
|
||||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||||
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
||||||
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
|
current_hidden_states = (
|
||||||
|
expert_layer(current_state)
|
||||||
|
* routing_weights[top_x_list, idx_list, None]
|
||||||
|
)
|
||||||
|
|
||||||
# However `index_add_` only support torch tensors for indexing so we'll use
|
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||||
# the `top_x` tensor here.
|
# the `top_x` tensor here.
|
||||||
@@ -198,32 +202,46 @@ class Grok1MoE(nn.Module):
|
|||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
# Gate always runs at half / full precision for now.
|
# Gate always runs at half / full precision for now.
|
||||||
self.gate = ReplicatedLinear(self.hidden_size,
|
self.gate = ReplicatedLinear(
|
||||||
self.num_total_experts,
|
self.hidden_size,
|
||||||
bias=False,
|
self.num_total_experts,
|
||||||
params_dtype=self.params_dtype,
|
bias=False,
|
||||||
quant_config=None)
|
params_dtype=self.params_dtype,
|
||||||
|
quant_config=None,
|
||||||
|
)
|
||||||
|
|
||||||
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
params_dtype = torch.float8_e4m3fn
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
self.w13_weight = nn.Parameter(
|
self.w13_weight = nn.Parameter(
|
||||||
torch.empty(self.num_total_experts,
|
torch.empty(
|
||||||
2 * self.intermediate_size,
|
self.num_total_experts,
|
||||||
self.hidden_size,
|
2 * self.intermediate_size,
|
||||||
dtype=params_dtype))
|
self.hidden_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
)
|
||||||
|
)
|
||||||
self.w2_weight = nn.Parameter(
|
self.w2_weight = nn.Parameter(
|
||||||
torch.empty(self.num_total_experts,
|
torch.empty(
|
||||||
self.hidden_size,
|
self.num_total_experts,
|
||||||
self.intermediate_size,
|
self.hidden_size,
|
||||||
dtype=params_dtype))
|
self.intermediate_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
set_weight_attrs(self.w13_weight, {
|
set_weight_attrs(
|
||||||
"weight_loader": self.weight_loader,
|
self.w13_weight,
|
||||||
})
|
{
|
||||||
set_weight_attrs(self.w2_weight, {
|
"weight_loader": self.weight_loader,
|
||||||
"weight_loader": self.weight_loader,
|
},
|
||||||
})
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
self.w2_weight,
|
||||||
|
{
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Used for fp8.
|
# Used for fp8.
|
||||||
self.w13_scale = None
|
self.w13_scale = None
|
||||||
@@ -233,46 +251,69 @@ class Grok1MoE(nn.Module):
|
|||||||
|
|
||||||
if self.use_fp8:
|
if self.use_fp8:
|
||||||
# WEIGHT_SCALE (for fp8)
|
# WEIGHT_SCALE (for fp8)
|
||||||
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
self.w13_scale = nn.Parameter(
|
||||||
dtype=torch.float32),
|
torch.ones(self.num_total_experts, dtype=torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False,
|
||||||
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
)
|
||||||
dtype=torch.float32),
|
self.w2_scale = nn.Parameter(
|
||||||
requires_grad=False)
|
torch.ones(self.num_total_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
# If loading fp8 checkpoint, pass the weight loaders.
|
# If loading fp8 checkpoint, pass the weight loaders.
|
||||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||||
# process_weights_after_loading()
|
# process_weights_after_loading()
|
||||||
if quant_config.is_checkpoint_fp8_serialized:
|
if quant_config.is_checkpoint_fp8_serialized:
|
||||||
set_weight_attrs(self.w13_scale, {
|
set_weight_attrs(
|
||||||
"weight_loader": self.weight_loader,
|
self.w13_scale,
|
||||||
})
|
{
|
||||||
set_weight_attrs(self.w2_scale, {
|
"weight_loader": self.weight_loader,
|
||||||
"weight_loader": self.weight_loader,
|
},
|
||||||
})
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
self.w2_scale,
|
||||||
|
{
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# ACT_SCALE (for fp8)
|
# ACT_SCALE (for fp8)
|
||||||
if quant_config.activation_scheme == "static":
|
if quant_config.activation_scheme == "static":
|
||||||
if not quant_config.is_checkpoint_fp8_serialized:
|
if not quant_config.is_checkpoint_fp8_serialized:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Found static activation scheme for checkpoint that "
|
"Found static activation scheme for checkpoint that "
|
||||||
"was not serialized fp8.")
|
"was not serialized fp8."
|
||||||
self.a13_scale = nn.Parameter(torch.zeros(
|
)
|
||||||
self.num_total_experts, dtype=torch.float32),
|
self.a13_scale = nn.Parameter(
|
||||||
requires_grad=False)
|
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
||||||
self.a2_scale = nn.Parameter(torch.zeros(
|
requires_grad=False,
|
||||||
self.num_total_experts, dtype=torch.float32),
|
)
|
||||||
requires_grad=False)
|
self.a2_scale = nn.Parameter(
|
||||||
|
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
set_weight_attrs(self.a13_scale, {
|
set_weight_attrs(
|
||||||
"weight_loader": self.weight_loader,
|
self.a13_scale,
|
||||||
})
|
{
|
||||||
set_weight_attrs(self.a2_scale, {
|
"weight_loader": self.weight_loader,
|
||||||
"weight_loader": self.weight_loader,
|
},
|
||||||
})
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
self.a2_scale,
|
||||||
|
{
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
def weight_loader(
|
||||||
weight_name: str, expert_id: int, pre_sharded: bool):
|
self,
|
||||||
|
param: nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str,
|
||||||
|
expert_id: int,
|
||||||
|
pre_sharded: bool,
|
||||||
|
):
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
shard_size = self.intermediate_size
|
shard_size = self.intermediate_size
|
||||||
if pre_sharded:
|
if pre_sharded:
|
||||||
@@ -284,8 +325,9 @@ class Grok1MoE(nn.Module):
|
|||||||
if weight_name.endswith("w1.weight"):
|
if weight_name.endswith("w1.weight"):
|
||||||
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
||||||
if weight_name.endswith("w3.weight"):
|
if weight_name.endswith("w3.weight"):
|
||||||
param_data[expert_id,
|
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
||||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
shard, :
|
||||||
|
]
|
||||||
if weight_name.endswith("w2.weight"):
|
if weight_name.endswith("w2.weight"):
|
||||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||||
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
||||||
@@ -298,17 +340,17 @@ class Grok1MoE(nn.Module):
|
|||||||
|
|
||||||
# If checkpoint is fp16, quantize here.
|
# If checkpoint is fp16, quantize here.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
w13_weight = torch.empty_like(self.w13_weight.data,
|
w13_weight = torch.empty_like(
|
||||||
dtype=torch.float8_e4m3fn)
|
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
||||||
w2_weight = torch.empty_like(self.w2_weight.data,
|
)
|
||||||
dtype=torch.float8_e4m3fn)
|
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
||||||
for expert in range(self.num_total_experts):
|
for expert in range(self.num_total_experts):
|
||||||
w13_weight[expert, :, :], self.w13_scale[
|
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
||||||
expert] = ops.scaled_fp8_quant(
|
self.w13_weight.data[expert, :, :]
|
||||||
self.w13_weight.data[expert, :, :])
|
)
|
||||||
w2_weight[expert, :, :], self.w2_scale[
|
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
||||||
expert] = ops.scaled_fp8_quant(
|
self.w2_weight.data[expert, :, :]
|
||||||
self.w2_weight.data[expert, :, :])
|
)
|
||||||
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
||||||
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
|
||||||
@@ -319,40 +361,40 @@ class Grok1MoE(nn.Module):
|
|||||||
if self.a13_scale is None or self.a2_scale is None:
|
if self.a13_scale is None or self.a2_scale is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"QuantConfig has static quantization, but found "
|
"QuantConfig has static quantization, but found "
|
||||||
"activation scales are None.")
|
"activation scales are None."
|
||||||
|
)
|
||||||
|
|
||||||
if (not all_close_1d(self.a13_scale)
|
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
||||||
or not all_close_1d(self.a2_scale)):
|
|
||||||
print_warning_once(
|
print_warning_once(
|
||||||
"Found act_scales that are not equal for fp8 MoE layer. "
|
"Found act_scales that are not equal for fp8 MoE layer. "
|
||||||
"Using the maximum across experts for each layer. ")
|
"Using the maximum across experts for each layer. "
|
||||||
|
)
|
||||||
|
|
||||||
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
||||||
requires_grad=False)
|
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
||||||
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = fused_moe(hidden_states,
|
final_hidden_states = fused_moe(
|
||||||
self.w13_weight,
|
hidden_states,
|
||||||
self.w2_weight,
|
self.w13_weight,
|
||||||
router_logits,
|
self.w2_weight,
|
||||||
self.top_k,
|
router_logits,
|
||||||
renormalize=False,
|
self.top_k,
|
||||||
inplace=True,
|
renormalize=False,
|
||||||
use_fp8=self.use_fp8,
|
inplace=True,
|
||||||
w1_scale=self.w13_scale,
|
use_fp8=self.use_fp8,
|
||||||
w2_scale=self.w2_scale,
|
w1_scale=self.w13_scale,
|
||||||
a1_scale=self.a13_scale,
|
w2_scale=self.w2_scale,
|
||||||
a2_scale=self.a2_scale)
|
a1_scale=self.a13_scale,
|
||||||
|
a2_scale=self.a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
final_hidden_states)
|
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_size)
|
return final_hidden_states.view(num_tokens, hidden_size)
|
||||||
|
|
||||||
@@ -462,10 +504,12 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.block_sparse_moe = Grok1MoEUnfused(
|
self.block_sparse_moe = Grok1MoEUnfused(
|
||||||
config=config, quant_config=quant_config)
|
config=config, quant_config=quant_config
|
||||||
|
)
|
||||||
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@@ -478,12 +522,21 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
hidden_states = self.post_attn_norm(self.self_attn(
|
hidden_states = (
|
||||||
positions=positions, hidden_states=self.pre_attn_norm(hidden_states),
|
self.post_attn_norm(
|
||||||
input_metadata=input_metadata,
|
self.self_attn(
|
||||||
)) + hidden_states
|
positions=positions,
|
||||||
|
hidden_states=self.pre_attn_norm(hidden_states),
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
+ hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + hidden_states
|
hidden_states = (
|
||||||
|
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
|
||||||
|
+ hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -525,9 +578,7 @@ class Grok1Model(nn.Module):
|
|||||||
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
||||||
|
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
hidden_states = self.layers[i](
|
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
||||||
positions, hidden_states, input_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
hidden_states.mul_(self.config.output_multiplier_scale)
|
hidden_states.mul_(self.config.output_multiplier_scale)
|
||||||
@@ -572,28 +623,41 @@ class Grok1ModelForCausalLM(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if use_fused:
|
if use_fused:
|
||||||
expert_params_mapping = [
|
expert_params_mapping = (
|
||||||
# These are the weight scales for the experts
|
[
|
||||||
# (param_name, weight_name, expert_id)
|
# These are the weight scales for the experts
|
||||||
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
# (param_name, weight_name, expert_id)
|
||||||
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
|
(
|
||||||
for expert_id in range(self.config.num_local_experts)
|
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
f"experts.{expert_id}.{weight_name}.weight_scale",
|
||||||
] + [
|
expert_id,
|
||||||
# These are the weights for the experts
|
)
|
||||||
# (param_name, weight_name, expert_id)
|
for expert_id in range(self.config.num_local_experts)
|
||||||
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
]
|
||||||
for expert_id in range(self.config.num_local_experts)
|
+ [
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
# These are the weights for the experts
|
||||||
] + [
|
# (param_name, weight_name, expert_id)
|
||||||
# These are the activation scales for the experts
|
(
|
||||||
# (param_name, weight_name, expert_id)
|
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
||||||
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
f"experts.{expert_id}.{weight_name}.weight",
|
||||||
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
|
expert_id,
|
||||||
for expert_id in range(self.config.num_local_experts)
|
)
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
for expert_id in range(self.config.num_local_experts)
|
||||||
]
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
# These are the activation scales for the experts
|
||||||
|
# (param_name, weight_name, expert_id)
|
||||||
|
(
|
||||||
|
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
||||||
|
f"experts.{expert_id}.{weight_name}.act_scale",
|
||||||
|
expert_id,
|
||||||
|
)
|
||||||
|
for expert_id in range(self.config.num_local_experts)
|
||||||
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
expert_params_mapping = []
|
expert_params_mapping = []
|
||||||
|
|
||||||
@@ -601,11 +665,11 @@ class Grok1ModelForCausalLM(nn.Module):
|
|||||||
if get_tensor_model_parallel_rank() == 0:
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
|
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
#print(get_tensor_model_parallel_rank(), name)
|
# print(get_tensor_model_parallel_rank(), name)
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
@@ -623,19 +687,22 @@ class Grok1ModelForCausalLM(nn.Module):
|
|||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param,
|
weight_loader(
|
||||||
loaded_weight,
|
param,
|
||||||
weight_name,
|
loaded_weight,
|
||||||
expert_id=expert_id,
|
weight_name,
|
||||||
pre_sharded=get_tensor_model_parallel_world_size() > 1)
|
expert_id=expert_id,
|
||||||
|
pre_sharded=get_tensor_model_parallel_world_size() > 1,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(
|
||||||
default_weight_loader)
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
@@ -645,10 +712,11 @@ def all_close_1d(x: torch.Tensor) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
||||||
def _prepare_presharded_weights(self,
|
|
||||||
model_name_or_path: str,
|
|
||||||
revision: Optional[str],
|
def _prepare_presharded_weights(
|
||||||
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
|
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
||||||
|
) -> Tuple[str, List[str], bool]:
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -668,4 +736,4 @@ def _prepare_presharded_weights(self,
|
|||||||
return hf_folder, hf_weights_files, use_safetensors
|
return hf_folder, hf_weights_files, use_safetensors
|
||||||
|
|
||||||
|
|
||||||
EntryClass = Grok1ModelForCausalLM
|
EntryClass = Grok1ModelForCausalLM
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
||||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||||
from typing import Any, Dict, Optional, Tuple, Iterable
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
@@ -10,7 +10,7 @@ from transformers import LlamaConfig
|
|||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@@ -158,9 +158,11 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
if rope_scaling is not None and getattr(
|
if rope_scaling is not None and getattr(
|
||||||
config, "original_max_position_embeddings", None):
|
config, "original_max_position_embeddings", None
|
||||||
|
):
|
||||||
rope_scaling["original_max_position_embeddings"] = (
|
rope_scaling["original_max_position_embeddings"] = (
|
||||||
config.original_max_position_embeddings)
|
config.original_max_position_embeddings
|
||||||
|
)
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||||
self.self_attn = LlamaAttention(
|
self.self_attn = LlamaAttention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
|
|||||||
@@ -1,11 +1,17 @@
|
|||||||
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
from typing import List, Iterable, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig
|
from transformers import (
|
||||||
|
CLIPVisionConfig,
|
||||||
|
CLIPVisionModel,
|
||||||
|
LlavaConfig,
|
||||||
|
MistralConfig,
|
||||||
|
Qwen2Config,
|
||||||
|
)
|
||||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
@@ -19,8 +25,8 @@ from sglang.srt.mm_utils import (
|
|||||||
unpad_image_shape,
|
unpad_image_shape,
|
||||||
)
|
)
|
||||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
|
||||||
from sglang.srt.models.mistral import MistralForCausalLM
|
from sglang.srt.models.mistral import MistralForCausalLM
|
||||||
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||||
|
|
||||||
|
|
||||||
class LlavaLlamaForCausalLM(nn.Module):
|
class LlavaLlamaForCausalLM(nn.Module):
|
||||||
@@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
|||||||
|
|
||||||
first_call = True
|
first_call = True
|
||||||
|
|
||||||
|
|
||||||
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||||
batch_size = pixel_values.shape[0]
|
batch_size = pixel_values.shape[0]
|
||||||
|
|
||||||
@@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
EntryClass = [
|
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
||||||
LlavaLlamaForCausalLM,
|
|
||||||
LlavaQwenForCausalLM,
|
|
||||||
LlavaMistralForCausalLM
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
|
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
from typing import List, Iterable, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralMoE(nn.Module):
|
class MixtralMoE(nn.Module):
|
||||||
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
||||||
across all ranks.
|
across all ranks.
|
||||||
@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
|
|||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
# Gate always runs at half / full precision for now.
|
# Gate always runs at half / full precision for now.
|
||||||
self.gate = ReplicatedLinear(self.hidden_size,
|
self.gate = ReplicatedLinear(
|
||||||
self.num_total_experts,
|
self.hidden_size,
|
||||||
bias=False,
|
self.num_total_experts,
|
||||||
params_dtype=self.params_dtype,
|
bias=False,
|
||||||
quant_config=None)
|
params_dtype=self.params_dtype,
|
||||||
|
quant_config=None,
|
||||||
|
)
|
||||||
|
|
||||||
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
params_dtype = torch.float8_e4m3fn
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
self.w13_weight = nn.Parameter(
|
self.w13_weight = nn.Parameter(
|
||||||
torch.empty(self.num_total_experts,
|
torch.empty(
|
||||||
2 * self.intermediate_size,
|
self.num_total_experts,
|
||||||
self.hidden_size,
|
2 * self.intermediate_size,
|
||||||
dtype=params_dtype))
|
self.hidden_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
)
|
||||||
|
)
|
||||||
self.w2_weight = nn.Parameter(
|
self.w2_weight = nn.Parameter(
|
||||||
torch.empty(self.num_total_experts,
|
torch.empty(
|
||||||
self.hidden_size,
|
self.num_total_experts,
|
||||||
self.intermediate_size,
|
self.hidden_size,
|
||||||
dtype=params_dtype))
|
self.intermediate_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
set_weight_attrs(self.w13_weight, {
|
set_weight_attrs(
|
||||||
"weight_loader": self.weight_loader,
|
self.w13_weight,
|
||||||
})
|
{
|
||||||
set_weight_attrs(self.w2_weight, {
|
"weight_loader": self.weight_loader,
|
||||||
"weight_loader": self.weight_loader,
|
},
|
||||||
})
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
self.w2_weight,
|
||||||
|
{
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Used for fp8.
|
# Used for fp8.
|
||||||
self.w13_scale = None
|
self.w13_scale = None
|
||||||
@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
|
|||||||
|
|
||||||
if self.use_fp8:
|
if self.use_fp8:
|
||||||
# WEIGHT_SCALE (for fp8)
|
# WEIGHT_SCALE (for fp8)
|
||||||
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
self.w13_scale = nn.Parameter(
|
||||||
dtype=torch.float32),
|
torch.ones(self.num_total_experts, dtype=torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False,
|
||||||
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
)
|
||||||
dtype=torch.float32),
|
self.w2_scale = nn.Parameter(
|
||||||
requires_grad=False)
|
torch.ones(self.num_total_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
# If loading fp8 checkpoint, pass the weight loaders.
|
# If loading fp8 checkpoint, pass the weight loaders.
|
||||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||||
# process_weights_after_loading()
|
# process_weights_after_loading()
|
||||||
if quant_config.is_checkpoint_fp8_serialized:
|
if quant_config.is_checkpoint_fp8_serialized:
|
||||||
set_weight_attrs(self.w13_scale, {
|
set_weight_attrs(
|
||||||
"weight_loader": self.weight_loader,
|
self.w13_scale,
|
||||||
})
|
{
|
||||||
set_weight_attrs(self.w2_scale, {
|
"weight_loader": self.weight_loader,
|
||||||
"weight_loader": self.weight_loader,
|
},
|
||||||
})
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
self.w2_scale,
|
||||||
|
{
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# ACT_SCALE (for fp8)
|
# ACT_SCALE (for fp8)
|
||||||
if quant_config.activation_scheme == "static":
|
if quant_config.activation_scheme == "static":
|
||||||
if not quant_config.is_checkpoint_fp8_serialized:
|
if not quant_config.is_checkpoint_fp8_serialized:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Found static activation scheme for checkpoint that "
|
"Found static activation scheme for checkpoint that "
|
||||||
"was not serialized fp8.")
|
"was not serialized fp8."
|
||||||
self.a13_scale = nn.Parameter(torch.zeros(
|
)
|
||||||
self.num_total_experts, dtype=torch.float32),
|
self.a13_scale = nn.Parameter(
|
||||||
requires_grad=False)
|
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
||||||
self.a2_scale = nn.Parameter(torch.zeros(
|
requires_grad=False,
|
||||||
self.num_total_experts, dtype=torch.float32),
|
)
|
||||||
requires_grad=False)
|
self.a2_scale = nn.Parameter(
|
||||||
|
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
set_weight_attrs(self.a13_scale, {
|
set_weight_attrs(
|
||||||
"weight_loader": self.weight_loader,
|
self.a13_scale,
|
||||||
})
|
{
|
||||||
set_weight_attrs(self.a2_scale, {
|
"weight_loader": self.weight_loader,
|
||||||
"weight_loader": self.weight_loader,
|
},
|
||||||
})
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
self.a2_scale,
|
||||||
|
{
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
def weight_loader(
|
||||||
weight_name: str, expert_id: int):
|
self,
|
||||||
|
param: nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str,
|
||||||
|
expert_id: int,
|
||||||
|
):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
shard_size = self.intermediate_size
|
shard_size = self.intermediate_size
|
||||||
@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
|
|||||||
if weight_name.endswith("w1.weight"):
|
if weight_name.endswith("w1.weight"):
|
||||||
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
||||||
if weight_name.endswith("w3.weight"):
|
if weight_name.endswith("w3.weight"):
|
||||||
param_data[expert_id,
|
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
||||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
shard, :
|
||||||
|
]
|
||||||
if weight_name.endswith("w2.weight"):
|
if weight_name.endswith("w2.weight"):
|
||||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||||
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
||||||
@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
|
|||||||
|
|
||||||
# If checkpoint is fp16, quantize here.
|
# If checkpoint is fp16, quantize here.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
w13_weight = torch.empty_like(self.w13_weight.data,
|
w13_weight = torch.empty_like(
|
||||||
dtype=torch.float8_e4m3fn)
|
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
||||||
w2_weight = torch.empty_like(self.w2_weight.data,
|
)
|
||||||
dtype=torch.float8_e4m3fn)
|
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
||||||
for expert in range(self.num_total_experts):
|
for expert in range(self.num_total_experts):
|
||||||
w13_weight[expert, :, :], self.w13_scale[
|
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
||||||
expert] = ops.scaled_fp8_quant(
|
self.w13_weight.data[expert, :, :]
|
||||||
self.w13_weight.data[expert, :, :])
|
)
|
||||||
w2_weight[expert, :, :], self.w2_scale[
|
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
||||||
expert] = ops.scaled_fp8_quant(
|
self.w2_weight.data[expert, :, :]
|
||||||
self.w2_weight.data[expert, :, :])
|
)
|
||||||
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
||||||
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
|
||||||
@@ -193,40 +228,40 @@ class MixtralMoE(nn.Module):
|
|||||||
if self.a13_scale is None or self.a2_scale is None:
|
if self.a13_scale is None or self.a2_scale is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"QuantConfig has static quantization, but found "
|
"QuantConfig has static quantization, but found "
|
||||||
"activation scales are None.")
|
"activation scales are None."
|
||||||
|
)
|
||||||
|
|
||||||
if (not all_close_1d(self.a13_scale)
|
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
||||||
or not all_close_1d(self.a2_scale)):
|
|
||||||
print_warning_once(
|
print_warning_once(
|
||||||
"Found act_scales that are not equal for fp8 MoE layer. "
|
"Found act_scales that are not equal for fp8 MoE layer. "
|
||||||
"Using the maximum across experts for each layer. ")
|
"Using the maximum across experts for each layer. "
|
||||||
|
)
|
||||||
|
|
||||||
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
||||||
requires_grad=False)
|
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
||||||
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = fused_moe(hidden_states,
|
final_hidden_states = fused_moe(
|
||||||
self.w13_weight,
|
hidden_states,
|
||||||
self.w2_weight,
|
self.w13_weight,
|
||||||
router_logits,
|
self.w2_weight,
|
||||||
self.top_k,
|
router_logits,
|
||||||
renormalize=True,
|
self.top_k,
|
||||||
inplace=True,
|
renormalize=True,
|
||||||
use_fp8=self.use_fp8,
|
inplace=True,
|
||||||
w1_scale=self.w13_scale,
|
use_fp8=self.use_fp8,
|
||||||
w2_scale=self.w2_scale,
|
w1_scale=self.w13_scale,
|
||||||
a1_scale=self.a13_scale,
|
w2_scale=self.w2_scale,
|
||||||
a2_scale=self.a2_scale)
|
a1_scale=self.a13_scale,
|
||||||
|
a2_scale=self.a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
final_hidden_states)
|
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_size)
|
return final_hidden_states.view(num_tokens, hidden_size)
|
||||||
|
|
||||||
@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(
|
self.post_attention_layernorm = RMSNorm(
|
||||||
config.hidden_size, eps=config.rms_norm_eps
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
]
|
]
|
||||||
|
|
||||||
expert_params_mapping = [
|
expert_params_mapping = (
|
||||||
# These are the weight scales for the experts
|
[
|
||||||
# (param_name, weight_name, expert_id)
|
# These are the weight scales for the experts
|
||||||
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
# (param_name, weight_name, expert_id)
|
||||||
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
|
(
|
||||||
for expert_id in range(self.config.num_local_experts)
|
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
f"experts.{expert_id}.{weight_name}.weight_scale",
|
||||||
] + [
|
expert_id,
|
||||||
# These are the weights for the experts
|
)
|
||||||
# (param_name, weight_name, expert_id)
|
for expert_id in range(self.config.num_local_experts)
|
||||||
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
]
|
||||||
for expert_id in range(self.config.num_local_experts)
|
+ [
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
# These are the weights for the experts
|
||||||
] + [
|
# (param_name, weight_name, expert_id)
|
||||||
# These are the activation scales for the experts
|
(
|
||||||
# (param_name, weight_name, expert_id)
|
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
||||||
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
f"experts.{expert_id}.{weight_name}.weight",
|
||||||
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
|
expert_id,
|
||||||
for expert_id in range(self.config.num_local_experts)
|
)
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
for expert_id in range(self.config.num_local_experts)
|
||||||
]
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
# These are the activation scales for the experts
|
||||||
|
# (param_name, weight_name, expert_id)
|
||||||
|
(
|
||||||
|
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
||||||
|
f"experts.{expert_id}.{weight_name}.act_scale",
|
||||||
|
expert_id,
|
||||||
|
)
|
||||||
|
for expert_id in range(self.config.num_local_experts)
|
||||||
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param,
|
weight_loader(
|
||||||
loaded_weight,
|
param, loaded_weight, weight_name, expert_id=expert_id
|
||||||
weight_name,
|
)
|
||||||
expert_id=expert_id)
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(
|
||||||
default_weight_loader)
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
||||||
from typing import Any, Dict, Optional, Iterable, Tuple
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Adapted from llama2.py
|
# Adapted from llama2.py
|
||||||
# Modify details for the adaptation of Qwen2 model.
|
# Modify details for the adaptation of Qwen2 model.
|
||||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||||
from typing import Any, Dict, Optional, Tuple, Iterable
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
|
||||||
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
|
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
|
||||||
model compatible with HuggingFace weights."""
|
model compatible with HuggingFace weights."""
|
||||||
from typing import Optional, Tuple, Iterable
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
"""Inference-only Yi-VL model."""
|
"""Inference-only Yi-VL model."""
|
||||||
|
|
||||||
from typing import Tuple, Iterable, Optional
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import CLIPVisionModel, LlavaConfig
|
from transformers import CLIPVisionModel, LlavaConfig
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
|
||||||
from sglang.srt.models.llava import (
|
from sglang.srt.models.llava import (
|
||||||
LlavaLlamaForCausalLM,
|
LlavaLlamaForCausalLM,
|
||||||
monkey_path_clip_vision_embed_forward,
|
monkey_path_clip_vision_embed_forward,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import os
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from sglang.srt.conversation import (
|
from sglang.srt.conversation import (
|
||||||
Conversation,
|
Conversation,
|
||||||
@@ -40,21 +40,18 @@ chat_template_name = None
|
|||||||
def create_error_response(
|
def create_error_response(
|
||||||
message: str,
|
message: str,
|
||||||
err_type: str = "BadRequestError",
|
err_type: str = "BadRequestError",
|
||||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
|
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||||
error = ErrorResponse(message=message,
|
):
|
||||||
type=err_type,
|
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
|
||||||
code=status_code.value)
|
return JSONResponse(content=error.model_dump(), status_code=error.code)
|
||||||
return JSONResponse(content=error.model_dump(),
|
|
||||||
status_code=error.code)
|
|
||||||
|
|
||||||
|
|
||||||
def create_streaming_error_response(
|
def create_streaming_error_response(
|
||||||
message: str,
|
message: str,
|
||||||
err_type: str = "BadRequestError",
|
err_type: str = "BadRequestError",
|
||||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
|
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||||
error = ErrorResponse(message=message,
|
) -> str:
|
||||||
type=err_type,
|
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
|
||||||
code=status_code.value)
|
|
||||||
json_str = json.dumps({"error": error.model_dump()})
|
json_str = json.dumps({"error": error.model_dump()})
|
||||||
return json_str
|
return json_str
|
||||||
|
|
||||||
@@ -125,7 +122,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
n_prev_token = 0
|
n_prev_token = 0
|
||||||
try:
|
try:
|
||||||
async for content in tokenizer_manager.generate_request(
|
async for content in tokenizer_manager.generate_request(
|
||||||
adapted_request, raw_request):
|
adapted_request, raw_request
|
||||||
|
):
|
||||||
text = content["text"]
|
text = content["text"]
|
||||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||||
@@ -154,12 +152,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
decode_token_logprobs=content["meta_info"][
|
decode_token_logprobs=content["meta_info"][
|
||||||
"decode_token_logprobs"
|
"decode_token_logprobs"
|
||||||
][n_prev_token:],
|
][n_prev_token:],
|
||||||
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
|
decode_top_logprobs=content["meta_info"][
|
||||||
n_prev_token:
|
"decode_top_logprobs"
|
||||||
],
|
][n_prev_token:],
|
||||||
)
|
)
|
||||||
|
|
||||||
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
|
n_prev_token = len(
|
||||||
|
content["meta_info"]["decode_token_logprobs"]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
@@ -188,13 +188,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
yield f"data: {error}\n\n"
|
yield f"data: {error}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
|
return StreamingResponse(
|
||||||
background=tokenizer_manager.create_abort_task(adapted_request))
|
generate_stream_resp(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
background=tokenizer_manager.create_abort_task(adapted_request),
|
||||||
|
)
|
||||||
|
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.generate_request(
|
ret = await tokenizer_manager.generate_request(
|
||||||
adapted_request, raw_request).__anext__()
|
adapted_request, raw_request
|
||||||
|
).__anext__()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(str(e))
|
return create_error_response(str(e))
|
||||||
|
|
||||||
@@ -299,7 +303,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
|
|
||||||
stream_buffer = ""
|
stream_buffer = ""
|
||||||
try:
|
try:
|
||||||
async for content in tokenizer_manager.generate_request(adapted_request, raw_request):
|
async for content in tokenizer_manager.generate_request(
|
||||||
|
adapted_request, raw_request
|
||||||
|
):
|
||||||
if is_first:
|
if is_first:
|
||||||
# First chunk with role
|
# First chunk with role
|
||||||
is_first = False
|
is_first = False
|
||||||
@@ -334,13 +340,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
yield f"data: {error}\n\n"
|
yield f"data: {error}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
|
return StreamingResponse(
|
||||||
background=tokenizer_manager.create_abort_task(adapted_request))
|
generate_stream_resp(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
background=tokenizer_manager.create_abort_task(adapted_request),
|
||||||
|
)
|
||||||
|
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.generate_request(
|
ret = await tokenizer_manager.generate_request(
|
||||||
adapted_request, raw_request).__anext__()
|
adapted_request, raw_request
|
||||||
|
).__anext__()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(str(e))
|
return create_error_response(str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Optional, Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
# Fix a bug of Python threading
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
@@ -29,10 +29,14 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|||||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.srt.constrained import disable_cache
|
from sglang.srt.constrained import disable_cache
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.srt.managers.controller.manager_multi import (
|
||||||
|
start_controller_process as start_controller_process_multi,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.controller.manager_single import (
|
||||||
|
start_controller_process as start_controller_process_single,
|
||||||
|
)
|
||||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
|
|
||||||
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
|
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.openai_api_adapter import (
|
from sglang.srt.openai_api_adapter import (
|
||||||
load_chat_template_for_openai_api,
|
load_chat_template_for_openai_api,
|
||||||
@@ -97,8 +101,11 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(stream_results(), media_type="text/event-stream",
|
return StreamingResponse(
|
||||||
background=tokenizer_manager.create_abort_task(obj))
|
stream_results(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
background=tokenizer_manager.create_abort_task(obj),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Common utilities."""
|
"""Common utilities."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import multiprocessing
|
|
||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
@@ -17,12 +17,11 @@ import requests
|
|||||||
import rpyc
|
import rpyc
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
from rpyc.utils.server import ThreadedServer
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
|
from rpyc.utils.server import ThreadedServer
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -377,7 +376,7 @@ def init_rpyc_service(service: rpyc.Service, port: int):
|
|||||||
protocol_config={
|
protocol_config={
|
||||||
"allow_public_attrs": True,
|
"allow_public_attrs": True,
|
||||||
"allow_pickle": True,
|
"allow_pickle": True,
|
||||||
"sync_request_timeout": 3600
|
"sync_request_timeout": 3600,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
t.logger.setLevel(logging.WARN)
|
t.logger.setLevel(logging.WARN)
|
||||||
@@ -396,7 +395,7 @@ def connect_to_rpyc_service(port, host="localhost"):
|
|||||||
config={
|
config={
|
||||||
"allow_public_attrs": True,
|
"allow_public_attrs": True,
|
||||||
"allow_pickle": True,
|
"allow_pickle": True,
|
||||||
"sync_request_timeout": 3600
|
"sync_request_timeout": 3600,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
@@ -423,7 +422,9 @@ def suppress_other_loggers():
|
|||||||
|
|
||||||
vllm_default_logger.setLevel(logging.WARN)
|
vllm_default_logger.setLevel(logging.WARN)
|
||||||
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
||||||
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(logging.WARN)
|
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
|
||||||
|
logging.WARN
|
||||||
|
)
|
||||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||||
|
|
||||||
@@ -464,6 +465,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
|||||||
device_name = torch.cuda.get_device_name(gpu_id)
|
device_name = torch.cuda.get_device_name(gpu_id)
|
||||||
if "RTX 40" not in device_name:
|
if "RTX 40" not in device_name:
|
||||||
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
||||||
|
|
||||||
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
||||||
|
|
||||||
|
|
||||||
@@ -485,4 +487,3 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
|||||||
)
|
)
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
@@ -356,16 +356,25 @@ def test_completion_speculative():
|
|||||||
s += "Construct a character within the following format:\n"
|
s += "Construct a character within the following format:\n"
|
||||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||||
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
|
s += (
|
||||||
|
"Name:"
|
||||||
|
+ sgl.gen("name", stop="\n")
|
||||||
|
+ "\nBirthday:"
|
||||||
|
+ sgl.gen("birthday", stop="\n")
|
||||||
|
)
|
||||||
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
||||||
|
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def gen_character_no_spec(s):
|
def gen_character_no_spec(s):
|
||||||
s += "Construct a character within the following format:\n"
|
s += "Construct a character within the following format:\n"
|
||||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||||
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
|
s += (
|
||||||
|
"Name:"
|
||||||
|
+ sgl.gen("name", stop="\n")
|
||||||
|
+ "\nBirthday:"
|
||||||
|
+ sgl.gen("birthday", stop="\n")
|
||||||
|
)
|
||||||
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
||||||
|
|
||||||
token_usage = sgl.global_config.default_backend.token_usage
|
token_usage = sgl.global_config.default_backend.token_usage
|
||||||
@@ -378,7 +387,9 @@ def test_completion_speculative():
|
|||||||
gen_character_no_spec().sync()
|
gen_character_no_spec().sync()
|
||||||
usage_with_no_spec = token_usage.prompt_tokens
|
usage_with_no_spec = token_usage.prompt_tokens
|
||||||
|
|
||||||
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
|
assert (
|
||||||
|
usage_with_spec < usage_with_no_spec
|
||||||
|
), f"{usage_with_spec} vs {usage_with_no_spec}"
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_speculative():
|
def test_chat_completion_speculative():
|
||||||
@@ -386,8 +397,17 @@ def test_chat_completion_speculative():
|
|||||||
def gen_character_spec(s):
|
def gen_character_spec(s):
|
||||||
s += sgl.system("You are a helpful assistant.")
|
s += sgl.system("You are a helpful assistant.")
|
||||||
s += sgl.user("Construct a character within the following format:")
|
s += sgl.user("Construct a character within the following format:")
|
||||||
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
|
s += sgl.assistant(
|
||||||
|
"Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||||
|
)
|
||||||
s += sgl.user("Please generate new Name, Birthday and Job.\n")
|
s += sgl.user("Please generate new Name, Birthday and Job.\n")
|
||||||
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
s += sgl.assistant(
|
||||||
|
"Name:"
|
||||||
|
+ sgl.gen("name", stop="\n")
|
||||||
|
+ "\nBirthday:"
|
||||||
|
+ sgl.gen("birthday", stop="\n")
|
||||||
|
+ "\nJob:"
|
||||||
|
+ sgl.gen("job", stop="\n")
|
||||||
|
)
|
||||||
|
|
||||||
gen_character_spec().sync()
|
gen_character_spec().sync()
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from json import dumps
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -255,8 +254,10 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
|
|||||||
|
|
||||||
def graceful_registry(sub_module_name):
|
def graceful_registry(sub_module_name):
|
||||||
def graceful_shutdown(signum, frame):
|
def graceful_shutdown(signum, frame):
|
||||||
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...")
|
logger.info(
|
||||||
|
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
|
||||||
|
)
|
||||||
if signum == signal.SIGTERM:
|
if signum == signal.SIGTERM:
|
||||||
logger.info(f"{sub_module_name} recive sigterm")
|
logger.info(f"{sub_module_name} recive sigterm")
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, graceful_shutdown)
|
signal.signal(signal.SIGTERM, graceful_shutdown)
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ import unittest
|
|||||||
|
|
||||||
from sglang import OpenAI, set_default_backend
|
from sglang import OpenAI, set_default_backend
|
||||||
from sglang.test.test_programs import (
|
from sglang.test.test_programs import (
|
||||||
|
test_chat_completion_speculative,
|
||||||
|
test_completion_speculative,
|
||||||
test_decode_int,
|
test_decode_int,
|
||||||
test_decode_json,
|
test_decode_json,
|
||||||
test_expert_answer,
|
test_expert_answer,
|
||||||
@@ -14,8 +16,6 @@ from sglang.test.test_programs import (
|
|||||||
test_select,
|
test_select,
|
||||||
test_stream,
|
test_stream,
|
||||||
test_tool_use,
|
test_tool_use,
|
||||||
test_completion_speculative,
|
|
||||||
test_chat_completion_speculative
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -97,4 +97,4 @@ if __name__ == "__main__":
|
|||||||
# global_config.verbosity = 2
|
# global_config.verbosity = 2
|
||||||
# t = TestOpenAIBackend()
|
# t = TestOpenAIBackend()
|
||||||
# t.setUp()
|
# t.setUp()
|
||||||
# t.test_chat_completion_speculative()
|
# t.test_chat_completion_speculative()
|
||||||
|
|||||||
Reference in New Issue
Block a user