Format code & move functions (#155)
This commit is contained in:
@@ -193,6 +193,7 @@ def match_chat_ml(model_path: str):
|
||||
if "qwen" in model_path and "chat" in model_path:
|
||||
return get_chat_template("chatml")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_yi(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
|
||||
@@ -64,13 +64,19 @@ class LogitsProcessor(nn.Module):
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32)
|
||||
logprobs_cumsum = torch.cumsum(
|
||||
prefill_logprobs, dim=0, dtype=torch.float32
|
||||
)
|
||||
|
||||
start = input_metadata.extend_start_loc.clone()
|
||||
end = start + input_metadata.extend_seq_lens - 2
|
||||
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start]
|
||||
sum_logp = (
|
||||
logprobs_cumsum[end]
|
||||
- logprobs_cumsum[start]
|
||||
+ prefill_logprobs[start]
|
||||
)
|
||||
normalized_logprobs = sum_logp / (
|
||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
)
|
||||
|
||||
@@ -13,14 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
head_dim,
|
||||
scaling,
|
||||
num_kv_heads,
|
||||
layer_id
|
||||
):
|
||||
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
|
||||
super().__init__()
|
||||
self.tp_q_head_num = num_heads
|
||||
self.tp_k_head_num = num_kv_heads
|
||||
|
||||
@@ -100,6 +100,7 @@ class BatchStrOut:
|
||||
class FlushCacheReq:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetokenizeReqInput:
|
||||
input_ids: List[int]
|
||||
|
||||
@@ -11,8 +11,8 @@ import rpyc
|
||||
import torch
|
||||
from rpyc.utils.classic import obtain
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchTokenIDOut,
|
||||
@@ -391,8 +391,12 @@ class ModelRpcServer(rpyc.Service):
|
||||
logprobs = None
|
||||
if batch.extend_num_tokens != 0:
|
||||
# Forward
|
||||
logits, (prefill_logprobs, normalized_logprobs, last_logprobs) = (
|
||||
self.model_runner.forward(batch, ForwardMode.EXTEND, batch.return_logprob)
|
||||
logits, (
|
||||
prefill_logprobs,
|
||||
normalized_logprobs,
|
||||
last_logprobs,
|
||||
) = self.model_runner.forward(
|
||||
batch, ForwardMode.EXTEND, batch.return_logprob
|
||||
)
|
||||
if prefill_logprobs is not None:
|
||||
logprobs = prefill_logprobs.cpu().tolist()
|
||||
@@ -407,7 +411,9 @@ class ModelRpcServer(rpyc.Service):
|
||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||
reqs = batch.reqs
|
||||
if last_logprobs is not None:
|
||||
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
||||
last_logprobs = (
|
||||
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
||||
)
|
||||
|
||||
# Check finish condition
|
||||
pt = 0
|
||||
@@ -482,7 +488,9 @@ class ModelRpcServer(rpyc.Service):
|
||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||
reqs = batch.reqs
|
||||
if last_logprobs is not None:
|
||||
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist()
|
||||
last_logprobs = last_logprobs[
|
||||
torch.arange(len(reqs)), next_token_ids
|
||||
].tolist()
|
||||
|
||||
# Check finish condition
|
||||
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
|
||||
@@ -620,15 +628,16 @@ class ModelRpcClient:
|
||||
self.step = async_wrap("step")
|
||||
|
||||
|
||||
def start_model_process(port):
|
||||
def _init_service(port):
|
||||
t = ThreadedServer(
|
||||
ModelRpcServer(),
|
||||
port=port,
|
||||
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
|
||||
)
|
||||
t.start()
|
||||
def _init_service(port):
|
||||
t = ThreadedServer(
|
||||
ModelRpcServer(),
|
||||
port=port,
|
||||
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
|
||||
)
|
||||
t.start()
|
||||
|
||||
|
||||
def start_model_process(port):
|
||||
proc = multiprocessing.Process(target=_init_service, args=(port,))
|
||||
proc.start()
|
||||
time.sleep(1)
|
||||
|
||||
@@ -17,8 +17,8 @@ from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||
|
||||
import sglang
|
||||
QUANTIONCONFIG_MAPPING = {'awq': AWQConfig,
|
||||
'gptq': GPTQConfig}
|
||||
|
||||
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig}
|
||||
|
||||
logger = logging.getLogger("model_runner")
|
||||
|
||||
@@ -283,9 +283,13 @@ class ModelRunner:
|
||||
self.model_config.hf_config, "quantization_config", None
|
||||
)
|
||||
if hf_quant_config is not None:
|
||||
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_config['quant_method'])
|
||||
quant_config_class = QUANTIONCONFIG_MAPPING.get(
|
||||
hf_quant_config["quant_method"]
|
||||
)
|
||||
if quant_config_class is None:
|
||||
raise ValueError(f"Unsupported quantization method: {hf_quant_config['quant_method']}")
|
||||
raise ValueError(
|
||||
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
|
||||
)
|
||||
quant_config = quant_config_class.from_config(hf_quant_config)
|
||||
logger.info(f"quant_config: {quant_config}")
|
||||
linear_method = quant_config.get_linear_method()
|
||||
|
||||
@@ -42,14 +42,14 @@ class QWenMLP(nn.Module):
|
||||
2 * [intermediate_size],
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
linear_method=linear_method
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method
|
||||
linear_method=linear_method,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -74,7 +74,7 @@ class QWenAttention(nn.Module):
|
||||
layer_id: int = 0,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -86,18 +86,18 @@ class QWenAttention(nn.Module):
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.c_attn = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -143,12 +143,16 @@ class QWenBlock(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
layer_id=layer_id,
|
||||
linear_method=linear_method
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, linear_method=linear_method)
|
||||
self.mlp = QWenMLP(
|
||||
config.hidden_size,
|
||||
config.intermediate_size // 2,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -186,7 +190,10 @@ class QWenModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.h = nn.ModuleList(
|
||||
[QWenBlock(config, i, linear_method=linear_method) for i in range(config.num_hidden_layers)]
|
||||
[
|
||||
QWenBlock(config, i, linear_method=linear_method)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
|
||||
@@ -4,14 +4,17 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from sglang.srt.models.llava import (
|
||||
LlavaLlamaForCausalLM,
|
||||
clip_vision_embed_forward,
|
||||
monkey_path_clip_vision_embed_forward,
|
||||
)
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.models.llava import LlavaLlamaForCausalLM, clip_vision_embed_forward, monkey_path_clip_vision_embed_forward
|
||||
|
||||
|
||||
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -19,7 +22,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
super().__init__(self.config)
|
||||
|
||||
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
|
||||
self.vision_tower_subfolder = self.config.mm_vision_tower.replace("./", "") # Everything after "./"
|
||||
self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
|
||||
"./", ""
|
||||
) # Everything after "./"
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
@@ -30,7 +35,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
):
|
||||
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
model_name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder
|
||||
model_name_or_path,
|
||||
torch_dtype=torch.float16,
|
||||
subfolder=self.vision_tower_subfolder,
|
||||
).cuda()
|
||||
|
||||
self.vision_tower.eval()
|
||||
@@ -80,14 +87,19 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
|
||||
monkey_path_clip_vision_embed_forward()
|
||||
|
||||
|
||||
class YiVLMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlavaConfig):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size, config.text_config.hidden_size
|
||||
)
|
||||
self.ln_1 = nn.LayerNorm(config.text_config.hidden_size)
|
||||
self.act = nn.GELU()
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size
|
||||
)
|
||||
self.ln_2 = nn.LayerNorm(config.text_config.hidden_size)
|
||||
|
||||
def forward(self, image_features):
|
||||
@@ -98,4 +110,5 @@ class YiVLMultiModalProjector(nn.Module):
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
EntryClass = YiVLForCausalLM
|
||||
|
||||
EntryClass = YiVLForCausalLM
|
||||
|
||||
@@ -63,6 +63,7 @@ chat_template_name = None
|
||||
# FIXME: Remove this once we drop support for pydantic 1.x
|
||||
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
||||
|
||||
|
||||
def jsonify_pydantic_model(obj: BaseModel):
|
||||
if IS_PYDANTIC_1:
|
||||
return obj.json(ensure_ascii=False)
|
||||
@@ -165,7 +166,7 @@ async def v1_completions(raw_request: Request):
|
||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||
|
||||
if not stream_buffer: # The first chunk
|
||||
if not stream_buffer: # The first chunk
|
||||
if request.echo:
|
||||
# Prepend prompt in response text.
|
||||
text = request.prompt + text
|
||||
@@ -219,7 +220,9 @@ async def v1_completions(raw_request: Request):
|
||||
token_logprob_pos = prompt_tokens
|
||||
|
||||
logprobs = (
|
||||
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
|
||||
await make_openai_style_logprobs(
|
||||
ret["meta_info"]["token_logprob"][token_logprob_pos:]
|
||||
)
|
||||
if request.logprobs is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -114,7 +114,7 @@ class ServerArgs:
|
||||
"--max-prefill-num-token",
|
||||
type=int,
|
||||
default=ServerArgs.max_prefill_num_token,
|
||||
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length."
|
||||
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
|
||||
@@ -259,4 +259,4 @@ def load_image(image_file):
|
||||
else:
|
||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||
|
||||
return image
|
||||
return image
|
||||
|
||||
Reference in New Issue
Block a user