release initial code
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com> Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
385
python/sglang/srt/constrained/fsm.py
Normal file
385
python/sglang/srt/constrained/fsm.py
Normal file
@@ -0,0 +1,385 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py
|
||||
from typing import List, NewType, Protocol
|
||||
|
||||
import interegular
|
||||
from lark import Lark
|
||||
|
||||
# from outlines.fsm.parsing import PartialLark
|
||||
from sglang.srt.constrained.regex import (
|
||||
create_fsm_index_tokenizer,
|
||||
make_deterministic_fsm,
|
||||
)
|
||||
from sglang.srt.constrained.tokenizer import Tokenizer
|
||||
|
||||
FSMState = NewType("FSMState", int)
|
||||
|
||||
|
||||
class FSM(Protocol):
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
...
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
...
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
...
|
||||
|
||||
def reset(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
class StopAtTokenFSM(FSM):
|
||||
"""FSM to generate text until a specified token id is generated or
|
||||
a specified number of tokens has been generated.
|
||||
|
||||
Text is usually produced until the EOS token is generated by the
|
||||
model.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: "Tokenizer",
|
||||
stop_token_id: int,
|
||||
):
|
||||
self.stop_token_id = stop_token_id
|
||||
self.num_tokens_generated = 0
|
||||
self.vocabulary = tokenizer.vocabulary.values()
|
||||
self.final_states = {1}
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
When in the initial state we allow every token to be generated.
|
||||
In the final state the only allowed token is `stop_token_id`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
if state == 0:
|
||||
return list(self.vocabulary)
|
||||
else:
|
||||
return [self.stop_token_id]
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
The FSM stays in the initial state `0` unless the specified stop token
|
||||
has been generated or the maximum number of tokens has been reached. In
|
||||
which case the FSM moves to the final state `1`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
|
||||
if token_id == self.stop_token_id:
|
||||
return FSMState(1)
|
||||
|
||||
return FSMState(0)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
"""Determine whether the current state of the FSM is a final state."""
|
||||
return state in self.final_states
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state. Here this only resets the token counter."""
|
||||
self.num_tokens_generated = 0
|
||||
|
||||
|
||||
class RegexFSM(FSM):
|
||||
"""FSM to generate text that is in the language of a regular expression."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regex_string: str,
|
||||
tokenizer: "Tokenizer",
|
||||
):
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||
(
|
||||
self.states_to_token_maps,
|
||||
self.empty_token_ids,
|
||||
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
|
||||
|
||||
# We make sure that it is possible to generate strings in the language
|
||||
# of the regular expression with the tokens present in the model's
|
||||
# vocabulary.
|
||||
if not any(
|
||||
regex_fsm.finals.intersection(v.values())
|
||||
for v in self.states_to_token_maps.values()
|
||||
):
|
||||
raise ValueError(
|
||||
"The vocabulary does not allow us to build a sequence that matches the input regex"
|
||||
)
|
||||
|
||||
self.final_states = regex_fsm.finals | {
|
||||
-1
|
||||
} # Include the EOS token in final states
|
||||
self.num_tokens_generated = 0
|
||||
self.vocabulary = tokenizer.vocabulary.values()
|
||||
self.end_token_id = tokenizer.eos_token_id
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
The initialization of the FSM builds an index which maps FSM states to a
|
||||
map from authorized tokens to the state in which the FSM needs to move
|
||||
if said token is generated. Therefore the authorized tokens at the
|
||||
current state are the keys of the map returned by the value of the index
|
||||
for current state.
|
||||
|
||||
If the current state is not contained in the end this means that we are
|
||||
in a final state of the FSM. We only authorize EOS tokens in the final
|
||||
state.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
next_tokens_to_end_states = self.states_to_token_maps.get(state)
|
||||
|
||||
if next_tokens_to_end_states is None:
|
||||
return [self.end_token_id]
|
||||
else:
|
||||
return list(next_tokens_to_end_states.keys())
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
We use the index to determine to which state the FSM should transition
|
||||
given the token that was just generated.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
|
||||
if token_id == self.end_token_id:
|
||||
return FSMState(-1)
|
||||
|
||||
last_token_to_end_state = self.states_to_token_maps[state]
|
||||
next_state = last_token_to_end_state.get(token_id)
|
||||
if next_state is None:
|
||||
next_state = -1
|
||||
|
||||
return FSMState(next_state)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
"""Determine whether the current state of the FSM is a final state."""
|
||||
return state in self.final_states
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state. Here this only resets the token counter."""
|
||||
self.num_tokens_generated = 0
|
||||
|
||||
|
||||
class CFGFSM(FSM):
|
||||
"""FSM to generate text that is in the language of a context-free grammar."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_string: str,
|
||||
tokenizer: "Tokenizer",
|
||||
):
|
||||
# self.parser = PartialLark(cfg_string, parser="lalr")
|
||||
self.parser = Lark(
|
||||
cfg_string,
|
||||
parser="lalr",
|
||||
lexer="contextual",
|
||||
propagate_positions=False,
|
||||
maybe_placeholders=False,
|
||||
regex=True,
|
||||
)
|
||||
self.terminal_regexps = dict()
|
||||
for terminal in self.parser.terminals:
|
||||
if terminal.pattern is not None:
|
||||
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
|
||||
self.terminal_regexps["$END"] = tokenizer.eos_token
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.num_tokens_generated = 0
|
||||
self.generations: List[str] = []
|
||||
self.regex_fsms: List[RegexFSM] = []
|
||||
self.reset_state: List[bool] = []
|
||||
self.allow_eos: List[bool] = []
|
||||
self.done: List[bool] = []
|
||||
|
||||
def _set_next_regex_fsm(self, idx: int = 0) -> None:
|
||||
"""Use the CFG incremental parser to set the next regex FSM.
|
||||
|
||||
Check what the CFG incremental parser proposes next.
|
||||
If the only proposal is the EOS token,
|
||||
we set the state to done and return.
|
||||
If there are other proposals,
|
||||
we set a new regex FSM and return.
|
||||
|
||||
"""
|
||||
interactive = self.parser.parse_interactive(self.generations[idx])
|
||||
interactive.exhaust_lexer()
|
||||
options = {self.terminal_regexps[x] for x in interactive.accepts()}
|
||||
|
||||
if self.terminal_regexps["$END"] in options:
|
||||
options.remove(self.terminal_regexps["$END"])
|
||||
if len(options) == 0:
|
||||
self.done[idx] = True
|
||||
return
|
||||
self.allow_eos[idx] = True
|
||||
options.add("")
|
||||
assert len(options) > 1
|
||||
|
||||
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
|
||||
args = (
|
||||
regex_string,
|
||||
self.tokenizer,
|
||||
)
|
||||
if len(self.regex_fsms) <= idx:
|
||||
self.regex_fsms.append(RegexFSM(*args))
|
||||
else:
|
||||
self.regex_fsms[idx] = RegexFSM(*args)
|
||||
self.reset_state[idx] = True
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
Upon initialization, the CFG incremental parser is used to determine the first regex.
|
||||
|
||||
This regex is used for proposals until either:
|
||||
- the regex is exhausted, and its only remaining option is the EOS token,
|
||||
in which case we always transition to the next regex
|
||||
- the regex can be exhausted, but the EOS token is not the only remaining option,
|
||||
in which case we transition to the next regex with probability P (TODO)
|
||||
or remove the possibility of generating the EOS token and continue with the current regex
|
||||
|
||||
The CFG incremental parser is allowed to propose the EOS token from any final state,
|
||||
and once it is generated, the FSM will continue to always generate the EOS token.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
if len(self.generations) <= idx:
|
||||
self.generations.append("")
|
||||
self.reset_state.append(False)
|
||||
self.allow_eos.append(False)
|
||||
self.done.append(False)
|
||||
|
||||
if len(self.regex_fsms) > idx:
|
||||
proposal = self.regex_fsms[idx].allowed_token_ids(state)
|
||||
if self.tokenizer.eos_token_id not in proposal:
|
||||
return proposal
|
||||
if set(proposal) != {self.tokenizer.eos_token_id}:
|
||||
if False: # TODO: THIS NEEDS TO BE SAMPLED
|
||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||
return proposal
|
||||
|
||||
self._set_next_regex_fsm(idx)
|
||||
|
||||
if self.done[idx]:
|
||||
return [self.tokenizer.eos_token_id]
|
||||
|
||||
if self.reset_state[idx]:
|
||||
state = FSMState(0)
|
||||
|
||||
proposal = self.regex_fsms[idx].allowed_token_ids(state)
|
||||
if self.allow_eos[idx]:
|
||||
self.allow_eos[idx] = False
|
||||
else:
|
||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||
assert len(proposal) > 0
|
||||
return proposal
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
Transitions the underlying regex FSM to its next state.
|
||||
If at max tokens or EOS token, transition permanently to the final state.
|
||||
Update stored partial generations for subsequent incremental parsing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
if token_id == self.tokenizer.eos_token_id:
|
||||
self.done[idx] = True
|
||||
return FSMState(-1)
|
||||
if self.reset_state[idx]:
|
||||
self.reset_state[idx] = False
|
||||
state = FSMState(0)
|
||||
|
||||
self.generations[idx] += self.tokenizer.decode([token_id])[0]
|
||||
|
||||
return self.regex_fsms[idx].next_state(state, token_id, idx)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
"""Return whether the current state of the FSM is a final state."""
|
||||
return self.done[idx]
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state, so it can be called on a fresh batch on inputs."""
|
||||
self.num_tokens_generated = 0
|
||||
self.generations = []
|
||||
self.regex_fsms = []
|
||||
self.reset_state = []
|
||||
self.done = []
|
||||
41
python/sglang/srt/constrained/fsm_cache.py
Normal file
41
python/sglang/srt/constrained/fsm_cache.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import threading
|
||||
|
||||
from sglang.srt.constrained.fsm import RegexFSM
|
||||
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
||||
|
||||
|
||||
def get_fsm(regex, tokenizer, fsm_cache_entry):
|
||||
outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
fsm = RegexFSM(regex, outlines_tokenizer)
|
||||
fsm_cache_entry.fsm = fsm
|
||||
fsm_cache_entry.event.set()
|
||||
|
||||
|
||||
class FSMCacheEntry:
|
||||
def __init__(self):
|
||||
self.fsm = None
|
||||
self.event = threading.Event()
|
||||
|
||||
|
||||
class FSMCache:
|
||||
def __init__(self, tokenizer):
|
||||
self.cache = {}
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def init_fsm_in_background(self, regex):
|
||||
if regex not in self.cache:
|
||||
self.cache[regex] = FSMCacheEntry()
|
||||
threading.Thread(
|
||||
target=get_fsm,
|
||||
args=(
|
||||
regex,
|
||||
self.tokenizer,
|
||||
self.cache[regex],
|
||||
),
|
||||
).start()
|
||||
|
||||
def get_fsm(self, regex):
|
||||
self.init_fsm_in_background(regex)
|
||||
entry = self.cache[regex]
|
||||
entry.event.wait()
|
||||
return entry.fsm
|
||||
586
python/sglang/srt/constrained/regex.py
Normal file
586
python/sglang/srt/constrained/regex.py
Normal file
@@ -0,0 +1,586 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/regex.py
|
||||
from collections import namedtuple
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Generator, List, Sequence, Set, Tuple
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from interegular.fsm import FSM, Alphabet, OblivionError, anything_else
|
||||
from numba.typed.typedobjectutils import _nonoptional
|
||||
from sglang.srt.constrained.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class BetterAlphabet(Alphabet):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert anything_else in self._symbol_mapping
|
||||
self.anything_value = self._symbol_mapping[anything_else]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._symbol_mapping.get(item, self.anything_value)
|
||||
|
||||
def copy(self):
|
||||
return BetterAlphabet(self._symbol_mapping.copy())
|
||||
|
||||
|
||||
class BetterFSM(FSM):
|
||||
flat_transition_map: Dict[Tuple[int, int], int]
|
||||
trans_key_to_states: Dict[int, List[int]]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if not isinstance(self.alphabet, BetterAlphabet):
|
||||
self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping)
|
||||
|
||||
flat_transition_map = {}
|
||||
trans_key_to_states = {}
|
||||
for from_state, trans_map in self.map.items():
|
||||
for trans_key, to_state in trans_map.items():
|
||||
flat_transition_map[(from_state, trans_key)] = to_state
|
||||
trans_key_to_states.setdefault(trans_key, set()).add(from_state)
|
||||
|
||||
self.__dict__["trans_key_to_states"] = trans_key_to_states
|
||||
self.__dict__["flat_transition_map"] = flat_transition_map
|
||||
self.__dict__["_fsm_info"] = None
|
||||
|
||||
def copy(self):
|
||||
return BetterFSM(
|
||||
alphabet=self.alphabet.copy(),
|
||||
states=self.states.copy(),
|
||||
initial=self.initial,
|
||||
finals=self.finals.copy(),
|
||||
map=self.map.copy(),
|
||||
__no_validation__=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def fsm_info(self):
|
||||
if self._fsm_info is None:
|
||||
flat_transition_map_items = np.fromiter(
|
||||
((a[0], a[1], b) for a, b in self.flat_transition_map.items()),
|
||||
dtype=np.dtype("i8, i8, i8"),
|
||||
)
|
||||
trans_key_to_states_items = np.fromiter(
|
||||
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
|
||||
dtype=np.dtype("i8, i8"),
|
||||
)
|
||||
alphabet_symbol_mapping_items = np.fromiter(
|
||||
(
|
||||
it
|
||||
for it in self.alphabet._symbol_mapping.items()
|
||||
if it[0] != anything_else
|
||||
),
|
||||
dtype=np.dtype("U1, i8"),
|
||||
)
|
||||
nb_finals = np.fromiter(self.finals, dtype=np.dtype("i8"))
|
||||
self.__dict__["_fsm_info"] = create_fsm_info(
|
||||
self.initial,
|
||||
nb_finals,
|
||||
flat_transition_map_items,
|
||||
trans_key_to_states_items,
|
||||
self.alphabet.anything_value,
|
||||
alphabet_symbol_mapping_items,
|
||||
)
|
||||
|
||||
return self._fsm_info
|
||||
|
||||
|
||||
nb_int_list_type = numba.types.ListType(numba.int64)
|
||||
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
|
||||
nb_unichar_1_type = numba.types.UnicodeCharSeq(1)
|
||||
|
||||
|
||||
@numba.njit(cache=True)
|
||||
def create_fsm_info(
|
||||
py_initial,
|
||||
py_finals,
|
||||
flat_transition_map_items,
|
||||
trans_key_to_states_items,
|
||||
py_anything_value,
|
||||
alphabet_symbol_mapping_items,
|
||||
):
|
||||
trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type)
|
||||
for trans_key_and_state in trans_key_to_states_items:
|
||||
trans_key_to_states.setdefault(
|
||||
trans_key_and_state[0], numba.typed.List.empty_list(numba.int64)
|
||||
).append(trans_key_and_state[1])
|
||||
|
||||
flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64)
|
||||
for trans_key_and_state in flat_transition_map_items:
|
||||
flat_transition_map[
|
||||
(trans_key_and_state[0], trans_key_and_state[1])
|
||||
] = trans_key_and_state[2]
|
||||
|
||||
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_1_type, numba.int64)
|
||||
for symbol_and_trans_key in alphabet_symbol_mapping_items:
|
||||
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]
|
||||
|
||||
initial = numba.int64(py_initial)
|
||||
|
||||
finals = set()
|
||||
for final in py_finals:
|
||||
finals.add(final)
|
||||
|
||||
anything_value = numba.int64(py_anything_value)
|
||||
|
||||
return FSMInfo(
|
||||
initial,
|
||||
finals,
|
||||
flat_transition_map,
|
||||
trans_key_to_states,
|
||||
anything_value,
|
||||
alphabet_symbol_map,
|
||||
)
|
||||
|
||||
|
||||
FSMInfo = namedtuple(
|
||||
"FSMInfo",
|
||||
[
|
||||
"initial",
|
||||
"finals",
|
||||
"transitions",
|
||||
"trans_key_to_states",
|
||||
"alphabet_anything_value",
|
||||
"alphabet_symbol_mapping",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
|
||||
"""Construct an equivalent FSM with deterministic state labels."""
|
||||
old_to_new_trans_keys = {
|
||||
trans_key: i
|
||||
for i, (trans_key, _) in enumerate(
|
||||
sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1]))
|
||||
)
|
||||
}
|
||||
|
||||
new_symbol_mapping = {
|
||||
symbol: old_to_new_trans_keys[trans_key]
|
||||
for symbol, trans_key in fsm.alphabet._symbol_mapping.items()
|
||||
}
|
||||
|
||||
new_alphabet = BetterAlphabet(new_symbol_mapping)
|
||||
|
||||
new_map = {
|
||||
from_state: {
|
||||
old_to_new_trans_keys[trans_key]: to_state
|
||||
for trans_key, to_state in trans_map.items()
|
||||
}
|
||||
for from_state, trans_map in fsm.map.items()
|
||||
}
|
||||
|
||||
old_to_new_states = {}
|
||||
old_to_new_states[fsm.initial] = 0
|
||||
|
||||
i = 0
|
||||
seen = {fsm.initial}
|
||||
old_state_queue = [fsm.initial]
|
||||
while old_state_queue:
|
||||
old_state = old_state_queue.pop(-1)
|
||||
transitions = new_map[old_state]
|
||||
sorted_transitions = sorted(transitions.items(), key=lambda v: v[0])
|
||||
for _, old_state in sorted_transitions:
|
||||
if old_state not in seen:
|
||||
old_state_queue.append(old_state)
|
||||
seen.add(old_state)
|
||||
if old_state not in old_to_new_states:
|
||||
i += 1
|
||||
old_to_new_states[old_state] = i
|
||||
|
||||
new_map = dict(
|
||||
sorted(
|
||||
(
|
||||
(
|
||||
old_to_new_states[from_state],
|
||||
dict(
|
||||
sorted(
|
||||
(
|
||||
(trans_key, old_to_new_states[to_state])
|
||||
for trans_key, to_state in trans_map.items()
|
||||
),
|
||||
key=lambda v: v[0],
|
||||
)
|
||||
),
|
||||
)
|
||||
for from_state, trans_map in new_map.items()
|
||||
),
|
||||
key=lambda v: v[0],
|
||||
)
|
||||
)
|
||||
|
||||
new_initial = 0
|
||||
new_finals = frozenset(
|
||||
sorted(old_to_new_states[old_state] for old_state in fsm.finals)
|
||||
)
|
||||
new_states = frozenset(sorted(new_map.keys()))
|
||||
|
||||
new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map)
|
||||
|
||||
return new_fsm, old_to_new_states
|
||||
|
||||
|
||||
@numba.njit(nogil=True, cache=True)
|
||||
def _walk_fsm(
|
||||
fsm_transitions: Dict[Tuple[int, int], int],
|
||||
alphabet_symbol_mapping: Dict[str, int],
|
||||
alphabet_anything_value: int,
|
||||
fsm_initial: int,
|
||||
fsm_finals: Set[int],
|
||||
input_string: str,
|
||||
start_state: int,
|
||||
full_match: bool = True,
|
||||
) -> List[int]:
|
||||
state = start_state
|
||||
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
|
||||
last_final_idx: int = numba.uint64(0)
|
||||
|
||||
for i, symbol in enumerate(input_string):
|
||||
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
|
||||
|
||||
new_state = fsm_transitions.get((state, trans_key))
|
||||
|
||||
if new_state is None:
|
||||
if not full_match and last_final_idx > 0:
|
||||
return accepted_states[:last_final_idx]
|
||||
|
||||
return numba.typed.List.empty_list(numba.int64)
|
||||
|
||||
state = new_state
|
||||
|
||||
if state in fsm_finals:
|
||||
last_final_idx = numba.uint64(i + 1)
|
||||
|
||||
accepted_states.append(_nonoptional(state))
|
||||
|
||||
if full_match and last_final_idx - 1 != i:
|
||||
return numba.typed.List.empty_list(numba.int64)
|
||||
|
||||
return accepted_states
|
||||
|
||||
|
||||
def walk_fsm(
|
||||
fsm: BetterFSM,
|
||||
input_string: str,
|
||||
start_state: int,
|
||||
full_match: bool = True,
|
||||
) -> List[int]:
|
||||
fsm_finals = fsm.finals
|
||||
|
||||
state = start_state
|
||||
accepted_states: List[int] = []
|
||||
last_final_idx: int = 0
|
||||
|
||||
alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
|
||||
alphabet_anything_value = fsm.alphabet.anything_value
|
||||
fsm_transitions = fsm.flat_transition_map
|
||||
|
||||
for i, symbol in enumerate(input_string):
|
||||
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
|
||||
|
||||
new_state = fsm_transitions.get((state, trans_key))
|
||||
|
||||
if new_state is None:
|
||||
if not full_match and last_final_idx > 0:
|
||||
return accepted_states[:last_final_idx]
|
||||
|
||||
return []
|
||||
|
||||
state = new_state
|
||||
|
||||
if state in fsm_finals:
|
||||
last_final_idx = i + 1
|
||||
|
||||
accepted_states.append(state)
|
||||
|
||||
if full_match and last_final_idx - 1 != i:
|
||||
return []
|
||||
|
||||
return accepted_states
|
||||
|
||||
|
||||
def fsm_union(
|
||||
fsms: Sequence[FSM],
|
||||
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
|
||||
"""Construct an FSM representing the union of the FSMs in `fsms`.
|
||||
|
||||
This is an updated version of `interegular.fsm.FSM.union` made to return an
|
||||
extra map of component FSMs to the sets of state transitions that
|
||||
correspond to them in the new FSM.
|
||||
|
||||
"""
|
||||
|
||||
alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])
|
||||
|
||||
indexed_fsms = tuple(enumerate(fsms))
|
||||
|
||||
initial = {i: fsm.initial for (i, fsm) in indexed_fsms}
|
||||
|
||||
# Dedicated function accepting a "superset" and returning the next
|
||||
# "superset" obtained by following this transition in the new FSM
|
||||
def follow(current_state, new_transition: int):
|
||||
next = {}
|
||||
for i, f in indexed_fsms:
|
||||
old_transition = new_to_old[i][new_transition]
|
||||
if (
|
||||
i in current_state
|
||||
and current_state[i] in f.map
|
||||
and old_transition in f.map[current_state[i]]
|
||||
):
|
||||
next[i] = f.map[current_state[i]][old_transition]
|
||||
if not next:
|
||||
raise OblivionError
|
||||
return next
|
||||
|
||||
states = [initial]
|
||||
finals: Set[int] = set()
|
||||
map: Dict[int, Dict[int, int]] = {}
|
||||
|
||||
# Map component FSMs to their new state-to-state transitions, finals, and a
|
||||
# map translating component FSM states to aggregate FSM states
|
||||
fsms_to_trans_finals: Dict[
|
||||
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
|
||||
] = {}
|
||||
|
||||
i = 0
|
||||
while i < len(states):
|
||||
state = states[i]
|
||||
|
||||
# Add to the finals of the aggregate FSM whenever we hit a final in a
|
||||
# component FSM
|
||||
if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
|
||||
finals.add(i)
|
||||
|
||||
# Compute the map for this state
|
||||
map[i] = {}
|
||||
for transition in alphabet.by_transition:
|
||||
try:
|
||||
next = follow(state, transition)
|
||||
except OblivionError:
|
||||
# Reached an oblivion state; don't list it
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# TODO: Seems like this could--and should--be avoided
|
||||
j = states.index(next)
|
||||
except ValueError:
|
||||
j = len(states)
|
||||
states.append(next)
|
||||
|
||||
map[i][transition] = j
|
||||
|
||||
for fsm_id, fsm_state in next.items():
|
||||
(
|
||||
fsm_transitions,
|
||||
fsm_finals,
|
||||
fsm_old_to_new,
|
||||
) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
|
||||
old_from = state[fsm_id]
|
||||
old_to = fsm_state
|
||||
fsm_old_to_new.setdefault(old_from, set()).add(i)
|
||||
fsm_old_to_new.setdefault(old_to, set()).add(j)
|
||||
fsm_transitions.add((i, j))
|
||||
if fsm_state in fsms[fsm_id].finals:
|
||||
fsm_finals.add(j)
|
||||
|
||||
i += 1
|
||||
|
||||
fsm = FSM(
|
||||
alphabet=alphabet,
|
||||
states=range(len(states)),
|
||||
initial=0,
|
||||
finals=finals,
|
||||
map=map,
|
||||
__no_validation__=True,
|
||||
)
|
||||
|
||||
fsm, old_to_new_states = make_deterministic_fsm(fsm)
|
||||
_fsms_to_trans_finals = {
|
||||
fsm_id: (
|
||||
{(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
|
||||
{old_to_new_states[s] for s in finals},
|
||||
{
|
||||
old_state: {old_to_new_states[new_state] for new_state in new_states}
|
||||
for old_state, new_states in old_to_new.items()
|
||||
},
|
||||
)
|
||||
for fsm_id, (transitions, finals, old_to_new) in sorted(
|
||||
fsms_to_trans_finals.items(), key=lambda x: x[0]
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
fsm,
|
||||
_fsms_to_trans_finals,
|
||||
)
|
||||
|
||||
|
||||
def get_sub_fsms_from_seq(
|
||||
state_seq: Sequence[int],
|
||||
fsms_to_trans_finals: Dict[
|
||||
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
|
||||
],
|
||||
) -> Generator[Tuple[int, bool, bool], None, None]:
|
||||
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state_seq
|
||||
A state sequence.
|
||||
fsms_to_trans_finals
|
||||
A map from FSM indices to tuples containing sets of their state transitions
|
||||
and sets of the final/accept states.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generator returning tuples containing each sub-FSM index (in the order
|
||||
they were union-ed to construct `fsm`) and booleans indicating whether or
|
||||
not there is another valid transition from the last state in the sequence
|
||||
for the associated sub-FSM (i.e. if the FSM can continue
|
||||
accepting/matching) and whether or not the sequence ends in a final state
|
||||
of the sub-FSM.
|
||||
"""
|
||||
state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
|
||||
last_fsm_state = state_seq[-1]
|
||||
yield from (
|
||||
(
|
||||
# The sub-FMS index
|
||||
fsm_idx,
|
||||
# Is there another possible transition in this sub-FSM?
|
||||
any(last_fsm_state == from_s for (from_s, to_s) in transitions),
|
||||
# Is this sub-FSM in a final state?
|
||||
state_seq[-1] in finals,
|
||||
)
|
||||
for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
|
||||
if state_seq_transitions.issubset(transitions)
|
||||
)
|
||||
|
||||
|
||||
@numba.njit(cache=True, nogil=True)
|
||||
def state_scan_tokens(
|
||||
fsm_transitions: Dict[Tuple[int, int], int],
|
||||
alphabet_symbol_mapping: Dict[str, int],
|
||||
alphabet_anything_value: int,
|
||||
fsm_initial: int,
|
||||
fsm_finals: Set[int],
|
||||
vocabulary: Dict[str, List[int]],
|
||||
start_state: int,
|
||||
) -> Set[Tuple[int, int]]:
|
||||
res = set()
|
||||
|
||||
for token, token_ids in vocabulary.items():
|
||||
state_seq = _walk_fsm(
|
||||
fsm_transitions,
|
||||
alphabet_symbol_mapping,
|
||||
alphabet_anything_value,
|
||||
fsm_initial,
|
||||
fsm_finals,
|
||||
token,
|
||||
start_state,
|
||||
False,
|
||||
)
|
||||
|
||||
if state_seq is not None and len(state_seq) < len(token):
|
||||
continue
|
||||
|
||||
for token_id in token_ids:
|
||||
res.add((token_id, state_seq[-1]))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def create_fsm_index_end_to_end(
|
||||
fsm_info: FSMInfo,
|
||||
vocabulary: Dict[str, List[int]],
|
||||
) -> Dict[int, Set[Tuple[int, int]]]:
|
||||
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""
|
||||
|
||||
# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
|
||||
# code, too.
|
||||
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {}
|
||||
seen: Set[int] = set()
|
||||
next_states = {fsm_info.initial}
|
||||
|
||||
while next_states:
|
||||
start_state = next_states.pop()
|
||||
|
||||
token_ids_end_states = state_scan_tokens(
|
||||
fsm_info.transitions,
|
||||
fsm_info.alphabet_symbol_mapping,
|
||||
fsm_info.alphabet_anything_value,
|
||||
fsm_info.initial,
|
||||
fsm_info.finals,
|
||||
vocabulary,
|
||||
start_state,
|
||||
)
|
||||
|
||||
for token_id_and_end_state in token_ids_end_states:
|
||||
states_to_token_subsets.setdefault(start_state, set()).add(
|
||||
token_id_and_end_state
|
||||
)
|
||||
end_state = token_id_and_end_state[1]
|
||||
if end_state not in seen:
|
||||
next_states.add(end_state)
|
||||
|
||||
seen.add(start_state)
|
||||
|
||||
return states_to_token_subsets
|
||||
|
||||
|
||||
# TODO: Cannot cache typed collections to disk, yet. See
|
||||
# https://github.com/numba/numba/issues/4698
|
||||
@lru_cache
|
||||
def reduced_vocabulary(tokenizer: "Tokenizer"):
|
||||
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
|
||||
vocabulary = numba.typed.Dict.empty(
|
||||
numba.types.string, numba.types.ListType(numba.int64)
|
||||
)
|
||||
empty_token_ids = set()
|
||||
for token, token_idx in tokenizer.vocabulary.items():
|
||||
if token in tokenizer.special_tokens:
|
||||
continue
|
||||
|
||||
token_str = tokenizer.convert_token_to_string(token)
|
||||
|
||||
if token_str:
|
||||
vocabulary.setdefault(
|
||||
token_str,
|
||||
numba.typed.List.empty_list(numba.int64),
|
||||
).append(numba.int64(token_idx))
|
||||
else:
|
||||
empty_token_ids.add(numba.int64(token_idx))
|
||||
|
||||
return vocabulary, empty_token_ids
|
||||
|
||||
|
||||
def create_fsm_index_tokenizer(
|
||||
fsm: BetterFSM,
|
||||
tokenizer: "Tokenizer",
|
||||
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
|
||||
"""Construct an FMS index from a tokenizer.
|
||||
|
||||
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
|
||||
|
||||
.. warning::
|
||||
|
||||
`fsm` needs to be deterministically ordered so that future caching makes sense.
|
||||
|
||||
"""
|
||||
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)
|
||||
|
||||
states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)
|
||||
|
||||
# Allow transitions to EOS from all terminals FSM states that are
|
||||
# reachable
|
||||
# TODO: Do we really need this anymore?
|
||||
for state in fsm.fsm_info.finals:
|
||||
subset = states_to_token_subsets.get(state)
|
||||
if subset is not None:
|
||||
subset.add((tokenizer.eos_token_id, state))
|
||||
|
||||
# Convert to token-to-end-state maps
|
||||
states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()}
|
||||
|
||||
return states_to_token_subsets, empty_token_ids
|
||||
266
python/sglang/srt/constrained/tokenizer.py
Normal file
266
python/sglang/srt/constrained/tokenizer.py
Normal file
@@ -0,0 +1,266 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
|
||||
from abc import abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Hashable,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
class Tokenizer(Protocol, Hashable):
|
||||
eos_token: str
|
||||
eos_token_id: int
|
||||
pad_token_id: int
|
||||
vocabulary: Dict[str, int]
|
||||
special_tokens: Set[int]
|
||||
|
||||
@abstractmethod
|
||||
def encode(
|
||||
self, prompt: Union[str, List[str]]
|
||||
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
|
||||
"""Translate the input prompts into NumPy arrays of token ids and attention mask."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
|
||||
"""Translate an array of token ids to a string or list of strings."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def convert_token_to_string(self, token: str) -> str:
|
||||
"""Convert a token to its equivalent string.
|
||||
|
||||
This is for instance useful for BPE tokenizers where whitespaces are
|
||||
represented by the special characted `Ġ`. This prevents matching a raw
|
||||
token that includes `Ġ` with a string.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
__all__ = ["transformers"]
|
||||
|
||||
|
||||
KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
|
||||
|
||||
|
||||
def get_llama_tokenizer_types():
|
||||
"""Get all the Llama tokenizer types/classes that need work-arounds.
|
||||
|
||||
When they can't be imported, a dummy class is created.
|
||||
|
||||
"""
|
||||
try:
|
||||
from transformers.models.llama import LlamaTokenizer
|
||||
except ImportError:
|
||||
|
||||
class LlamaTokenizer: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
except ImportError:
|
||||
|
||||
class LlamaTokenizerFast: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.code_llama import CodeLlamaTokenizer
|
||||
except ImportError:
|
||||
|
||||
class CodeLlamaTokenizer: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.code_llama import CodeLlamaTokenizerFast
|
||||
except ImportError:
|
||||
|
||||
class CodeLlamaTokenizerFast: # type: ignore
|
||||
pass
|
||||
|
||||
return (
|
||||
LlamaTokenizer,
|
||||
LlamaTokenizerFast,
|
||||
CodeLlamaTokenizer,
|
||||
CodeLlamaTokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
class Transformer:
|
||||
"""Represents a `transformers` model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
):
|
||||
self.device = model.device
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@torch.inference_mode
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values: Optional[Tuple] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
|
||||
"""Compute a forward pass through the transformer model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_ids
|
||||
The input token ids. Must be one or two dimensional.
|
||||
attention_mask
|
||||
The attention mask. Must be one or two dimensional.
|
||||
past_key_values
|
||||
A tuple of tuples containing the cached key and value tensors for each
|
||||
attention head.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The computed logits and the new cached key and value tensors.
|
||||
|
||||
"""
|
||||
assert 0 < input_ids.ndim < 3
|
||||
|
||||
if past_key_values:
|
||||
input_ids = input_ids[..., -1].unsqueeze(-1)
|
||||
|
||||
output = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
return output.logits, output.past_key_values
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values: Optional[Tuple] = None,
|
||||
) -> torch.FloatTensor:
|
||||
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
|
||||
next_token_logits = logits[..., -1, :]
|
||||
|
||||
return next_token_logits, kv_cache
|
||||
|
||||
|
||||
class TransformerTokenizer(Tokenizer):
|
||||
"""Represents a tokenizer for models in the `transformers` library."""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
# TODO: Do something to make this hashable?
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.eos_token = self.tokenizer.eos_token
|
||||
|
||||
if not self.tokenizer.pad_token_id:
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||
self.pad_token_id = self.eos_token_id
|
||||
else:
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.pad_token = self.tokenizer.pad_token
|
||||
|
||||
self.special_tokens = set(self.tokenizer.all_special_tokens)
|
||||
|
||||
self.vocabulary = self.tokenizer.get_vocab()
|
||||
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())
|
||||
|
||||
def encode(
|
||||
self, prompt: Union[str, List[str]], **kwargs
|
||||
) -> Tuple[torch.LongTensor, torch.LongTensor]:
|
||||
kwargs["padding"] = True
|
||||
kwargs["return_tensors"] = "pt"
|
||||
output = self.tokenizer(prompt, **kwargs)
|
||||
return output["input_ids"], output["attention_mask"]
|
||||
|
||||
def decode(self, token_ids: torch.LongTensor) -> List[str]:
|
||||
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return text
|
||||
|
||||
def convert_token_to_string(self, token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = self.tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
if self.is_llama:
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, type(self)):
|
||||
return False
|
||||
# TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ?
|
||||
# return other.model_name == self.model_name and other.kwargs == self.kwargs
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
from datasets.fingerprint import Hasher
|
||||
|
||||
return hash(Hasher.hash(self.tokenizer))
|
||||
|
||||
|
||||
def transformers(
|
||||
model_name: str,
|
||||
device: Optional[str] = None,
|
||||
model_kwargs: dict = {},
|
||||
tokenizer_kwargs: dict = {},
|
||||
):
|
||||
"""Instantiate a model from the `transformers` library and its tokenizer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name
|
||||
The name of the model as listed on Hugging Face's model page.
|
||||
device
|
||||
The device(s) on which the model should be loaded. This overrides
|
||||
the `device_map` entry in `model_kwargs` when provided.
|
||||
model_kwargs
|
||||
A dictionary that contains the keyword arguments to pass to the
|
||||
`from_pretrained` method when loading the model.
|
||||
tokenizer_kwargs
|
||||
A dictionary that contains the keyword arguments to pass to the
|
||||
`from_pretrained` method when loading the tokenizer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A `TransformersModel` model instance.
|
||||
|
||||
"""
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The `transformers` library needs to be installed in order to use `transformers` models."
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
model_kwargs["device_map"] = device
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
||||
tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs)
|
||||
|
||||
return Transformer(model, tokenizer)
|
||||
164
python/sglang/srt/hf_transformers_utils.py
Normal file
164
python/sglang/srt/hf_transformers_utils.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Utilities for Huggingface Transformers."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from sglang.srt.utils import is_multimodal_model
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
def download_from_hf(model_path: str):
|
||||
if os.path.exists(model_path):
|
||||
return model_path
|
||||
|
||||
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
|
||||
|
||||
|
||||
def get_config_json(model_path: str):
|
||||
with open(os.path.join(model_path, "config.json")) as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
|
||||
|
||||
def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None):
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
# Models don't use the same configuration key for determining the maximum
|
||||
# context length. Store them here so we can sanely check them.
|
||||
# NOTE: The ordering here is important. Some models have two of these and we
|
||||
# have a preference for which value gets used.
|
||||
CONTEXT_LENGTH_KEYS = [
|
||||
"max_sequence_length",
|
||||
"seq_length",
|
||||
"max_position_embeddings",
|
||||
"max_seq_len",
|
||||
"model_max_length",
|
||||
]
|
||||
|
||||
|
||||
def get_context_length(config):
|
||||
"""Get the context length of a model from a huggingface model config."""
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling:
|
||||
rope_scaling_factor = config.rope_scaling["factor"]
|
||||
else:
|
||||
rope_scaling_factor = 1
|
||||
|
||||
for key in CONTEXT_LENGTH_KEYS:
|
||||
val = getattr(config, key, None)
|
||||
if val is not None:
|
||||
return int(rope_scaling_factor * val)
|
||||
return 2048
|
||||
|
||||
|
||||
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
|
||||
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
tokenizer_name: str,
|
||||
*args,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||
if is_multimodal_model(tokenizer_name):
|
||||
processor = get_processor(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
**kwargs,
|
||||
)
|
||||
tokenizer = processor.tokenizer
|
||||
return tokenizer
|
||||
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
kwargs["use_fast"] = False
|
||||
|
||||
if (
|
||||
"llama" in tokenizer_name.lower()
|
||||
and kwargs.get("use_fast", True)
|
||||
and tokenizer_name != _FAST_LLAMA_TOKENIZER
|
||||
):
|
||||
pass
|
||||
# warnings.warn(
|
||||
# "For some LLaMA V1 models, initializing the fast tokenizer may "
|
||||
# "take a long time. To reduce the initialization time, consider "
|
||||
# f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
||||
# "tokenizer."
|
||||
# )
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
**kwargs,
|
||||
)
|
||||
except TypeError as e:
|
||||
# The LLaMA tokenizer causes a protobuf error in some environments.
|
||||
err_msg = (
|
||||
"Failed to load the tokenizer. If you are using a LLaMA V1 model "
|
||||
f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
|
||||
"original tokenizer."
|
||||
)
|
||||
raise RuntimeError(err_msg) from e
|
||||
except ValueError as e:
|
||||
# If the error pertains to the tokenizer class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
if not trust_remote_code and (
|
||||
"does not exist or is not currently imported." in str(e)
|
||||
or "requires you to execute the tokenizer file" in str(e)
|
||||
):
|
||||
err_msg = (
|
||||
"Failed to load the tokenizer. If the tokenizer is a custom "
|
||||
"tokenizer not yet available in the HuggingFace transformers "
|
||||
"library, consider setting `trust_remote_code=True` in LLM "
|
||||
"or using the `--trust-remote-code` flag in the CLI."
|
||||
)
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
warnings.warn(
|
||||
"Using a slow tokenizer. This might cause a significant "
|
||||
"slowdown. Consider using a fast tokenizer instead."
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_processor(
|
||||
tokenizer_name: str,
|
||||
*args,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
**kwargs,
|
||||
)
|
||||
return processor
|
||||
181
python/sglang/srt/layers/context_flashattention_nopad.py
Normal file
181
python/sglang/srt/layers/context_flashattention_nopad.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Out,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
|
||||
|
||||
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
|
||||
|
||||
cached_kernel = None
|
||||
|
||||
|
||||
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
BLOCK = 128
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
global cached_kernel
|
||||
if cached_kernel:
|
||||
cached_kernel(
|
||||
grid,
|
||||
num_warps,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
||||
371
python/sglang/srt/layers/extend_attention.py
Normal file
371
python/sglang/srt/layers/extend_attention.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q_Extend,
|
||||
K_Extend,
|
||||
V_Extend,
|
||||
O_Extend,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Seq_Len,
|
||||
B_Start_Loc_Extend,
|
||||
B_Seq_Len_Extend,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_req_to_tokens_b,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_seq = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
cur_block_m = tl.program_id(2)
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_seq_len = tl.load(B_Seq_Len + cur_seq)
|
||||
cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
|
||||
cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
|
||||
|
||||
cur_seq_prefix_start_in_loc = 0
|
||||
cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
|
||||
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
||||
offs_q = (
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
|
||||
|
||||
# stage1: compute scores with prefix
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
|
||||
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
||||
offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
|
||||
cur_seq_prefix_start_in_loc + start_n + offs_n
|
||||
)
|
||||
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
|
||||
|
||||
# load k in transposed way
|
||||
offs_buf_k = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
|
||||
offs_buf_v = (
|
||||
offs_kv_loc[:, None] * stride_buf_vbs
|
||||
+ cur_kv_head * stride_buf_vh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||
|
||||
e_max = n_e_max
|
||||
|
||||
# stage2: compute the trianlge part
|
||||
|
||||
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
||||
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_block_m_end
|
||||
|
||||
# load k in transposed way
|
||||
offs_k = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
||||
start_n + offs_n[None, :]
|
||||
)
|
||||
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
||||
qk = tl.where(mask_causual, qk, float("-inf"))
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
|
||||
offs_v = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
||||
+ cur_kv_head * stride_vh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||
|
||||
e_max = n_e_max
|
||||
|
||||
offs_o = (
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
||||
|
||||
|
||||
def extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
max_len_extend,
|
||||
):
|
||||
"""
|
||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
|
||||
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
||||
"""
|
||||
BLOCK_M, BLOCK_N = 128, 128
|
||||
Lq, Lk, Lv, Lo = (
|
||||
q_extend.shape[-1],
|
||||
k_extend.shape[-1],
|
||||
v_extend.shape[-1],
|
||||
o_extend.shape[-1],
|
||||
)
|
||||
assert Lq == Lk and Lk == Lv and Lv == Lo
|
||||
assert Lq in {16, 32, 64, 128}
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
||||
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||
|
||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
num_stages = 1
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
q_extend.stride(0),
|
||||
q_extend.stride(1),
|
||||
k_extend.stride(0),
|
||||
k_extend.stride(1),
|
||||
v_extend.stride(0),
|
||||
v_extend.stride(1),
|
||||
o_extend.stride(0),
|
||||
o_extend.stride(1),
|
||||
k_buffer.stride(0),
|
||||
k_buffer.stride(1),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
BLOCK_DMODEL=Lq,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
|
||||
def redundant_attention(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
max_len_in_batch,
|
||||
):
|
||||
total_token_num = k_buffer.shape[0]
|
||||
B, H_Q, D = b_req_idx.shape[0], q_extend.shape[-2], q_extend.shape[-1]
|
||||
q_buffer = torch.empty(
|
||||
(total_token_num, H_Q, D), dtype=q_extend.dtype, device=q_extend.device
|
||||
)
|
||||
|
||||
pt = 0
|
||||
for i in range(B):
|
||||
cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]
|
||||
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
|
||||
q_buffer[pl:pr] = q_extend[pt : pt + cur_seq_len_extend]
|
||||
pt += cur_seq_len_extend
|
||||
|
||||
o_buffer = torch.empty_like(q_buffer)
|
||||
context_attention_fwd(
|
||||
q_buffer, k_buffer, v_buffer, o_buffer, b_start_loc, b_seq_len, max_len_in_batch
|
||||
)
|
||||
|
||||
pt = 0
|
||||
for i in range(B):
|
||||
cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]
|
||||
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
|
||||
o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
|
||||
pt += cur_seq_len_extend
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(0)
|
||||
|
||||
B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128
|
||||
dtype = torch.float16
|
||||
|
||||
b_seq_len_prefix = torch.randint(
|
||||
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_seq_len_extend = torch.randint(
|
||||
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
||||
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
|
||||
|
||||
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
|
||||
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda")
|
||||
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
for i in range(B):
|
||||
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
|
||||
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
|
||||
)
|
||||
|
||||
total_token_num = torch.sum(b_seq_len).item()
|
||||
extend_token_num = torch.sum(b_seq_len_extend).item()
|
||||
k_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
v_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
||||
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
||||
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
for i in range(B):
|
||||
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
||||
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
||||
extend_start = b_start_loc_extend[i]
|
||||
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
||||
k_extend[extend_start:extend_end] = k_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
v_extend[extend_start:extend_end] = v_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
q_extend[extend_start:extend_end] = torch.empty(
|
||||
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
|
||||
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||
b_start_loc_extend = torch.zeros_like(b_seq_len)
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
||||
extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
max_len_extend,
|
||||
)
|
||||
|
||||
redundant_attention(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_redundant,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
max_len_in_batch,
|
||||
)
|
||||
|
||||
print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant)))
|
||||
print("Max: ", torch.max(torch.abs(o_extend - o_redundant)))
|
||||
|
||||
assert torch.allclose(o_extend, o_redundant, rtol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
79
python/sglang/srt/layers/get_selected_logprob.py
Normal file
79
python/sglang/srt/layers/get_selected_logprob.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_segmented_gather(
|
||||
all_logits,
|
||||
len_add_1,
|
||||
cum_len,
|
||||
input_ids,
|
||||
logprobs,
|
||||
max_seq_len,
|
||||
voc_size: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
cur_req = tl.program_id(0)
|
||||
cur_l = tl.load(len_add_1 + cur_req)
|
||||
cum_l = tl.load(cum_len + cur_req)
|
||||
|
||||
for i in range(0, (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE):
|
||||
off = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = off < cur_l - 1
|
||||
|
||||
idx = tl.load(input_ids + cum_l - cur_l + off + 1, mask=mask)
|
||||
data = tl.load(all_logits + (cum_l - cur_l + off) * voc_size + idx, mask=mask)
|
||||
tl.store(logprobs + cum_l - cur_l - cur_req + off, data, mask=mask)
|
||||
|
||||
|
||||
cached_kernel = None
|
||||
|
||||
|
||||
def get_selected_logprob(all_logits, len_add_1, input_ids, logprobs):
|
||||
cum_len = torch.cumsum(len_add_1, dtype=torch.int32, dim=0)
|
||||
voc_size = all_logits.shape[1]
|
||||
grid = (len_add_1.shape[0], 1, 1)
|
||||
max_seq_len = len_add_1.max().item()
|
||||
|
||||
global cached_kernel
|
||||
if cached_kernel:
|
||||
cached_kernel(
|
||||
grid,
|
||||
4,
|
||||
all_logits,
|
||||
len_add_1,
|
||||
cum_len,
|
||||
input_ids,
|
||||
logprobs,
|
||||
max_seq_len,
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_segmented_gather[grid](
|
||||
all_logits,
|
||||
len_add_1,
|
||||
cum_len,
|
||||
input_ids,
|
||||
logprobs,
|
||||
max_seq_len,
|
||||
voc_size,
|
||||
BLOCK_SIZE=128,
|
||||
)
|
||||
cached_kernel = wrap_kernel_launcher(_fwd_segmented_gather)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
all_logits = torch.tensor(
|
||||
# s s s
|
||||
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
len_add_1 = torch.tensor([2, 3], dtype=torch.int32, device="cuda")
|
||||
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
||||
logprobs = torch.empty((3), dtype=torch.float32, device="cuda")
|
||||
get_selected_logprobs(all_logits, len_add_1, input_ids, logprobs)
|
||||
print(logprobs)
|
||||
# assert logprobs == [2, 2, 4]
|
||||
77
python/sglang/srt/layers/logits_processor.py
Normal file
77
python/sglang/srt/layers/logits_processor.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
from sglang.srt.layers.get_selected_logprob import get_selected_logprob
|
||||
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
||||
from torch import nn
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata):
|
||||
if not input_metadata.return_normalized_logprob:
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_index = (
|
||||
torch.cumsum(
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
dim=0,
|
||||
dtype=torch.long,
|
||||
)
|
||||
- 1
|
||||
)
|
||||
last_hidden = hidden_states[last_index]
|
||||
hidden_states = None
|
||||
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
if self.tp_size > 1:
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
return last_logits, None
|
||||
else:
|
||||
assert input_metadata.forward_mode != ForwardMode.DECODE
|
||||
last_index = (
|
||||
torch.cumsum(
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
dim=0,
|
||||
dtype=torch.long,
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
logits = torch.matmul(hidden_states, weight.T)
|
||||
if self.tp_size > 1:
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
logits = logits[:, : self.config.vocab_size]
|
||||
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
|
||||
|
||||
normalized_logprobs = compute_normalized_logprobs(
|
||||
all_logprobs,
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
input_ids,
|
||||
)
|
||||
|
||||
last_logits = logits[last_index]
|
||||
return last_logits, normalized_logprobs
|
||||
|
||||
|
||||
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
|
||||
# assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0]
|
||||
logprobs = torch.zeros(
|
||||
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs)
|
||||
cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
end = torch.cumsum(len_add_1.sub_(1), dim=0)
|
||||
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
|
||||
end.sub_(1)
|
||||
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
|
||||
res = sum_logp / len_add_1
|
||||
return res
|
||||
158
python/sglang/srt/layers/radix_attention.py
Normal file
158
python/sglang/srt/layers/radix_attention.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
||||
from torch import nn
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
head_dim,
|
||||
scaling,
|
||||
num_kv_heads,
|
||||
layer_id,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tp_q_head_num = num_heads
|
||||
self.tp_k_head_num = num_kv_heads
|
||||
self.tp_v_head_num = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.layer_id = layer_id
|
||||
|
||||
from sglang.srt.managers.router.model_runner import global_model_mode
|
||||
|
||||
self.use_flashinfer = "flashinfer" in global_model_mode
|
||||
|
||||
if self.use_flashinfer:
|
||||
self.prefill_forward = self.prefill_forward_flashinfer
|
||||
self.extend_forward = self.prefill_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
else:
|
||||
self.prefill_forward = self.prefill_forward_triton
|
||||
self.extend_forward = self.extend_forward_triton
|
||||
self.decode_forward = self.decode_forward_triton
|
||||
|
||||
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
|
||||
context_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k,
|
||||
v,
|
||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
return o
|
||||
|
||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
extend_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.prefix_lens,
|
||||
input_metadata.extend_start_loc,
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.max_extend_len,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
token_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.other_kv_index,
|
||||
input_metadata.total_num_tokens,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = input_metadata.prefill_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.qo_indptr,
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
input_metadata.kv_indptr,
|
||||
input_metadata.kv_indices,
|
||||
input_metadata.kv_last_page_len,
|
||||
allow_fp16_qk_reduction=True,
|
||||
)
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = input_metadata.decode_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
input_metadata.kv_indptr,
|
||||
input_metadata.kv_indices,
|
||||
input_metadata.kv_last_page_len,
|
||||
)
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
def forward(self, q, k, v, input_metadata: InputMetadata):
|
||||
k = k.view(-1, self.tp_k_head_num, self.head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.head_dim)
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.PREFILL:
|
||||
return self.prefill_forward(q, k, v, input_metadata)
|
||||
elif input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
return self.extend_forward(q, k, v, input_metadata)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
return self.decode_forward(q, k, v, input_metadata)
|
||||
|
||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
||||
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
||||
if input_metadata.out_cache_loc is not None:
|
||||
key_buffer[input_metadata.out_cache_loc] = cache_k
|
||||
value_buffer[input_metadata.out_cache_loc] = cache_v
|
||||
elif input_metadata.out_cache_cont_start is not None:
|
||||
key_buffer[
|
||||
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
|
||||
] = cache_k
|
||||
value_buffer[
|
||||
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
|
||||
] = cache_v
|
||||
else:
|
||||
raise RuntimeError()
|
||||
324
python/sglang/srt/layers/token_attention.py
Normal file
324
python/sglang/srt/layers/token_attention.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# Adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage1(
|
||||
Q,
|
||||
K_Buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Att_Out,
|
||||
stride_req_to_tokens_b,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
att_stride_h,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_n = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
||||
|
||||
cur_batch_start_index = 0
|
||||
cur_batch_end_index = cur_batch_seq_len
|
||||
|
||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||
|
||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
block_stard_index = start_n * BLOCK_N
|
||||
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_mark in range(0, block_mask, 1):
|
||||
q = tl.load(Q + off_q + start_mark)
|
||||
offs_n_new = cur_batch_start_index + offs_n
|
||||
k_loc = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
||||
mask=offs_n_new < cur_batch_end_index,
|
||||
other=0,
|
||||
)
|
||||
offs_buf_k = (
|
||||
k_loc[:, None] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=offs_n_new[:, None] < cur_batch_end_index,
|
||||
other=0.0,
|
||||
)
|
||||
att_value = tl.sum(q[None, :] * k, 1)
|
||||
att_value *= sm_scale
|
||||
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
|
||||
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage2(
|
||||
Logics,
|
||||
V_Buffer,
|
||||
Out,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
stride_logic_h,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_req_to_token_b,
|
||||
other_kv_index, # To fix a NAN issue
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
|
||||
v_ptrs = V_Buffer + offs_buf_v
|
||||
|
||||
e_max = float("-inf")
|
||||
e_sum = 0.0
|
||||
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
v_index = tl.load(
|
||||
Req_to_tokens
|
||||
+ cur_batch_req_idx * stride_req_to_token_b
|
||||
+ (start_n + offs_n),
|
||||
mask=(start_n + offs_n) < cur_batch_seq_len,
|
||||
other=other_kv_index,
|
||||
)
|
||||
|
||||
qk = tl.load(
|
||||
Logics
|
||||
+ cur_head * stride_logic_h
|
||||
+ (cur_batch_start_loc + start_n + offs_n),
|
||||
mask=start_n + offs_n < cur_batch_seq_len,
|
||||
other=float("-inf"),
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
||||
old_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max)
|
||||
e_sum = e_sum * old_scale + tl.sum(p, 0)
|
||||
v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
|
||||
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
|
||||
e_max = n_e_max
|
||||
|
||||
acc = acc / e_sum
|
||||
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
cached_kernel_stage1 = None
|
||||
cached_kernel_stage2 = None
|
||||
|
||||
|
||||
def _token_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
att_out,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
max_len_in_batch,
|
||||
):
|
||||
BLOCK = 32
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||
assert Lq == Lk
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
sm_scale = 1.0 / (Lk**0.5)
|
||||
|
||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||
|
||||
if kv_group_num == 1:
|
||||
num_warps = 4
|
||||
else:
|
||||
num_warps = 2
|
||||
|
||||
global cached_kernel_stage1
|
||||
if cached_kernel_stage1:
|
||||
cached_kernel_stage1(
|
||||
grid,
|
||||
num_warps,
|
||||
q,
|
||||
k_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(0),
|
||||
k_buffer.stride(1),
|
||||
att_out.stride(0),
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(0),
|
||||
k_buffer.stride(1),
|
||||
att_out.stride(0),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
|
||||
|
||||
|
||||
def _token_softmax_reducev_fwd(
|
||||
logics,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
other_kv_index,
|
||||
):
|
||||
BLOCK = 64
|
||||
batch, head = b_seq_len.shape[0], logics.shape[0]
|
||||
grid = (batch, head, 1)
|
||||
kv_group_num = logics.shape[0] // v_buffer.shape[1]
|
||||
|
||||
num_warps = 1
|
||||
|
||||
global cached_kernel_stage2
|
||||
if cached_kernel_stage2:
|
||||
cached_kernel_stage2(
|
||||
grid,
|
||||
num_warps,
|
||||
logics,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
logics.stride(0),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
other_kv_index,
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel_stage2[grid](
|
||||
logics,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
logics.stride(0),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
other_kv_index,
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=v_buffer.shape[-1],
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=3,
|
||||
)
|
||||
cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
|
||||
|
||||
|
||||
def token_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
other_kv_index,
|
||||
total_num_tokens,
|
||||
att_m=None,
|
||||
):
|
||||
if att_m is None:
|
||||
att_m = torch.empty(
|
||||
(q.shape[-2], total_num_tokens), dtype=q.dtype, device="cuda"
|
||||
)
|
||||
|
||||
_token_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
att_m,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
)
|
||||
_token_softmax_reducev_fwd(
|
||||
att_m,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
other_kv_index,
|
||||
)
|
||||
85
python/sglang/srt/managers/detokenizer_manager.py
Normal file
85
python/sglang/srt/managers/detokenizer_manager.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import asyncio
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_exception_traceback
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
class DetokenizerManager:
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_router = context.socket(zmq.PULL)
|
||||
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
|
||||
|
||||
self.send_to_tokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
async def handle_loop(self):
|
||||
while True:
|
||||
recv_obj = await self.recv_from_router.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, BatchTokenIDOut):
|
||||
output_tokens = recv_obj.output_tokens
|
||||
|
||||
# TODO(lmzheng): handle skip_special_tokens per request
|
||||
output_strs = self.tokenizer.batch_decode(
|
||||
output_tokens,
|
||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||
)
|
||||
|
||||
# Trim stop str
|
||||
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||
for i in range(len(output_strs)):
|
||||
if recv_obj.hit_stop_str[i] is not None:
|
||||
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
|
||||
if pos != -1:
|
||||
output_strs[i] = output_strs[i][:pos]
|
||||
|
||||
if len(output_tokens[i]) > 0:
|
||||
first_token = self.tokenizer.convert_ids_to_tokens(
|
||||
int(output_tokens[i][0])
|
||||
)
|
||||
if first_token.startswith("▁"):
|
||||
output_strs[i] = " " + output_strs[i]
|
||||
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
BatchStrOut(
|
||||
recv_obj.rids,
|
||||
output_strs,
|
||||
recv_obj.meta_info,
|
||||
recv_obj.finished,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj}")
|
||||
|
||||
|
||||
def start_detokenizer_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
):
|
||||
try:
|
||||
manager = DetokenizerManager(server_args, port_args)
|
||||
except Exception as e:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
pipe_writer.send("init ok")
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(manager.handle_loop())
|
||||
88
python/sglang/srt/managers/io_struct.py
Normal file
88
python/sglang/srt/managers/io_struct.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateReqInput:
|
||||
text: Union[List[str], str]
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
return_normalized_logprob: Optional[Union[List[bool], bool]] = None
|
||||
normalized_logprob_start_len: Optional[Union[List[int], int]] = None
|
||||
stream: bool = False
|
||||
|
||||
def post_init(self):
|
||||
is_single = isinstance(self.text, str)
|
||||
|
||||
if is_single:
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = {}
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.return_normalized_logprob is None:
|
||||
self.return_normalized_logprob = False
|
||||
if self.normalized_logprob_start_len is None:
|
||||
self.normalized_logprob_start_len = 0
|
||||
else:
|
||||
num = len(self.text)
|
||||
|
||||
if self.image_data is None:
|
||||
self.image_data = [None] * num
|
||||
elif not isinstance(self.image_data, list):
|
||||
self.image_data = [self.image_data] * num
|
||||
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = [{}] * num
|
||||
elif not isinstance(self.sampling_params, list):
|
||||
self.sampling_params = [self.sampling_params] * num
|
||||
|
||||
if self.rid is None:
|
||||
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
||||
else:
|
||||
assert isinstance(self.rid, list)
|
||||
|
||||
if self.return_normalized_logprob is None:
|
||||
self.return_normalized_logprob = [False] * num
|
||||
elif not isinstance(self.return_normalized_logprob, list):
|
||||
self.return_normalized_logprob = [self.return_normalized_logprob] * num
|
||||
|
||||
if self.normalized_logprob_start_len is None:
|
||||
self.normalized_logprob_start_len = [0] * num
|
||||
elif not isinstance(self.normalized_logprob_start_len, list):
|
||||
self.normalized_logprob_start_len = [
|
||||
self.normalized_logprob_start_len
|
||||
] * num
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizedGenerateReqInput:
|
||||
rid: str
|
||||
input_ids: List[int]
|
||||
pixel_values: List[float]
|
||||
image_hash: int
|
||||
sampling_params: SamplingParams
|
||||
return_normalized_logprob: bool
|
||||
normalized_logprob_start_len: int
|
||||
stream: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchTokenIDOut:
|
||||
rids: List[str]
|
||||
output_tokens: List[List[int]]
|
||||
hit_stop_str: List[Optional[str]]
|
||||
skip_special_tokens: List[bool]
|
||||
meta_info: List[Dict]
|
||||
finished: List[bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchStrOut:
|
||||
rids: List[str]
|
||||
output_str: List[str]
|
||||
meta_info: List[Dict]
|
||||
finished: List[bool]
|
||||
12
python/sglang/srt/managers/openai_protocol.py
Normal file
12
python/sglang/srt/managers/openai_protocol.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionRequest:
|
||||
prompt: Union[str, List[Any]]
|
||||
model: str = "default"
|
||||
temperature: Optional[float] = 0.7
|
||||
max_tokens: Optional[int] = 16
|
||||
n: Optional[int] = 1
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
326
python/sglang/srt/managers/router/infer_batch.py
Normal file
326
python/sglang/srt/managers/router/infer_batch.py
Normal file
@@ -0,0 +1,326 @@
|
||||
from enum import Enum, auto
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
|
||||
|
||||
class ForwardMode(Enum):
|
||||
PREFILL = auto()
|
||||
EXTEND = auto()
|
||||
DECODE = auto()
|
||||
|
||||
|
||||
class FinishReason(Enum):
|
||||
LENGTH = auto()
|
||||
EOS_TOKEN = auto()
|
||||
STOP_STR = auto()
|
||||
|
||||
|
||||
class Req:
|
||||
def __init__(self, rid):
|
||||
self.rid = rid
|
||||
self.input_ids = []
|
||||
self.output_ids = []
|
||||
self.pixel_values = None
|
||||
self.image_offset = 0
|
||||
self.sampling_params = None
|
||||
self.return_normalized_logprob = False
|
||||
self.normalized_logprob_start_len = 0
|
||||
self.stream = False
|
||||
|
||||
self.tokenizer = None
|
||||
self.finished = False
|
||||
self.finish_reason = None
|
||||
self.hit_stop_str = None
|
||||
|
||||
self.adjust_input_len = 0
|
||||
self.prefix_indices = []
|
||||
|
||||
self.normalized_logprob = None
|
||||
|
||||
# for constrained decoding
|
||||
self.regex_fsm = None
|
||||
self.regex_fsm_state = None
|
||||
|
||||
def max_new_tokens(self):
|
||||
return self.sampling_params.max_new_tokens
|
||||
|
||||
def check_finished(self):
|
||||
if self.finished:
|
||||
return
|
||||
|
||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||
self.finished = True
|
||||
self.finish_reason = FinishReason.LENGTH
|
||||
return
|
||||
|
||||
if (
|
||||
self.output_ids[-1] == self.tokenizer.eos_token_id
|
||||
and self.sampling_params.ignore_eos == False
|
||||
):
|
||||
self.finished = True
|
||||
self.finish_reason = FinishReason.EOS_TOKEN
|
||||
return
|
||||
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
tail_str = self.tokenizer.decode(
|
||||
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
||||
)
|
||||
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str:
|
||||
self.finished = True
|
||||
self.finish_reason = FinishReason.STOP_STR
|
||||
self.hit_stop_str = stop_str
|
||||
return
|
||||
|
||||
def __repr__(self):
|
||||
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
|
||||
|
||||
|
||||
class Batch:
|
||||
def __init__(
|
||||
self,
|
||||
reqs: List[Req],
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: TokenToKVPool,
|
||||
tree_cache: RadixCache,
|
||||
):
|
||||
self.reqs = reqs
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
self.return_normalized_logprob = any(
|
||||
req.return_normalized_logprob for req in reqs
|
||||
)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||
device = "cuda"
|
||||
bs = len(self.reqs)
|
||||
reqs = self.reqs
|
||||
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
prefix_indices = [r.prefix_indices for r in reqs]
|
||||
|
||||
# Handle prefix
|
||||
flatten_input_ids = []
|
||||
extend_lens = []
|
||||
prefix_lens = []
|
||||
seq_lens = []
|
||||
|
||||
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
for i in range(bs):
|
||||
flatten_input_ids.extend(input_ids[i])
|
||||
extend_lens.append(len(input_ids[i]))
|
||||
|
||||
if len(prefix_indices[i]) == 0:
|
||||
prefix_lens.append(0)
|
||||
else:
|
||||
prefix_lens.append(len(prefix_indices[i]))
|
||||
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
||||
: len(prefix_indices[i])
|
||||
] = prefix_indices[i]
|
||||
|
||||
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
||||
|
||||
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
||||
|
||||
# Alloc mem
|
||||
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||
if out_cache_loc is None:
|
||||
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
print("Prefill out of memory.")
|
||||
self.tree_cache.pretty_print()
|
||||
exit()
|
||||
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
||||
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
||||
] = out_cache_loc[pt : pt + extend_lens[i]]
|
||||
pt += extend_lens[i]
|
||||
|
||||
# Handle logit bias
|
||||
logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device)
|
||||
for i in range(bs):
|
||||
if reqs[i].sampling_params.dtype == "int":
|
||||
logit_bias[i] = int_token_logit_bias
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(
|
||||
flatten_input_ids, dtype=torch.int32, device=device
|
||||
)
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_offsets = [
|
||||
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
||||
]
|
||||
self.req_pool_indices = req_pool_indices
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
||||
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||
self.position_ids_offsets = position_ids_offsets
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.out_cache_loc = out_cache_loc
|
||||
|
||||
self.temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
).view(-1, 1)
|
||||
self.top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
||||
).view(-1, 1)
|
||||
self.top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||
).view(-1, 1)
|
||||
self.frequency_penalties = torch.tensor(
|
||||
[r.sampling_params.frequency_penalty for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
self.presence_penalties = torch.tensor(
|
||||
[r.sampling_params.presence_penalty for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
self.logit_bias = logit_bias
|
||||
|
||||
def update_for_decode(self, input_ids=None):
|
||||
if input_ids is None:
|
||||
input_ids = [
|
||||
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
|
||||
]
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
||||
self.seq_lens.add_(1)
|
||||
self.prefix_lens = None
|
||||
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
|
||||
if alloc_res is None:
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
||||
|
||||
if self.out_cache_loc is None:
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
||||
|
||||
if self.out_cache_loc is None:
|
||||
print("Decode out of memory.")
|
||||
self.tree_cache.pretty_print()
|
||||
exit()
|
||||
|
||||
self.out_cache_cont_start = None
|
||||
self.out_cache_cont_end = None
|
||||
else:
|
||||
self.out_cache_loc = alloc_res[0]
|
||||
self.out_cache_cont_start = alloc_res[1]
|
||||
self.out_cache_cont_end = alloc_res[2]
|
||||
|
||||
self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices, self.seq_lens - 1
|
||||
] = self.out_cache_loc
|
||||
|
||||
def filter_batch(self, unfinished_indices: List[int]):
|
||||
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
||||
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
self.input_ids = None
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.prefix_lens = None
|
||||
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
||||
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"frequency_penalties",
|
||||
"presence_penalties",
|
||||
"logit_bias",
|
||||
]:
|
||||
setattr(self, item, getattr(self, item)[new_indices])
|
||||
|
||||
def merge(self, other):
|
||||
self.reqs.extend(other.reqs)
|
||||
|
||||
self.req_pool_indices = torch.concat(
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
||||
self.prefix_lens = None
|
||||
self.position_ids_offsets = torch.concat(
|
||||
[self.position_ids_offsets, other.position_ids_offsets]
|
||||
)
|
||||
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"frequency_penalties",
|
||||
"presence_penalties",
|
||||
"logit_bias",
|
||||
]:
|
||||
setattr(
|
||||
self, item, torch.concat([getattr(self, item), getattr(other, item)])
|
||||
)
|
||||
|
||||
def sample(self, logits: torch.Tensor):
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
logits.div_(self.temperatures)
|
||||
logits.add_(self.logit_bias)
|
||||
|
||||
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
||||
if has_regex:
|
||||
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.regex_fsm is not None:
|
||||
allowed_mask.zero_()
|
||||
allowed_mask[
|
||||
req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
|
||||
] = 1
|
||||
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
||||
|
||||
# TODO(lmzheng): apply penalty
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
||||
-1
|
||||
)
|
||||
batch_next_token_probs = torch.gather(
|
||||
probs_sort, dim=1, index=sampled_index
|
||||
).view(-1)
|
||||
|
||||
if has_regex:
|
||||
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.regex_fsm is not None:
|
||||
req.regex_fsm_state = req.regex_fsm.next_state(
|
||||
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
||||
)
|
||||
|
||||
return batch_next_token_ids, batch_next_token_probs
|
||||
|
||||
|
||||
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
|
||||
probs_sort[
|
||||
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
|
||||
] = 0.0
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
return probs_sort, probs_idx
|
||||
71
python/sglang/srt/managers/router/manager.py
Normal file
71
python/sglang/srt/managers/router/manager.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_exception_traceback
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
class RouterManager:
|
||||
def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
|
||||
# Init communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
||||
|
||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_detokenizer.connect(
|
||||
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
||||
)
|
||||
|
||||
# Init status
|
||||
self.model_client = model_client
|
||||
self.recv_reqs = []
|
||||
|
||||
async def loop_for_forward(self):
|
||||
while True:
|
||||
next_step_input = list(self.recv_reqs)
|
||||
self.recv_reqs = []
|
||||
out_pyobjs = await self.model_client.step(next_step_input)
|
||||
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
# await for a while to accept input requests
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
async def loop_for_recv_requests(self):
|
||||
while True:
|
||||
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
||||
self.recv_reqs.append(recv_req)
|
||||
|
||||
|
||||
def start_router_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
):
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
model_client = ModelRpcClient(server_args, port_args)
|
||||
router = RouterManager(model_client, port_args)
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
|
||||
pipe_writer.send("init ok")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.create_task(router.loop_for_recv_requests())
|
||||
loop.run_until_complete(router.loop_for_forward())
|
||||
497
python/sglang/srt/managers/router/model_rpc.py
Normal file
497
python/sglang/srt/managers/router/model_rpc.py
Normal file
@@ -0,0 +1,497 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import rpyc
|
||||
import torch
|
||||
from rpyc.utils.classic import obtain
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
from sglang.srt.managers.router.scheduler import Scheduler
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_exception_traceback,
|
||||
get_int_token_logit_bias,
|
||||
is_multimodal_model,
|
||||
set_random_seed,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("model_rpc")
|
||||
|
||||
|
||||
class ModelRpcServer(rpyc.Service):
|
||||
def exposed_init_model(
|
||||
self,
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
||||
|
||||
# Copy arguments
|
||||
self.model_mode = server_args.model_mode
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = server_args.tp_size
|
||||
self.schedule_heuristic = server_args.schedule_heuristic
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path, server_args.trust_remote_code
|
||||
)
|
||||
self.model_runner = ModelRunner(
|
||||
self.model_config,
|
||||
server_args.mem_fraction_static,
|
||||
tp_rank,
|
||||
server_args.tp_size,
|
||||
port_args.nccl_port,
|
||||
server_args.load_format,
|
||||
server_args.trust_remote_code,
|
||||
server_args.model_mode,
|
||||
)
|
||||
if is_multimodal_model(server_args.model_path):
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.max_total_num_token = self.model_runner.max_total_num_token
|
||||
self.max_num_running_seq = self.max_total_num_token // 2
|
||||
self.max_prefill_num_token = max(
|
||||
self.model_config.context_len, self.max_total_num_token // 6
|
||||
)
|
||||
self.int_token_logit_bias = torch.tensor(
|
||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||
)
|
||||
set_random_seed(server_args.random_seed)
|
||||
logger.info(
|
||||
f"Rank {self.tp_rank}: "
|
||||
f"max_total_num_token={self.max_total_num_token}, "
|
||||
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
||||
f"context_len={self.model_config.context_len}, "
|
||||
f"model_mode={self.model_mode}"
|
||||
)
|
||||
|
||||
# Init cache
|
||||
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
|
||||
self.scheduler = Scheduler(
|
||||
self.schedule_heuristic,
|
||||
self.max_num_running_seq,
|
||||
self.max_prefill_num_token,
|
||||
self.max_total_num_token,
|
||||
self.tree_cache,
|
||||
)
|
||||
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
||||
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
||||
|
||||
# Init running status
|
||||
self.forward_queue: List[Req] = []
|
||||
self.running_batch: Batch = None
|
||||
self.out_pyobjs = []
|
||||
self.decode_forward_ct = 0
|
||||
self.stream_interval = 2
|
||||
|
||||
# Init the FSM cache for constrained generation
|
||||
self.regex_fsm_cache = FSMCache(self.tokenizer)
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
if self.tp_size != 1:
|
||||
recv_reqs = obtain(recv_reqs)
|
||||
|
||||
try:
|
||||
# Recv requests
|
||||
for recv_req in recv_reqs:
|
||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
self.handle_generate_request(recv_req)
|
||||
else:
|
||||
raise ValueError(f"Invalid request: {recv_req}")
|
||||
|
||||
# Forward
|
||||
self.forward_step()
|
||||
except Exception:
|
||||
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
|
||||
|
||||
# Return results
|
||||
ret = self.out_pyobjs
|
||||
self.out_pyobjs = []
|
||||
return ret
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_step(self):
|
||||
new_batch = self.get_new_fill_batch()
|
||||
|
||||
if new_batch is not None:
|
||||
# Run new fill batch
|
||||
self.forward_fill_batch(new_batch)
|
||||
|
||||
if not new_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
self.running_batch = new_batch
|
||||
else:
|
||||
self.running_batch.merge(new_batch)
|
||||
else:
|
||||
# Run decode batch
|
||||
if self.running_batch is not None:
|
||||
# Run a few decode batches continuously for reducing overhead
|
||||
for _ in range(10):
|
||||
self.forward_decode_batch(self.running_batch)
|
||||
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
break
|
||||
|
||||
if self.running_batch is not None and self.tp_rank == 0:
|
||||
if self.decode_forward_ct >= 20:
|
||||
self.decode_forward_ct = 0
|
||||
num_used = self.max_total_num_token - (
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
logger.info(
|
||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_token:.2f}, "
|
||||
f"#queue-req: {len(self.forward_queue)}"
|
||||
)
|
||||
|
||||
def handle_generate_request(
|
||||
self,
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
req = Req(recv_req.rid)
|
||||
req.input_ids = recv_req.input_ids
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
if req.pixel_values is not None:
|
||||
pad_value = [
|
||||
(recv_req.image_hash) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
||||
]
|
||||
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
||||
req.input_ids, pad_value
|
||||
)
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
req.return_normalized_logprob = recv_req.return_normalized_logprob
|
||||
req.normalized_logprob_start_len = recv_req.normalized_logprob_start_len
|
||||
req.stream = recv_req.stream
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
# init the regex fsm
|
||||
if req.sampling_params.regex is not None:
|
||||
req.regex_fsm_state = 0
|
||||
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
|
||||
|
||||
# Truncate long prompts
|
||||
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
req.sampling_params.max_new_tokens,
|
||||
self.model_config.context_len - 1 - len(req.input_ids),
|
||||
)
|
||||
self.forward_queue.append(req)
|
||||
|
||||
def get_new_fill_batch(self):
|
||||
if (
|
||||
self.running_batch is not None
|
||||
and len(self.running_batch.reqs) > self.max_num_running_seq
|
||||
):
|
||||
return None
|
||||
|
||||
for req in self.forward_queue:
|
||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||
if req.return_normalized_logprob:
|
||||
prefix_indices = prefix_indices[: req.normalized_logprob_start_len]
|
||||
req.adjust_input_len = len(req.input_ids) - len(prefix_indices)
|
||||
req.prefix_indices = prefix_indices
|
||||
req.last_node = last_node
|
||||
|
||||
# Get priority queue
|
||||
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
|
||||
|
||||
# Add requests if there is available space
|
||||
can_run_list = []
|
||||
new_batch_total_tokens = 0
|
||||
new_batch_input_tokens = 0
|
||||
new_batch_prefix_tokens = 0
|
||||
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
)
|
||||
new_ratio = self.scheduler.new_token_estimation_ratio()
|
||||
if self.running_batch:
|
||||
available_size -= sum(
|
||||
[
|
||||
(r.max_new_tokens() - len(r.output_ids)) * new_ratio
|
||||
for r in self.running_batch.reqs
|
||||
]
|
||||
)
|
||||
|
||||
for req in self.forward_queue:
|
||||
if req.return_normalized_logprob:
|
||||
# Need at least two tokens to compute normalized logprob
|
||||
if req.adjust_input_len < 2:
|
||||
delta = 2 - req.adjust_input_len
|
||||
req.adjust_input_len += delta
|
||||
req.prefix_indices = req.prefix_indices[:-delta]
|
||||
if req.image_offset is not None:
|
||||
req.image_offset += delta
|
||||
if req.adjust_input_len == 0 and req.max_new_tokens() > 0:
|
||||
# Need at least one token to compute logits
|
||||
req.adjust_input_len = 1
|
||||
req.prefix_indices = req.prefix_indices[:-1]
|
||||
if req.image_offset is not None:
|
||||
req.image_offset += 1
|
||||
|
||||
if (
|
||||
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||
< available_size
|
||||
and req.adjust_input_len + new_batch_input_tokens
|
||||
< self.max_prefill_num_token
|
||||
):
|
||||
delta = self.tree_cache.inc_ref_counter(req.last_node)
|
||||
available_size += delta
|
||||
|
||||
if not (
|
||||
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||
< available_size
|
||||
):
|
||||
delta = self.tree_cache.dec_ref_counter(req.last_node)
|
||||
available_size += delta
|
||||
else:
|
||||
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
||||
can_run_list.append(req)
|
||||
new_batch_total_tokens += (
|
||||
req.adjust_input_len + req.max_new_tokens()
|
||||
)
|
||||
new_batch_input_tokens += req.adjust_input_len
|
||||
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
|
||||
if self.tp_rank == 0:
|
||||
logger.info(
|
||||
f"new fill batch. #seq: {len(can_run_list)}. "
|
||||
f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. "
|
||||
f"#new_token: {new_batch_input_tokens}. "
|
||||
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
|
||||
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
)
|
||||
|
||||
new_batch = Batch(
|
||||
can_run_list,
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.tree_cache,
|
||||
)
|
||||
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
||||
return new_batch
|
||||
|
||||
def forward_fill_batch(self, batch: Batch):
|
||||
# Build batch tensors
|
||||
batch.init_extend_batch(self.model_config.vocab_size, self.int_token_logit_bias)
|
||||
if batch.extend_num_tokens != 0:
|
||||
# Forward
|
||||
logits, normalized_logprobs = self.model_runner.forward(
|
||||
batch, ForwardMode.EXTEND, batch.return_normalized_logprob
|
||||
)
|
||||
# print("extend logits", logits)
|
||||
if normalized_logprobs is not None:
|
||||
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
||||
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
else:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
normalized_logprobs = None
|
||||
|
||||
# Check finish condition
|
||||
reqs = batch.reqs
|
||||
for i in range(len(reqs)):
|
||||
reqs[i].output_ids = [next_token_ids[i]]
|
||||
reqs[i].check_finished()
|
||||
|
||||
if normalized_logprobs is not None:
|
||||
reqs[i].normalized_logprob = normalized_logprobs[i]
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
def forward_decode_batch(self, batch: Batch):
|
||||
# Update batch tensors
|
||||
self.decode_forward_ct += 1
|
||||
batch.update_for_decode()
|
||||
|
||||
# Forward
|
||||
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
|
||||
# Check finish condition
|
||||
reqs = batch.reqs
|
||||
for i in range(len(reqs)):
|
||||
reqs[i].output_ids.append(next_token_ids[i])
|
||||
reqs[i].check_finished()
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
def handle_finished_requests(self, batch: Batch):
|
||||
output_rids = []
|
||||
output_tokens = []
|
||||
output_hit_stop_str = []
|
||||
output_skip_special_tokens = []
|
||||
output_meta_info = []
|
||||
output_finished = []
|
||||
finished_indices = []
|
||||
unfinished_indices = []
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req.finished:
|
||||
finished_indices.append(i)
|
||||
else:
|
||||
unfinished_indices.append(i)
|
||||
|
||||
if req.finished or (
|
||||
req.stream and self.decode_forward_ct % self.stream_interval == 0
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
output_tokens.append(req.output_ids)
|
||||
output_hit_stop_str.append(req.hit_stop_str)
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
)
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.input_ids),
|
||||
"completion_tokens": len(req.output_ids),
|
||||
}
|
||||
if req.return_normalized_logprob:
|
||||
meta_info["normalized_logprob"] = req.normalized_logprob
|
||||
output_meta_info.append(meta_info)
|
||||
output_finished.append(req.finished)
|
||||
|
||||
# Send to detokenizer
|
||||
if output_rids:
|
||||
self.out_pyobjs.append(
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
output_tokens,
|
||||
output_hit_stop_str,
|
||||
output_skip_special_tokens,
|
||||
output_meta_info,
|
||||
output_finished,
|
||||
)
|
||||
)
|
||||
|
||||
# Remove finished reqs
|
||||
if finished_indices:
|
||||
# Update radix cache
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
|
||||
for i in finished_indices:
|
||||
req = batch.reqs[i]
|
||||
req_pool_idx = req_pool_indices_cpu[i]
|
||||
token_ids = tuple(req.input_ids + req.output_ids)
|
||||
seq_len = len(token_ids) - 1
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
|
||||
prefix_len = self.tree_cache.insert(token_ids, indices.clone())
|
||||
|
||||
self.token_to_kv_pool.free(indices[:prefix_len])
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
self.tree_cache.dec_ref_counter(req.last_node)
|
||||
|
||||
# Update batch tensors
|
||||
if unfinished_indices:
|
||||
batch.filter_batch(unfinished_indices)
|
||||
else:
|
||||
batch.reqs = []
|
||||
|
||||
|
||||
class ModelRpcClient:
|
||||
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
|
||||
tp_size = server_args.tp_size
|
||||
|
||||
if tp_size == 1:
|
||||
# Init model
|
||||
self.model_server = ModelRpcServer()
|
||||
self.model_server.exposed_init_model(0, server_args, port_args)
|
||||
|
||||
# Wrap functions
|
||||
def async_wrap(f):
|
||||
async def _func(*args, **kwargs):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return _func
|
||||
|
||||
self.step = async_wrap(self.model_server.exposed_step)
|
||||
else:
|
||||
with ThreadPoolExecutor(tp_size) as executor:
|
||||
# Launch model processes
|
||||
rets = executor.map(start_model_process, port_args.model_rpc_ports)
|
||||
self.model_servers = [x[0] for x in rets]
|
||||
self.procs = [x[1] for x in rets]
|
||||
|
||||
# Init model
|
||||
def init_model(i):
|
||||
return self.model_servers[i].init_model(i, server_args, port_args)
|
||||
|
||||
rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
|
||||
|
||||
# Wrap functions
|
||||
def async_wrap(func_name):
|
||||
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
|
||||
|
||||
async def _func(*args, **kwargs):
|
||||
tasks = [f(*args, **kwargs) for f in fs]
|
||||
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
|
||||
return obtain(tasks[0].value)
|
||||
|
||||
return _func
|
||||
|
||||
self.step = async_wrap("step")
|
||||
|
||||
|
||||
def start_model_process(port):
|
||||
def _init_service(port):
|
||||
t = ThreadedServer(
|
||||
ModelRpcServer(),
|
||||
port=port,
|
||||
protocol_config={"allow_pickle": True, "sync_request_timeout": 600},
|
||||
)
|
||||
t.start()
|
||||
|
||||
proc = multiprocessing.Process(target=_init_service, args=(port,))
|
||||
proc.start()
|
||||
time.sleep(1)
|
||||
|
||||
repeat_count = 0
|
||||
while repeat_count < 20:
|
||||
try:
|
||||
con = rpyc.connect(
|
||||
"localhost",
|
||||
port,
|
||||
config={"allow_pickle": True, "sync_request_timeout": 600},
|
||||
)
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(1)
|
||||
repeat_count += 1
|
||||
if repeat_count == 20:
|
||||
raise RuntimeError("init rpc env error!")
|
||||
|
||||
assert proc.is_alive()
|
||||
return con.root, proc
|
||||
458
python/sglang/srt/managers/router/model_runner.py
Normal file
458
python/sglang/srt/managers/router/model_runner.py
Normal file
@@ -0,0 +1,458 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.utils import is_multimodal_model
|
||||
from sglang.utils import get_available_gpu_memory
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||
|
||||
# for model_mode
|
||||
global_model_mode: List[str] = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
model_runner: "ModelRunner"
|
||||
forward_mode: ForwardMode
|
||||
batch_size: int
|
||||
total_num_tokens: int
|
||||
max_seq_len: int
|
||||
req_pool_indices: torch.Tensor
|
||||
start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
prefix_lens: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: TokenToKVPool
|
||||
|
||||
# for extend
|
||||
extend_seq_lens: torch.Tensor = None
|
||||
extend_start_loc: torch.Tensor = None
|
||||
max_extend_len: int = 0
|
||||
|
||||
out_cache_loc: torch.Tensor = None
|
||||
out_cache_cont_start: torch.Tensor = None
|
||||
out_cache_cont_end: torch.Tensor = None
|
||||
|
||||
other_kv_index: torch.Tensor = None
|
||||
return_normalized_logprob: bool = False
|
||||
|
||||
# for flashinfer
|
||||
use_flashinfer: bool = False
|
||||
qo_indptr: torch.Tensor = None
|
||||
kv_indptr: torch.Tensor = None
|
||||
kv_indices: torch.Tensor = None
|
||||
kv_last_page_len: torch.Tensor = None
|
||||
prefill_wrapper = None
|
||||
decode_wrapper = None
|
||||
|
||||
def init_flashinfer_args(self, tp_size):
|
||||
self.kv_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
||||
self.kv_indices = torch.cat(
|
||||
[
|
||||
self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices[i].item(), : self.seq_lens[i].item()
|
||||
]
|
||||
for i in range(self.batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
from flashinfer.ops import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
|
||||
if (
|
||||
self.forward_mode == ForwardMode.PREFILL
|
||||
or self.forward_mode == ForwardMode.EXTEND
|
||||
):
|
||||
self.qo_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper()
|
||||
self.prefill_wrapper.begin_forward(
|
||||
self.qo_indptr,
|
||||
self.batch_size,
|
||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||
)
|
||||
else:
|
||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper()
|
||||
self.decode_wrapper.begin_forward(
|
||||
self.kv_indptr,
|
||||
self.kv_last_page_len,
|
||||
self.batch_size,
|
||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||
self.model_runner.model_config.head_dim,
|
||||
1,
|
||||
"NONE",
|
||||
"float16",
|
||||
)
|
||||
|
||||
def init_extend_args(self):
|
||||
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0)
|
||||
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_runner,
|
||||
tp_size,
|
||||
forward_mode,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
out_cache_cont_start=None,
|
||||
out_cache_cont_end=None,
|
||||
return_normalized_logprob=False,
|
||||
):
|
||||
batch_size = len(req_pool_indices)
|
||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
max_seq_len = int(torch.max(seq_lens))
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
||||
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices[0], seq_lens[0] - 1
|
||||
].item()
|
||||
else:
|
||||
seq_lens_np = seq_lens.cpu().numpy()
|
||||
prefix_lens_np = prefix_lens.cpu().numpy()
|
||||
position_ids_offsets_np = position_ids_offsets.cpu().numpy()
|
||||
positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(
|
||||
prefix_lens_np[i] + position_ids_offsets_np[i],
|
||||
seq_lens_np[i] + position_ids_offsets_np[i],
|
||||
)
|
||||
for i in range(batch_size)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device="cuda",
|
||||
)
|
||||
other_kv_index = None
|
||||
|
||||
ret = cls(
|
||||
model_runner=model_runner,
|
||||
forward_mode=forward_mode,
|
||||
batch_size=batch_size,
|
||||
total_num_tokens=total_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
req_pool_indices=req_pool_indices,
|
||||
start_loc=start_loc,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=prefix_lens,
|
||||
positions=positions,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=out_cache_cont_start,
|
||||
out_cache_cont_end=out_cache_cont_end,
|
||||
return_normalized_logprob=return_normalized_logprob,
|
||||
other_kv_index=other_kv_index,
|
||||
)
|
||||
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
ret.init_extend_args()
|
||||
|
||||
ret.use_flashinfer = "flashinfer" in model_runner.model_mode
|
||||
if ret.use_flashinfer:
|
||||
ret.init_flashinfer_args(tp_size)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_config,
|
||||
mem_fraction_static,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
nccl_port,
|
||||
load_format="auto",
|
||||
trust_remote_code=True,
|
||||
model_mode: List[str] = (),
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.mem_fraction_static = mem_fraction_static
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.nccl_port = nccl_port
|
||||
self.load_format = load_format
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.model_mode = model_mode
|
||||
|
||||
global global_model_mode
|
||||
global_model_mode = model_mode
|
||||
|
||||
# Init torch distributed
|
||||
torch.cuda.set_device(self.tp_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=self.tp_size,
|
||||
rank=self.tp_rank,
|
||||
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
||||
)
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
if self.tp_size > 1:
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
|
||||
total_gpu_memory = get_available_gpu_memory(
|
||||
self.tp_rank, distributed=self.tp_size > 1
|
||||
) * (1 << 30)
|
||||
self.load_model()
|
||||
self.init_memory_pool(total_gpu_memory)
|
||||
|
||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
||||
|
||||
def load_model(self):
|
||||
"""See also vllm/model_executor/model_loader.py::get_model"""
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.llava import LlavaLlamaForCausalLM
|
||||
from sglang.srt.models.mixtral import MixtralForCausalLM
|
||||
|
||||
# Select model class
|
||||
architectures = getattr(self.model_config.hf_config, "architectures", [])
|
||||
|
||||
model_class = None
|
||||
for arch in architectures:
|
||||
if arch == "LlamaForCausalLM":
|
||||
model_class = LlamaForCausalLM
|
||||
break
|
||||
if arch == "MistralForCausalLM":
|
||||
model_class = LlamaForCausalLM
|
||||
break
|
||||
if arch == "LlavaLlamaForCausalLM":
|
||||
model_class = LlavaLlamaForCausalLM
|
||||
break
|
||||
if arch == "MixtralForCausalLM":
|
||||
model_class = MixtralForCausalLM
|
||||
break
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported architectures: {architectures}")
|
||||
|
||||
# Load weights
|
||||
linear_method = None
|
||||
with _set_default_torch_dtype(torch.float16):
|
||||
with torch.device("cuda"):
|
||||
hf_quant_config = getattr(
|
||||
self.model_config.hf_config, "quantization_config", None
|
||||
)
|
||||
if hf_quant_config is not None:
|
||||
# TODO: config quantization awq etc
|
||||
quant_config = AWQConfig.from_config(hf_quant_config)
|
||||
print(f"quant_config: {quant_config}")
|
||||
linear_method = quant_config.get_linear_method()
|
||||
model = model_class(
|
||||
config=self.model_config.hf_config, linear_method=linear_method
|
||||
)
|
||||
model.load_weights(
|
||||
self.model_config.path,
|
||||
cache_dir=None,
|
||||
load_format=self.load_format,
|
||||
revision=None,
|
||||
)
|
||||
self.model = model
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory):
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.tp_rank, distributed=self.tp_size > 1
|
||||
) * (1 << 30)
|
||||
head_dim = (
|
||||
self.model_config.hidden_size // self.model_config.num_attention_heads
|
||||
)
|
||||
head_num = self.model_config.num_key_value_heads // self.tp_size
|
||||
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
)
|
||||
max_num_token = int(rest_memory // cell_size)
|
||||
return max_num_token
|
||||
|
||||
def init_memory_pool(self, total_gpu_memory):
|
||||
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
int(self.max_total_num_token / self.model_config.context_len * 256),
|
||||
self.model_config.context_len + 8,
|
||||
)
|
||||
self.token_to_kv_pool = TokenToKVPool(
|
||||
self.max_total_num_token,
|
||||
dtype=torch.float16,
|
||||
head_num=self.model_config.num_key_value_heads // self.tp_size,
|
||||
head_dim=self.model_config.hidden_size
|
||||
// self.model_config.num_attention_heads,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_prefill(
|
||||
self,
|
||||
input_ids,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
return_normalized_logprob,
|
||||
):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.PREFILL,
|
||||
tp_size=self.tp_size,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=prefix_lens,
|
||||
position_ids_offsets=position_ids_offsets,
|
||||
out_cache_loc=out_cache_loc,
|
||||
return_normalized_logprob=return_normalized_logprob,
|
||||
)
|
||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend(
|
||||
self,
|
||||
input_ids,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
return_normalized_logprob,
|
||||
):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
tp_size=self.tp_size,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=prefix_lens,
|
||||
position_ids_offsets=position_ids_offsets,
|
||||
out_cache_loc=out_cache_loc,
|
||||
return_normalized_logprob=return_normalized_logprob,
|
||||
)
|
||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(
|
||||
self,
|
||||
input_ids,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
out_cache_cont_start,
|
||||
out_cache_cont_end,
|
||||
):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
tp_size=self.tp_size,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=prefix_lens,
|
||||
position_ids_offsets=position_ids_offsets,
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=out_cache_cont_start,
|
||||
out_cache_cont_end=out_cache_cont_end,
|
||||
)
|
||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)[
|
||||
0
|
||||
]
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(
|
||||
self,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_offsets,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
return_normalized_logprob,
|
||||
):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
tp_size=self.tp_size,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=prefix_lens,
|
||||
position_ids_offsets=position_ids_offsets,
|
||||
out_cache_loc=out_cache_loc,
|
||||
return_normalized_logprob=return_normalized_logprob,
|
||||
)
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
input_metadata.positions,
|
||||
input_metadata,
|
||||
pixel_values,
|
||||
image_offsets,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, batch: Batch, forward_mode: ForwardMode, return_normalized_logprob=False
|
||||
):
|
||||
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
||||
kwargs = {
|
||||
"input_ids": batch.input_ids,
|
||||
"pixel_values": batch.pixel_values,
|
||||
"image_offsets": batch.image_offsets,
|
||||
"req_pool_indices": batch.req_pool_indices,
|
||||
"seq_lens": batch.seq_lens,
|
||||
"prefix_lens": batch.prefix_lens,
|
||||
"position_ids_offsets": batch.position_ids_offsets,
|
||||
"out_cache_loc": batch.out_cache_loc,
|
||||
}
|
||||
kwargs["return_normalized_logprob"] = return_normalized_logprob
|
||||
return self.forward_extend_multi_modal(**kwargs)
|
||||
else:
|
||||
kwargs = {
|
||||
"input_ids": batch.input_ids,
|
||||
"req_pool_indices": batch.req_pool_indices,
|
||||
"seq_lens": batch.seq_lens,
|
||||
"prefix_lens": batch.prefix_lens,
|
||||
"position_ids_offsets": batch.position_ids_offsets,
|
||||
"out_cache_loc": batch.out_cache_loc,
|
||||
}
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
|
||||
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
||||
return self.forward_decode(**kwargs)
|
||||
elif forward_mode == ForwardMode.EXTEND:
|
||||
kwargs["return_normalized_logprob"] = return_normalized_logprob
|
||||
return self.forward_extend(**kwargs)
|
||||
elif forward_mode == ForwardMode.PREFILL:
|
||||
kwargs["return_normalized_logprob"] = return_normalized_logprob
|
||||
return self.forward_prefill(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
||||
220
python/sglang/srt/managers/router/radix_cache.py
Normal file
220
python/sglang/srt/managers/router/radix_cache.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TreeNode:
|
||||
def __init__(self):
|
||||
self.children = defaultdict(TreeNode)
|
||||
self.parent = None
|
||||
self.value = None
|
||||
self.ref_counter = 0
|
||||
self.last_access_time = time.time()
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
def match(key, seq):
|
||||
i = 0
|
||||
for k, w in zip(key, seq):
|
||||
if k != w:
|
||||
break
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
class RadixCache:
|
||||
def __init__(self, disable=False):
|
||||
self.root_node = TreeNode()
|
||||
self.root_node.value = []
|
||||
self.root_node.ref_counter = 1
|
||||
self.evictable_size_ = 0
|
||||
|
||||
self.disable = disable
|
||||
|
||||
##### Public API #####
|
||||
def match_prefix(self, key):
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
|
||||
value = []
|
||||
last_node = [self.root_node]
|
||||
self._match_prefix_helper(self.root_node, key, value, last_node)
|
||||
if value:
|
||||
value = torch.concat(value)
|
||||
return value, last_node[0]
|
||||
|
||||
def insert(self, key, value=None):
|
||||
if self.disable:
|
||||
return len(key)
|
||||
|
||||
if value is None:
|
||||
value = [x for x in key]
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def pretty_print(self):
|
||||
self._print_helper(self.root_node, 0)
|
||||
print(f"#tokens: {self.total_size()}")
|
||||
|
||||
def total_size(self):
|
||||
return self._total_size_helper(self.root_node)
|
||||
|
||||
def evict(self, num_tokens, evict_callback):
|
||||
if self.disable:
|
||||
raise RuntimeError()
|
||||
|
||||
leaves = self._collect_leaves()
|
||||
heapq.heapify(leaves)
|
||||
|
||||
num_evicted = 0
|
||||
while num_evicted < num_tokens and len(leaves):
|
||||
x = heapq.heappop(leaves)
|
||||
|
||||
if x == self.root_node:
|
||||
break
|
||||
if x.ref_counter > 0:
|
||||
continue
|
||||
|
||||
num_evicted += evict_callback(x.value)
|
||||
self._delete_leaf(x)
|
||||
|
||||
if len(x.parent.children) == 0:
|
||||
heapq.heappush(leaves, x.parent)
|
||||
|
||||
def inc_ref_counter(self, node):
|
||||
delta = 0
|
||||
while node != self.root_node:
|
||||
if node.ref_counter == 0:
|
||||
self.evictable_size_ -= len(node.value)
|
||||
delta -= len(node.value)
|
||||
node.ref_counter += 1
|
||||
node = node.parent
|
||||
return delta
|
||||
|
||||
def dec_ref_counter(self, node):
|
||||
delta = 0
|
||||
while node != self.root_node:
|
||||
if node.ref_counter == 1:
|
||||
self.evictable_size_ += len(node.value)
|
||||
delta += len(node.value)
|
||||
node.ref_counter -= 1
|
||||
node = node.parent
|
||||
return delta
|
||||
|
||||
def evictable_size(self):
|
||||
return self.evictable_size_
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
def _match_prefix_helper(self, node, key, value, last_node):
|
||||
node.last_access_time = time.time()
|
||||
|
||||
for c_key, child in node.children.items():
|
||||
prefix_len = match(c_key, key)
|
||||
if prefix_len != 0:
|
||||
if prefix_len == len(key) and prefix_len != len(c_key):
|
||||
new_node = self._split_node(c_key, child, prefix_len)
|
||||
value.append(new_node.value)
|
||||
last_node[0] = new_node
|
||||
else:
|
||||
value.append(child.value[:prefix_len])
|
||||
last_node[0] = child
|
||||
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
||||
break
|
||||
|
||||
def _split_node(self, key, child, split_len):
|
||||
# new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {key[split_len:]: child}
|
||||
new_node.parent = child.parent
|
||||
new_node.ref_counter = child.ref_counter
|
||||
new_node.value = child.value[:split_len]
|
||||
child.parent = new_node
|
||||
child.value = child.value[split_len:]
|
||||
new_node.parent.children[key[:split_len]] = new_node
|
||||
del new_node.parent.children[key]
|
||||
return new_node
|
||||
|
||||
def _insert_helper(self, node, key, value):
|
||||
node.last_access_time = time.time()
|
||||
|
||||
for c_key, child in node.children.items():
|
||||
prefix_len = match(c_key, key)
|
||||
|
||||
if prefix_len == len(c_key):
|
||||
if prefix_len == len(key):
|
||||
return prefix_len
|
||||
else:
|
||||
key = key[prefix_len:]
|
||||
value = value[prefix_len:]
|
||||
return prefix_len + self._insert_helper(child, key, value)
|
||||
|
||||
if prefix_len:
|
||||
new_node = self._split_node(c_key, child, prefix_len)
|
||||
return prefix_len + self._insert_helper(
|
||||
new_node, key[prefix_len:], value[prefix_len:]
|
||||
)
|
||||
|
||||
if len(key):
|
||||
new_node = TreeNode()
|
||||
new_node.parent = node
|
||||
new_node.value = value
|
||||
node.children[key] = new_node
|
||||
self.evictable_size_ += len(value)
|
||||
return 0
|
||||
|
||||
def _print_helper(self, node, indent):
|
||||
for key, child in node.children.items():
|
||||
print(" " * indent, len(key), key[:10], f"r={child.ref_counter}")
|
||||
self._print_helper(child, indent=indent + 2)
|
||||
|
||||
def _delete_leaf(self, node):
|
||||
for k, v in node.parent.children.items():
|
||||
if v == node:
|
||||
break
|
||||
del node.parent.children[k]
|
||||
self.evictable_size_ -= len(k)
|
||||
|
||||
def _total_size_helper(self, node):
|
||||
x = len(node.value)
|
||||
for child in node.children.values():
|
||||
x += self._total_size_helper(child)
|
||||
return x
|
||||
|
||||
def _collect_leaves(self):
|
||||
ret_list = []
|
||||
|
||||
def dfs_(cur_node):
|
||||
if len(cur_node.children) == 0:
|
||||
ret_list.append(cur_node)
|
||||
|
||||
for x in cur_node.children.values():
|
||||
dfs_(x)
|
||||
|
||||
dfs_(self.root_node)
|
||||
return ret_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tree = RadixCache(disable=False)
|
||||
|
||||
tree.insert("Hello")
|
||||
tree.insert("Hello")
|
||||
tree.insert("Hello_L.A.!")
|
||||
# tree.insert("Hello_world! Happy")
|
||||
# tree.insert("I love you!")
|
||||
tree.pretty_print()
|
||||
|
||||
# print(tree.match_prefix("I love you! aha"))
|
||||
|
||||
# def evict_callback(x):
|
||||
# print("evict", x)
|
||||
# return len(x)
|
||||
|
||||
# tree.evict(5, evict_callback)
|
||||
# tree.evict(10, evict_callback)
|
||||
# tree.pretty_print()
|
||||
73
python/sglang/srt/managers/router/scheduler.py
Normal file
73
python/sglang/srt/managers/router/scheduler.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class Scheduler:
|
||||
def __init__(
|
||||
self,
|
||||
schedule_heuristic,
|
||||
max_running_seq,
|
||||
max_prefill_num_token,
|
||||
max_total_num_token,
|
||||
tree_cache,
|
||||
):
|
||||
self.schedule_heuristic = schedule_heuristic
|
||||
self.max_running_seq = max_running_seq
|
||||
self.max_prefill_num_token = max_prefill_num_token
|
||||
self.max_total_num_token = max_total_num_token
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
def new_token_estimation_ratio(self):
|
||||
return 0.4 if self.schedule_heuristic != "fcfs" else 0.5
|
||||
|
||||
def get_priority_queue(self, forward_queue):
|
||||
if self.schedule_heuristic == "lpm":
|
||||
# longest prefix match
|
||||
forward_queue.sort(key=lambda x: -len(x.prefix_indices))
|
||||
return forward_queue
|
||||
elif self.schedule_heuristic == "random":
|
||||
random.shuffle(forward_queue)
|
||||
return forward_queue
|
||||
elif self.schedule_heuristic == "fcfs":
|
||||
return forward_queue
|
||||
elif self.schedule_heuristic == "weight":
|
||||
last_node_to_reqs = defaultdict(list)
|
||||
for req in forward_queue:
|
||||
last_node_to_reqs[req.last_node].append(req)
|
||||
for node in last_node_to_reqs:
|
||||
last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
|
||||
|
||||
node_to_weight = defaultdict(int)
|
||||
self._calc_weight_recursive(
|
||||
self.tree_cache.root_node, last_node_to_reqs, node_to_weight
|
||||
)
|
||||
|
||||
tmp_queue = []
|
||||
self._get_weight_priority_recursive(
|
||||
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue
|
||||
)
|
||||
assert len(tmp_queue) == len(forward_queue)
|
||||
return tmp_queue
|
||||
else:
|
||||
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
|
||||
|
||||
def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight):
|
||||
node_to_weight[cur_node] = 1
|
||||
if cur_node in last_node_to_reqs:
|
||||
node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
|
||||
for child in cur_node.children.values():
|
||||
self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight)
|
||||
node_to_weight[cur_node] += node_to_weight[child]
|
||||
|
||||
def _get_weight_priority_recursive(
|
||||
self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue
|
||||
):
|
||||
visit_list = [child for child in cur_node.children.values()]
|
||||
visit_list.sort(key=lambda x: -node_to_wight[x])
|
||||
# for node in visit_list:
|
||||
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
|
||||
for child in visit_list:
|
||||
self._get_weight_priority_recursive(
|
||||
child, node_to_wight, last_node_to_reqs, tmp_queue
|
||||
)
|
||||
tmp_queue.extend(last_node_to_reqs[cur_node])
|
||||
219
python/sglang/srt/managers/tokenizer_manager.py
Normal file
219
python/sglang/srt/managers/tokenizer_manager.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import transformers
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_config,
|
||||
get_context_length,
|
||||
get_processor,
|
||||
get_tokenizer,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchStrOut,
|
||||
GenerateReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReqState:
|
||||
out_list: List
|
||||
finished: bool
|
||||
event: asyncio.Event
|
||||
lock: asyncio.Lock
|
||||
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def init_global_processor(server_args: ServerArgs):
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
def get_pixel_values(image_data, processor=None):
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image = load_image(image_data)
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash
|
||||
except Exception:
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
|
||||
class TokenizerManager:
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
|
||||
self.send_to_router = context.socket(zmq.PUSH)
|
||||
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}")
|
||||
|
||||
self.model_path = server_args.model_path
|
||||
self.hf_config = get_config(
|
||||
self.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
|
||||
if is_multimodal_model(self.model_path):
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor, initargs=(server_args,)
|
||||
)
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
self.to_create_loop = True
|
||||
self.rid_to_state = {} # Dict[str -> ReqState]
|
||||
|
||||
async def get_pixel_values(self, image_data):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor, get_pixel_values, image_data
|
||||
)
|
||||
else:
|
||||
return get_pixel_values(image_data, self.processor)
|
||||
|
||||
async def generate_request(self, obj: GenerateReqInput):
|
||||
if self.to_create_loop:
|
||||
await self.create_handle_loop()
|
||||
|
||||
is_single = isinstance(obj.text, str)
|
||||
|
||||
if is_single:
|
||||
rid = obj.rid
|
||||
input_ids = self.tokenizer.encode(obj.text)
|
||||
sampling_params = SamplingParams(**obj.sampling_params)
|
||||
if sampling_params.max_new_tokens != 0:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
if obj.image_data is None:
|
||||
pixel_values, image_hash = None, None
|
||||
else:
|
||||
pixel_values, image_hash = await self.get_pixel_values(obj.image_data)
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hash=image_hash,
|
||||
sampling_params=sampling_params,
|
||||
return_normalized_logprob=obj.return_normalized_logprob,
|
||||
normalized_logprob_start_len=obj.normalized_logprob_start_len,
|
||||
stream=obj.stream,
|
||||
)
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
|
||||
lock = asyncio.Lock()
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event, lock)
|
||||
self.rid_to_state[rid] = state
|
||||
|
||||
while True:
|
||||
await event.wait()
|
||||
yield state.out_list[-1]
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
del self.rid_to_state[rid]
|
||||
break
|
||||
event.clear()
|
||||
else:
|
||||
assert obj.stream is False
|
||||
bs = len(obj.text)
|
||||
for i in range(bs):
|
||||
rid = obj.rid[i]
|
||||
input_ids = self.tokenizer.encode(obj.text[i])
|
||||
sampling_params = SamplingParams(**obj.sampling_params[i])
|
||||
if sampling_params.max_new_tokens != 0:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
if obj.image_data[i] is None:
|
||||
pixel_values, image_hash = None, None
|
||||
else:
|
||||
pixel_values, image_hash = await self.get_pixel_values(
|
||||
obj.image_data[i]
|
||||
)
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hash=image_hash,
|
||||
sampling_params=sampling_params,
|
||||
return_normalized_logprob=obj.return_normalized_logprob[i],
|
||||
normalized_logprob_start_len=obj.normalized_logprob_start_len[i],
|
||||
stream=obj.stream,
|
||||
)
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
|
||||
lock = asyncio.Lock()
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event, lock)
|
||||
self.rid_to_state[rid] = state
|
||||
|
||||
output_list = []
|
||||
for i in range(bs):
|
||||
rid = obj.rid[i]
|
||||
state = self.rid_to_state[rid]
|
||||
await state.event.wait()
|
||||
output_list.append(state.out_list[-1])
|
||||
assert state.finished
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
yield output_list
|
||||
|
||||
async def create_handle_loop(self):
|
||||
self.to_create_loop = False
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(self.handle_loop())
|
||||
|
||||
async def handle_loop(self):
|
||||
while True:
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
recv_obj.meta_info[i]["id"] = rid
|
||||
out_dict = {
|
||||
"text": recv_obj.output_str[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
state = self.rid_to_state[rid]
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished[i]
|
||||
state.event.set()
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj}")
|
||||
111
python/sglang/srt/memory_pool.py
Normal file
111
python/sglang/srt/memory_pool.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Memory pool."""
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReqToTokenPool:
|
||||
def __init__(self, size, max_context_len):
|
||||
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
|
||||
self.can_use_mem_size = size
|
||||
self.req_to_token = torch.empty(
|
||||
(size, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
def alloc(self, need_size):
|
||||
if need_size > self.can_use_mem_size:
|
||||
return None
|
||||
|
||||
select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size]
|
||||
self.mem_state[select_index] = 0
|
||||
self.can_use_mem_size -= need_size
|
||||
return select_index.to(torch.int32)
|
||||
|
||||
def free(self, free_index):
|
||||
if isinstance(free_index, (int,)):
|
||||
self.can_use_mem_size += 1
|
||||
else:
|
||||
self.can_use_mem_size += free_index.shape[0]
|
||||
self.mem_state[free_index] = 1
|
||||
|
||||
# if self.can_use_mem_size == len(self.mem_state):
|
||||
# print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.")
|
||||
|
||||
def clear(self):
|
||||
self.mem_state.fill_(1)
|
||||
self.can_use_mem_size = len(self.mem_state)
|
||||
|
||||
|
||||
class TokenToKVPool:
|
||||
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
||||
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
|
||||
self.alloc_ct = 0
|
||||
|
||||
# [size, key/value, head_num, head_dim] for each layer
|
||||
self.kv_data = [
|
||||
torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_key_buffer(self, layer_id):
|
||||
return self.kv_data[layer_id][:, 0]
|
||||
|
||||
def get_value_buffer(self, layer_id):
|
||||
return self.kv_data[layer_id][:, 1]
|
||||
|
||||
def alloc(self, need_size):
|
||||
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
|
||||
if select_index.shape[0] < need_size:
|
||||
return None
|
||||
|
||||
self.add_refs(select_index)
|
||||
return select_index.to(torch.int32)
|
||||
|
||||
def alloc_contiguous(self, need_size):
|
||||
empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
|
||||
if empty_index.shape[0] < need_size:
|
||||
return None
|
||||
empty_size = len(empty_index)
|
||||
loc_sum = (
|
||||
empty_index[need_size - 1 :] - empty_index[: empty_size - (need_size - 1)]
|
||||
)
|
||||
can_used_loc = empty_index[: empty_size - (need_size - 1)][
|
||||
loc_sum == need_size - 1
|
||||
]
|
||||
if can_used_loc.shape[0] == 0:
|
||||
return None
|
||||
|
||||
start_loc = can_used_loc[0].item()
|
||||
select_index = torch.arange(start_loc, start_loc + need_size, device="cuda")
|
||||
self.add_refs(select_index)
|
||||
return select_index.to(torch.int32), start_loc, start_loc + need_size
|
||||
|
||||
def free(self, free_index):
|
||||
return self.decrease_refs(free_index)
|
||||
|
||||
def used_size(self):
|
||||
return len(torch.nonzero(self.mem_state).squeeze(1))
|
||||
|
||||
def available_size(self):
|
||||
return torch.sum(self.mem_state == 0).item()
|
||||
|
||||
def add_refs(self, token_index: torch.Tensor):
|
||||
self.alloc_ct += len(token_index)
|
||||
self.mem_state[token_index] += 1
|
||||
|
||||
def decrease_refs(self, token_index: torch.Tensor):
|
||||
self.alloc_ct -= len(token_index)
|
||||
self.mem_state[token_index] -= 1
|
||||
|
||||
num_freed = torch.sum(self.mem_state[token_index] == 0)
|
||||
|
||||
# if self.alloc_ct == 0:
|
||||
# print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.")
|
||||
|
||||
return num_freed
|
||||
|
||||
def clear(self):
|
||||
self.mem_state.fill_(0)
|
||||
self.alloc_ct = 0
|
||||
27
python/sglang/srt/model_config.py
Normal file
27
python/sglang/srt/model_config.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
trust_remote_code: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.revision = revision
|
||||
self.hf_config = get_config(self.path, trust_remote_code, revision)
|
||||
|
||||
# Unify the config keys for hf_config
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
self.num_key_value_heads = self.hf_config.num_key_value_heads
|
||||
self.num_attention_heads = self.hf_config.num_attention_heads
|
||||
self.hidden_size = self.hf_config.hidden_size
|
||||
self.num_hidden_layers = self.hf_config.num_hidden_layers
|
||||
self.vocab_size = self.hf_config.vocab_size
|
||||
316
python/sglang/srt/models/llama2.py
Normal file
316
python/sglang/srt/models/llama2.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, linear_method=linear_method
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
layer_id: int = 0,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
layer_id: int = 0,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.self_attn = LlamaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LlamaDecoderLayer(config, i, linear_method)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
skip_embed: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if not skip_embed:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_ids
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = LlamaModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
skip_embed: bool = False,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
):
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
continue
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
213
python/sglang/srt/models/llava.py
Normal file
213
python/sglang/srt/models/llava.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from torch import nn
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel, LlavaConfig
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
|
||||
class LlavaLlamaForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||
self.config.text_config.hidden_size = config.hidden_size
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = LlamaForCausalLM(config, linear_method)
|
||||
|
||||
def pad_input_ids(self, input_ids, pad_value):
|
||||
pad_ids = pad_value * (
|
||||
(self.image_feature_len + len(pad_value)) // len(pad_value)
|
||||
)
|
||||
offset = input_ids.index(self.config.image_token_index)
|
||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||
new_input_ids = (
|
||||
input_ids[:offset]
|
||||
+ pad_ids[: self.image_feature_len]
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
return new_input_ids, offset
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
pixel_values: Optional[List[Optional[np.array]]] = None,
|
||||
image_offsets: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
bs = input_metadata.batch_size
|
||||
|
||||
# Embed text input
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
# Embed vision input
|
||||
need_vision = (
|
||||
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
# FIXME: We need to substract the length of the system prompt
|
||||
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
||||
need_vision = need_vision & has_pixel
|
||||
|
||||
if need_vision.any():
|
||||
pixel_values = torch.tensor(
|
||||
np.array([pixel_values[i] for i in range(bs) if need_vision[i]]),
|
||||
device=self.vision_tower.device,
|
||||
)
|
||||
|
||||
image_outputs = self.vision_tower(
|
||||
pixel_values, output_hidden_states=True
|
||||
)
|
||||
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
||||
|
||||
selected_image_feature = image_outputs.hidden_states[
|
||||
self.vision_feature_layer
|
||||
]
|
||||
if self.vision_feature_select_strategy in ["default", "patch"]:
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
||||
)
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
if not need_vision[i]:
|
||||
continue
|
||||
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
pad_len, pad_dim = image_features[pt].shape
|
||||
dim = input_embeds.shape[1]
|
||||
assert (
|
||||
pad_dim == dim
|
||||
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
||||
# Fill in the placeholder for the image
|
||||
try:
|
||||
input_embeds[
|
||||
start_idx
|
||||
+ image_offsets[i] : start_idx
|
||||
+ image_offsets[i]
|
||||
+ pad_len
|
||||
] = image_features[pt]
|
||||
except RuntimeError as e:
|
||||
print(f"RuntimeError in llava image encoding: {e}")
|
||||
print(input_embeds.shape)
|
||||
print(start_idx, image_offsets[i])
|
||||
pt += 1
|
||||
|
||||
return self.language_model(
|
||||
input_embeds, positions, input_metadata, skip_embed=True
|
||||
)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata, skip_embed=False
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
# load clip vision model by cfg['mm_vision_tower']:
|
||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||
vision_path = self.config.mm_vision_tower
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
vision_path, torch_dtype=torch.float16
|
||||
).cuda()
|
||||
self.vision_tower.eval()
|
||||
|
||||
self.vision_feature_layer = self.config.mm_vision_select_layer
|
||||
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
|
||||
self.image_size = self.vision_tower.config.image_size
|
||||
self.patch_size = self.vision_tower.config.patch_size
|
||||
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
|
||||
if self.vision_feature_select_strategy == "patch":
|
||||
pass
|
||||
elif self.vision_feature_select_strategy == "cls_patch":
|
||||
self.image_feature_len += 1
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||
|
||||
# load mm_projector
|
||||
# TODO: support TP?
|
||||
projector_weights = {
|
||||
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||
}
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
):
|
||||
# FIXME: why projector weights read two times?
|
||||
if "projector" in name:
|
||||
for weight_name, param_name in projector_weights.items():
|
||||
if weight_name in name:
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load language model
|
||||
self.language_model.load_weights(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
)
|
||||
|
||||
monkey_path_clip_vision_embed_forward()
|
||||
|
||||
|
||||
first_call = True
|
||||
|
||||
|
||||
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
||||
global first_call
|
||||
if first_call:
|
||||
self.patch_embedding.cpu().float()
|
||||
first_call = False
|
||||
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
||||
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
||||
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
def monkey_path_clip_vision_embed_forward():
|
||||
import transformers
|
||||
|
||||
setattr(
|
||||
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
||||
"forward",
|
||||
clip_vision_embed_forward,
|
||||
)
|
||||
378
python/sglang/srt/models/mixtral.py
Normal file
378
python/sglang/srt/models/mixtral.py
Normal file
@@ -0,0 +1,378 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.ffn_dim = intermediate_size
|
||||
self.hidden_dim = hidden_size
|
||||
|
||||
self.w1 = ReplicatedLinear(
|
||||
self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
|
||||
)
|
||||
self.w2 = ReplicatedLinear(
|
||||
self.ffn_dim, self.hidden_dim, bias=False, linear_method=linear_method
|
||||
)
|
||||
self.w3 = ReplicatedLinear(
|
||||
self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
|
||||
)
|
||||
|
||||
# TODO: Use vllm's SiluAndMul
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
w1_out, _ = self.w1(hidden_states)
|
||||
w1_out = self.act_fn(w1_out)
|
||||
w3_out, _ = self.w3(hidden_states)
|
||||
current_hidden_states = w1_out * w3_out
|
||||
current_hidden_states, _ = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_total_experts = config.num_local_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
if self.tp_size > self.num_total_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {self.num_total_experts}."
|
||||
)
|
||||
# Split experts equally between ranks
|
||||
self.expert_indicies = np.array_split(
|
||||
range(self.num_total_experts), self.tp_size
|
||||
)[self.rank].tolist()
|
||||
if not self.expert_indicies:
|
||||
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
|
||||
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
MixtralMLP(
|
||||
self.num_total_experts,
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
if idx in self.expert_indicies
|
||||
else None
|
||||
for idx in range(self.num_total_experts)
|
||||
]
|
||||
)
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size, self.num_total_experts, bias=False, linear_method=None
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(
|
||||
routing_weights, self.top_k, dim=-1
|
||||
)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
final_hidden_states = None
|
||||
for expert_idx in self.expert_indicies:
|
||||
expert_layer = self.experts[expert_idx]
|
||||
expert_mask = selected_experts == expert_idx
|
||||
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
|
||||
|
||||
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
|
||||
if final_hidden_states is None:
|
||||
final_hidden_states = current_hidden_states
|
||||
else:
|
||||
final_hidden_states.add_(current_hidden_states)
|
||||
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
|
||||
class MixtralAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
layer_id: int = 0,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=int(self.rope_theta),
|
||||
is_neox_style=True,
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class MixtralDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
layer_id: int = 0,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = MixtralAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
sliding_window=config.sliding_window,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.block_sparse_moe = MixtralMoE(config=config, linear_method=linear_method)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.block_sparse_moe(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class MixtralModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
# config.num_hidden_layers=16
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MixtralDecoderLayer(config, i, linear_method=linear_method)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
skip_embed: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if not skip_embed:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_ids
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, input_metadata, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MixtralForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = MixtralModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
skip_embed: bool = False,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision, fall_back_to_pt=False
|
||||
):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if "block_sparse_moe.experts." in name and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
81
python/sglang/srt/sampling_params.py
Normal file
81
python/sglang/srt/sampling_params.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
from typing import List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-6
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_new_tokens: int = 16,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
dtype: Optional[str] = None,
|
||||
regex: Optional[str] = None,
|
||||
) -> None:
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.stop_strs = stop
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.ignore_eos = ignore_eos
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.dtype = dtype
|
||||
self.regex = regex
|
||||
|
||||
# Process some special cases
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
self.temperature = 1.0
|
||||
self.top_k = 1
|
||||
if self.top_k == -1:
|
||||
self.top_k = 1 << 30 # whole vocabulary
|
||||
if self.dtype == "int":
|
||||
self.stop_strs = [" ", "\n"]
|
||||
|
||||
def verify(self):
|
||||
if self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {self.temperature}."
|
||||
)
|
||||
if not 0.0 < self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
||||
if self.top_k < -1 or self.top_k == 0:
|
||||
raise ValueError(
|
||||
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
|
||||
)
|
||||
if not -2.0 <= self.frequency_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
"frequency_penalty must be in [-2, 2], got "
|
||||
f"{self.frequency_penalty}."
|
||||
)
|
||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
|
||||
)
|
||||
if self.max_new_tokens < 0:
|
||||
raise ValueError(
|
||||
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
|
||||
)
|
||||
|
||||
def normalize(self, tokenizer):
|
||||
# Process stop strings
|
||||
if self.stop_strs is None:
|
||||
self.stop_strs = []
|
||||
self.stop_str_max_len = 0
|
||||
else:
|
||||
if isinstance(self.stop_strs, str):
|
||||
self.stop_strs = [self.stop_strs]
|
||||
|
||||
stop_str_max_len = 0
|
||||
for stop_str in self.stop_strs:
|
||||
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
|
||||
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
|
||||
self.stop_str_max_len = stop_str_max_len
|
||||
222
python/sglang/srt/server.py
Normal file
222
python/sglang/srt/server.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""SRT: SGLang Runtime"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
# Fix a Python bug
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
import psutil
|
||||
import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.openai_protocol import CompletionRequest
|
||||
from sglang.srt.managers.router.manager import start_router_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import alloc_usable_network_port
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
tokenizer_manager = None
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
result = {
|
||||
"model_path": tokenizer_manager.model_path,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_request(obj: GenerateReqInput):
|
||||
obj.post_init()
|
||||
result_generator = tokenizer_manager.generate_request(obj)
|
||||
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results():
|
||||
async for out in result_generator:
|
||||
yield (json.dumps(out) + "\0").encode("utf-8")
|
||||
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||
else:
|
||||
ret = await result_generator.__anext__()
|
||||
return ret
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def v1_completions(obj: CompletionRequest):
|
||||
assert obj.n == 1
|
||||
obj = GenerateReqInput(
|
||||
text=obj.prompt,
|
||||
sampling_params={
|
||||
"temperature": obj.temperature,
|
||||
"max_new_tokens": obj.max_tokens,
|
||||
"stop": obj.stop,
|
||||
},
|
||||
)
|
||||
ret = await generate_request(obj)
|
||||
return {
|
||||
"choices": [{"text": ret["text"]}],
|
||||
}
|
||||
|
||||
|
||||
def launch_server(server_args, pipe_finish_writer):
|
||||
global tokenizer_manager
|
||||
|
||||
# Allocate ports
|
||||
can_use_ports = alloc_usable_network_port(
|
||||
num=4 + server_args.tp_size, used_list=(server_args.port,)
|
||||
)
|
||||
port_args = PortArgs(
|
||||
tokenizer_port=can_use_ports[0],
|
||||
router_port=can_use_ports[1],
|
||||
detokenizer_port=can_use_ports[2],
|
||||
nccl_port=can_use_ports[3],
|
||||
model_rpc_ports=can_use_ports[4:],
|
||||
)
|
||||
|
||||
# Launch processes
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||
|
||||
proc_router = mp.Process(
|
||||
target=start_router_process,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
pipe_router_writer,
|
||||
),
|
||||
)
|
||||
proc_router.start()
|
||||
proc_detoken = mp.Process(
|
||||
target=start_detokenizer_process,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
pipe_detoken_writer,
|
||||
),
|
||||
)
|
||||
proc_detoken.start()
|
||||
|
||||
# Wait for the model to finish loading
|
||||
router_init_state = pipe_router_reader.recv()
|
||||
detoken_init_state = pipe_detoken_reader.recv()
|
||||
|
||||
if router_init_state != "init ok" or detoken_init_state != "init ok":
|
||||
proc_router.kill()
|
||||
proc_detoken.kill()
|
||||
print("router init state:", router_init_state)
|
||||
print("detoken init state:", detoken_init_state)
|
||||
sys.exit(1)
|
||||
|
||||
assert proc_router.is_alive() and proc_detoken.is_alive()
|
||||
|
||||
def launch_server():
|
||||
# Launch api server
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level,
|
||||
timeout_keep_alive=5,
|
||||
loop="uvloop",
|
||||
)
|
||||
|
||||
t = threading.Thread(target=launch_server)
|
||||
t.start()
|
||||
|
||||
if pipe_finish_writer:
|
||||
url = server_args.url()
|
||||
|
||||
success = False
|
||||
for i in range(60):
|
||||
try:
|
||||
res = requests.get(url + "/get_model_info", timeout=5)
|
||||
success = True
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
time.sleep(1)
|
||||
|
||||
if success:
|
||||
pipe_finish_writer.send("init ok")
|
||||
else:
|
||||
pipe_finish_writer.send(str(e))
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = True,
|
||||
mem_fraction_static: float = 0.9,
|
||||
tp_size: int = 1,
|
||||
model_mode: List[str] = (),
|
||||
schedule_heuristic: str = "lpm",
|
||||
random_seed: int = 42,
|
||||
log_level: str = "warning",
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
port = alloc_usable_network_port(1)[0]
|
||||
server_args = ServerArgs(
|
||||
model_path=model_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
host=host,
|
||||
port=port,
|
||||
load_format=load_format,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=trust_remote_code,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
tp_size=tp_size,
|
||||
model_mode=model_mode,
|
||||
schedule_heuristic=schedule_heuristic,
|
||||
random_seed=random_seed,
|
||||
log_level=log_level,
|
||||
)
|
||||
self.url = server_args.url()
|
||||
|
||||
self.pid = None
|
||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||
proc = mp.Process(target=launch_server, args=(server_args, pipe_writer))
|
||||
proc.start()
|
||||
self.pid = proc.pid
|
||||
|
||||
init_state = pipe_reader.recv()
|
||||
if init_state != "init ok":
|
||||
self.shutdown()
|
||||
raise RuntimeError("Launch failed")
|
||||
|
||||
self.endpoint = RuntimeEndpoint(self.url)
|
||||
|
||||
def shutdown(self):
|
||||
if self.pid is not None:
|
||||
parent = psutil.Process(self.pid)
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
child.kill()
|
||||
psutil.wait_procs(children, timeout=5)
|
||||
parent.kill()
|
||||
parent.wait(timeout=5)
|
||||
self.pid = None
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
138
python/sglang/srt/server_args.py
Normal file
138
python/sglang/srt/server_args.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ServerArgs:
|
||||
model_path: str
|
||||
tokenizer_path: Optional[str] = None
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
load_format: str = "auto"
|
||||
tokenizer_mode: str = "auto"
|
||||
trust_remote_code: bool = True
|
||||
mem_fraction_static: float = 0.91
|
||||
tp_size: int = 1
|
||||
model_mode: List[str] = ()
|
||||
schedule_heuristic: str = "lpm"
|
||||
random_seed: int = 42
|
||||
disable_log_stats: bool = False
|
||||
log_stats_interval: int = 10
|
||||
log_level: str = "info"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-path",
|
||||
type=str,
|
||||
default=ServerArgs.tokenizer_path,
|
||||
help="The path of the tokenizer.",
|
||||
)
|
||||
parser.add_argument("--host", type=str, default=ServerArgs.host)
|
||||
parser.add_argument("--port", type=int, default=ServerArgs.port)
|
||||
parser.add_argument(
|
||||
"--load-format",
|
||||
type=str,
|
||||
default=ServerArgs.load_format,
|
||||
choices=["auto", "pt", "safetensors", "npcache", "dummy"],
|
||||
help="The format of the model weights to load. "
|
||||
'"auto" will try to load the weights in the safetensors format '
|
||||
"and fall back to the pytorch bin format if safetensors format "
|
||||
"is not available. "
|
||||
'"pt" will load the weights in the pytorch bin format. '
|
||||
'"safetensors" will load the weights in the safetensors format. '
|
||||
'"npcache" will load the weights in pytorch format and store '
|
||||
"a numpy cache to speed up the loading. "
|
||||
'"dummy" will initialize the weights with random values, '
|
||||
"which is mainly for profiling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-mode",
|
||||
type=str,
|
||||
default=ServerArgs.tokenizer_mode,
|
||||
choices=["auto", "slow"],
|
||||
help="Tokenizer mode. 'auto' will use the fast "
|
||||
"tokenizer if available, and 'slow' will "
|
||||
"always use the slow tokenizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-fraction-static",
|
||||
type=float,
|
||||
default=ServerArgs.mem_fraction_static,
|
||||
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
type=int,
|
||||
default=ServerArgs.tp_size,
|
||||
help="Tensor parallelism degree.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-mode",
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="+",
|
||||
help="Model mode: [flashinfer, no-cache, aggressive-new-fill]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--schedule-heuristic",
|
||||
type=str,
|
||||
default=ServerArgs.schedule_heuristic,
|
||||
help="Schudule mode: [lpm, weight, random, fcfs]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-seed",
|
||||
type=int,
|
||||
default=ServerArgs.random_seed,
|
||||
help="Random seed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default=ServerArgs.log_level,
|
||||
help="Log level",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-log-stats",
|
||||
action="store_true",
|
||||
help="Disable logging throughput stats.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-stats-interval",
|
||||
type=int,
|
||||
default=ServerArgs.log_stats_interval,
|
||||
help="Log stats interval in second.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
|
||||
def url(self):
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PortArgs:
|
||||
tokenizer_port: int
|
||||
router_port: int
|
||||
detokenizer_port: int
|
||||
nccl_port: int
|
||||
model_rpc_ports: List[int]
|
||||
217
python/sglang/srt/utils.py
Normal file
217
python/sglang/srt/utils.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import base64
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
is_show_cost_time = False
|
||||
|
||||
|
||||
def mark_cost_time(func_name):
|
||||
def inner_func(func):
|
||||
def time_func(*args, **kwargs):
|
||||
if dist.get_rank() in [0, 1] and is_show_cost_time:
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
ans = func(*args, **kwargs)
|
||||
torch.cuda.synchronize()
|
||||
print(func_name, "cost time:", (time.time() - start_time) * 1000)
|
||||
return ans
|
||||
else:
|
||||
torch.cuda.synchronize()
|
||||
ans = func(*args, **kwargs)
|
||||
torch.cuda.synchronize()
|
||||
return ans
|
||||
|
||||
return time_func
|
||||
|
||||
return inner_func
|
||||
|
||||
|
||||
time_mark = {}
|
||||
|
||||
|
||||
def mark_start(key):
|
||||
torch.cuda.synchronize()
|
||||
global time_mark
|
||||
time_mark[key] = time.time()
|
||||
return
|
||||
|
||||
|
||||
def mark_end(key, print_min_cost=0.0):
|
||||
torch.cuda.synchronize()
|
||||
global time_mark
|
||||
cost_time = (time.time() - time_mark[key]) * 1000
|
||||
if cost_time > print_min_cost:
|
||||
print(f"cost {key}:", cost_time)
|
||||
|
||||
|
||||
def calculate_time(show=False, min_cost_ms=0.0):
|
||||
def wrapper(func):
|
||||
def inner_func(*args, **kwargs):
|
||||
torch.cuda.synchronize()
|
||||
if show:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
torch.cuda.synchronize()
|
||||
if show:
|
||||
cost_time = (time.time() - start_time) * 1000
|
||||
if cost_time > min_cost_ms:
|
||||
print(f"Function {func.__name__} took {cost_time} ms to run.")
|
||||
return result
|
||||
|
||||
return inner_func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def alloc_usable_network_port(num, used_list=()):
|
||||
port_list = []
|
||||
for port in range(10000, 65536):
|
||||
if port in used_list:
|
||||
continue
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.bind(("", port))
|
||||
port_list.append(port)
|
||||
except socket.error:
|
||||
pass
|
||||
|
||||
if len(port_list) == num:
|
||||
return port_list
|
||||
return None
|
||||
|
||||
|
||||
def get_exception_traceback():
|
||||
etype, value, tb = sys.exc_info()
|
||||
err_str = "".join(traceback.format_exception(etype, value, tb))
|
||||
return err_str
|
||||
|
||||
|
||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
from transformers import LlamaTokenizer, LlamaTokenizerFast
|
||||
|
||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
||||
for t_id in range(vocab_size):
|
||||
ss = tokenizer.decode(t_id).strip()
|
||||
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
|
||||
logit_bias[t_id] = -1e5
|
||||
# else:
|
||||
# print(ss, t_id)
|
||||
|
||||
return logit_bias
|
||||
|
||||
|
||||
def wrap_kernel_launcher(kernel):
|
||||
"""A faster launcher for triton kernels."""
|
||||
import torch.distributed as dist
|
||||
|
||||
if dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
|
||||
kernels = kernel.cache[rank].values()
|
||||
kernel = next(iter(kernels))
|
||||
|
||||
# Different trition versions use different low-level names
|
||||
if hasattr(kernel, "cu_function"):
|
||||
kfunction = kernel.cu_function
|
||||
else:
|
||||
kfunction = kernel.function
|
||||
|
||||
if hasattr(kernel, "c_wrapper"):
|
||||
run = kernel.c_wrapper
|
||||
else:
|
||||
run = kernel.run
|
||||
|
||||
add_cluster_dim = True
|
||||
|
||||
def ret_func(grid, num_warps, *args):
|
||||
nonlocal add_cluster_dim
|
||||
|
||||
try:
|
||||
if add_cluster_dim:
|
||||
run(
|
||||
grid[0],
|
||||
grid[1],
|
||||
grid[2],
|
||||
num_warps,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
kernel.shared,
|
||||
0,
|
||||
kfunction,
|
||||
None,
|
||||
None,
|
||||
kernel,
|
||||
*args,
|
||||
)
|
||||
else:
|
||||
run(
|
||||
grid[0],
|
||||
grid[1],
|
||||
grid[2],
|
||||
num_warps,
|
||||
kernel.shared,
|
||||
0,
|
||||
kfunction,
|
||||
None,
|
||||
None,
|
||||
kernel,
|
||||
*args,
|
||||
)
|
||||
except TypeError:
|
||||
add_cluster_dim = not add_cluster_dim
|
||||
ret_func(grid, num_warps, *args)
|
||||
|
||||
return ret_func
|
||||
|
||||
|
||||
def is_multimodal_model(model):
|
||||
if isinstance(model, str):
|
||||
return "llava" in model
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
if isinstance(model, ModelConfig):
|
||||
return "llava" in model.path.lower()
|
||||
raise Exception("unrecognized type")
|
||||
|
||||
|
||||
def load_image(image_file):
|
||||
from PIL import Image
|
||||
|
||||
image = None
|
||||
|
||||
if image_file.startswith("http://") or image_file.startswith("https://"):
|
||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
||||
response = requests.get(image_file, timeout=timeout)
|
||||
image = Image.open(BytesIO(response.content))
|
||||
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
|
||||
image = Image.open(image_file)
|
||||
elif image_file.startswith("data:"):
|
||||
image_file = image_url.split(",")[1]
|
||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||
else:
|
||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||
|
||||
return image
|
||||
Reference in New Issue
Block a user