[Fix] Reduce memory usage for loading llava model & Remove EntryClassRemapping (#1308)
This commit is contained in:
2
.github/workflows/pr-test.yml
vendored
2
.github/workflows/pr-test.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Pull Request Test
|
||||
name: PR Test
|
||||
|
||||
on:
|
||||
push:
|
||||
|
||||
@@ -205,7 +205,7 @@ It supports streaming, vision, and most features of the Chat/Completions/Models/
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --tp 2
|
||||
```
|
||||
- Add `--dp 2` to enable multi-GPU data parallelism. It can also be used together with tensor parallelism. Data parallelism is better for throughput if there is enough memory.
|
||||
- Add `--dp 2` to enable multi-GPU data parallelism. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total.
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --dp 2 --tp 2
|
||||
```
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
# Custom Chat Template in SGLang Runtime
|
||||
|
||||
By default, the server uses the chat template specified in the model tokenizer from Hugging Face. It should just work for most official models such as Llama-2/Llama-3.
|
||||
**NOTE**: There are two chat template systems in SGLang project. This document is about setting a custom chat template for the OpenAI-compatible API server (defined at [conversation.py](../../python/sglang/srt/conversation.py)). It is NOT related to the chat template used in the SGLang language frontend (defined at [chat_template.py](../../python/sglang/lang/chat_template.py)).
|
||||
|
||||
By default, the server uses the chat template specified in the model tokenizer from Hugging Face.
|
||||
It should just work for most official models such as Llama-2/Llama-3.
|
||||
|
||||
If needed, you can also override the chat template when launching the server:
|
||||
|
||||
|
||||
@@ -2,13 +2,8 @@
|
||||
Usage: python3 local_example_llava_next.py
|
||||
"""
|
||||
|
||||
from PIL import ImageFile
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images
|
||||
|
||||
|
||||
@sgl.function
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import (
|
||||
@@ -23,6 +23,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
base_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
verify: Optional[str] = None,
|
||||
chat_template_name: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.support_concate_and_append = True
|
||||
@@ -39,9 +40,12 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self._assert_success(res)
|
||||
self.model_info = res.json()
|
||||
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_path"]
|
||||
)
|
||||
if chat_template_name:
|
||||
self.chat_template = get_chat_template(chat_template_name)
|
||||
else:
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_path"]
|
||||
)
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_path"]
|
||||
|
||||
@@ -86,8 +86,8 @@ class TokenizerManager:
|
||||
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
|
||||
self.send_to_router = context.socket(zmq.PUSH)
|
||||
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||
self.send_to_controller = context.socket(zmq.PUSH)
|
||||
self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
@@ -271,7 +271,7 @@ class TokenizerManager:
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
self.send_to_controller.send_pyobj(tokenized_obj)
|
||||
|
||||
# Recv results
|
||||
event = asyncio.Event()
|
||||
@@ -367,7 +367,7 @@ class TokenizerManager:
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
self.send_to_controller.send_pyobj(tokenized_obj)
|
||||
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event)
|
||||
@@ -500,14 +500,14 @@ class TokenizerManager:
|
||||
|
||||
def flush_cache(self):
|
||||
req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(req)
|
||||
self.send_to_controller.send_pyobj(req)
|
||||
|
||||
def abort_request(self, rid: str):
|
||||
if rid not in self.rid_to_state:
|
||||
return
|
||||
del self.rid_to_state[rid]
|
||||
req = AbortReq(rid)
|
||||
self.send_to_router.send_pyobj(req)
|
||||
self.send_to_controller.send_pyobj(req)
|
||||
|
||||
async def update_weights(
|
||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||
@@ -524,7 +524,7 @@ class TokenizerManager:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0)
|
||||
self.send_to_router.send_pyobj(obj)
|
||||
self.send_to_controller.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
result = await self.model_update_result
|
||||
if result.success:
|
||||
|
||||
@@ -606,16 +606,6 @@ def import_model_classes():
|
||||
assert entry.__name__ not in model_arch_name_to_cls
|
||||
model_arch_name_to_cls[entry.__name__] = entry
|
||||
|
||||
# compat: some models such as chatglm has incorrect class set in config.json
|
||||
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
|
||||
if hasattr(module, "EntryClassRemapping") and isinstance(
|
||||
module.EntryClassRemapping, list
|
||||
):
|
||||
for remap in module.EntryClassRemapping:
|
||||
if isinstance(remap, tuple) and len(remap) == 2:
|
||||
assert remap[0] not in model_arch_name_to_cls
|
||||
model_arch_name_to_cls[remap[0]] = remap[1]
|
||||
|
||||
return model_arch_name_to_cls
|
||||
|
||||
|
||||
|
||||
@@ -402,6 +402,8 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = ChatGLMForCausalLM
|
||||
# compat: glm model.config class == ChatGLMModel
|
||||
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
|
||||
class ChatGLMModel(ChatGLMForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [ChatGLMForCausalLM, ChatGLMModel]
|
||||
|
||||
@@ -297,7 +297,6 @@ class ExaoneForCausalLM(nn.Module):
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
efficient_weight_load=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -345,9 +344,7 @@ class ExaoneForCausalLM(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
||||
):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -358,7 +355,7 @@ class ExaoneForCausalLM(nn.Module):
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
def load_weights_per_param(name, loaded_weight):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
return
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
@@ -368,6 +365,7 @@ class ExaoneForCausalLM(nn.Module):
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
return
|
||||
|
||||
name = name.replace("attn.attention", "self_attn")
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@@ -387,13 +385,5 @@ class ExaoneForCausalLM(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name is None or loaded_weight is None:
|
||||
for name, loaded_weight in weights:
|
||||
name = name.replace("attn.attention", "self_attn")
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
else:
|
||||
name = name.replace("attn.attention", "self_attn")
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = ExaoneForCausalLM
|
||||
|
||||
@@ -295,7 +295,6 @@ class LlamaForCausalLM(nn.Module):
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
efficient_weight_load=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -305,6 +304,8 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
@@ -320,30 +321,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def get_module_name(self, name):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id, num_shard)
|
||||
("qkv_proj", "q_proj", "q", 3),
|
||||
("qkv_proj", "k_proj", "k", 3),
|
||||
("qkv_proj", "v_proj", "v", 3),
|
||||
("gate_up_proj", "gate_proj", 0, 2),
|
||||
("gate_up_proj", "up_proj", 1, 2),
|
||||
]
|
||||
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
return (
|
||||
name.replace(weight_name, param_name)[: -len(".weight")],
|
||||
num_shard,
|
||||
)
|
||||
return name[: -len(".weight")], 1
|
||||
|
||||
def get_num_params(self):
|
||||
params_dict = dict(self.named_parameters())
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
||||
):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -352,9 +330,9 @@ class LlamaForCausalLM(nn.Module):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
params_dict = self.param_dict
|
||||
|
||||
def load_weights_per_param(name, loaded_weight):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
return
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
@@ -383,11 +361,5 @@ class LlamaForCausalLM(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name is None or loaded_weight is None:
|
||||
for name, loaded_weight in weights:
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
else:
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = LlamaForCausalLM
|
||||
@@ -16,17 +16,16 @@ limitations under the License.
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaModel
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
|
||||
class LlamaForClassification(nn.Module):
|
||||
@@ -42,10 +41,12 @@ class LlamaForClassification(nn.Module):
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
|
||||
self.classification_head = nn.Linear(
|
||||
config.hidden_size, config.classification_out_size
|
||||
config.hidden_size, config.classification_out_size, bias=False
|
||||
)
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
@@ -65,7 +66,7 @@ class LlamaForClassification(nn.Module):
|
||||
(input_metadata.batch_size, self.config.classification_out_size)
|
||||
).to(input_ids.device)
|
||||
|
||||
return LogitsProcessorOutput(
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=scores,
|
||||
next_token_logprobs=scores,
|
||||
normalized_prompt_logprobs=scores,
|
||||
@@ -74,46 +75,38 @@ class LlamaForClassification(nn.Module):
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
continue
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if "lm_head" in name:
|
||||
continue
|
||||
# A dummy to make this work
|
||||
sample_output = SampleOutput(
|
||||
success=torch.full(
|
||||
size=(scores.shape[0],),
|
||||
fill_value=True,
|
||||
dtype=torch.bool,
|
||||
),
|
||||
probs=torch.full(
|
||||
size=(scores.shape[0], 1),
|
||||
fill_value=1.0,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
batch_next_token_ids=torch.full(
|
||||
size=(scores.shape[0],),
|
||||
fill_value=0,
|
||||
dtype=torch.long,
|
||||
),
|
||||
)
|
||||
return sample_output, logits_output
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = self.param_dict
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "classification_head" in name:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
elif "lm_head" in name:
|
||||
continue
|
||||
else:
|
||||
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
|
||||
|
||||
|
||||
EntryClass = LlamaForClassification
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Iterable, Optional, Tuple
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -7,7 +7,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel
|
||||
from sglang.srt.models.llama import LlamaModel
|
||||
|
||||
|
||||
class LlamaEmbeddingModel(nn.Module):
|
||||
@@ -16,7 +16,6 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
config: LlamaConfig,
|
||||
quant_config=None,
|
||||
cache_config=None,
|
||||
efficient_weight_load=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
@@ -86,6 +85,8 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = LlamaEmbeddingModel
|
||||
# compat: e5-mistral model.config class == MistralModel
|
||||
EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
|
||||
class MistralModel(LlamaEmbeddingModel):
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [LlamaEmbeddingModel, MistralModel]
|
||||
|
||||
@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
|
||||
@@ -395,21 +395,19 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||
"model.image_newline": "language_model.model.image_newline",
|
||||
}
|
||||
params_dict = dict(self.named_parameters())
|
||||
weights = list(weights)
|
||||
for name, loaded_weight in weights:
|
||||
# FIXME: why projector weights read two times?
|
||||
if "projector" in name or "vision_tower" in name:
|
||||
if "projector" in name or "vision_tower" in name or "image_newline" in name:
|
||||
for weight_name, param_name in projector_weights.items():
|
||||
if weight_name in name:
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load language model
|
||||
self.language_model.load_weights(weights)
|
||||
else:
|
||||
self.language_model.load_weights([(name, loaded_weight)])
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
@@ -429,6 +427,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
||||
self.vision_tower = None
|
||||
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||
self.config.text_config.hidden_size = config.hidden_size
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
@@ -448,9 +447,9 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
||||
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
|
||||
if getattr(self.config, "vision_config", None) is None:
|
||||
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
||||
|
||||
if getattr(self.config, "text_config", None) is None:
|
||||
self.config.text_config = Qwen2Config(self.config._name_or_path)
|
||||
|
||||
@@ -459,7 +458,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
||||
|
||||
if getattr(self.config, "projector_hidden_act", None) is None:
|
||||
self.config.projector_hidden_act = "gelu"
|
||||
|
||||
if getattr(self.config, "image_token_index", None) is None:
|
||||
self.config.image_token_index = 151646
|
||||
|
||||
@@ -482,9 +480,9 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
|
||||
if getattr(self.config, "vision_config", None) is None:
|
||||
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
||||
|
||||
if getattr(self.config, "text_config", None) is None:
|
||||
self.config.text_config = MistralConfig(self.config._name_or_path)
|
||||
|
||||
@@ -493,7 +491,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||
|
||||
if getattr(self.config, "projector_hidden_act", None) is None:
|
||||
self.config.projector_hidden_act = "gelu"
|
||||
|
||||
if getattr(self.config, "image_token_index", None) is None:
|
||||
self.config.image_token_index = 32000
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
|
||||
|
||||
class LlavaVidForCausalLM(nn.Module):
|
||||
@@ -239,12 +239,12 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
||||
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
||||
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||
"model.image_newline": "language_model.model.image_newline",
|
||||
}
|
||||
params_dict = dict(self.named_parameters())
|
||||
weights = list(weights)
|
||||
for name, loaded_weight in weights:
|
||||
# FIXME: why projector weights read two times?
|
||||
if "projector" in name or "vision_tower" in name:
|
||||
if "projector" in name or "vision_tower" in name or "image_newline" in name:
|
||||
for weight_name, param_name in projector_weights.items():
|
||||
if weight_name in name:
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -255,9 +255,8 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
continue
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load language model
|
||||
self.language_model.load_weights(weights)
|
||||
else:
|
||||
self.language_model.load_weights([(name, loaded_weight)])
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
|
||||
@@ -15,12 +15,11 @@ limitations under the License.
|
||||
|
||||
"""Inference-only Mistral model."""
|
||||
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
|
||||
|
||||
class MistralForCausalLM(LlamaForCausalLM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = MistralForCausalLM
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path /model/llama-classification
|
||||
python3 -m sglang.launch_server --disable-cuda-graph --model-path /model/llama-classification
|
||||
|
||||
python3 test_httpserver_classify.py
|
||||
"""
|
||||
|
||||
@@ -3,23 +3,24 @@ Usage:
|
||||
python3 reference_hf.py --model TinyLlama/TinyLlama-1.1B-Chat-v0.4
|
||||
|
||||
Reference output:
|
||||
========== Prompt 0 ==========
|
||||
prefill logits (final) tensor([-8.3125, -7.1172, 3.3398, ..., -4.9531, -4.1328, -3.4141],
|
||||
device='cuda:0')
|
||||
<s> The capital of France is Paris.
|
||||
The capital of the United States is Washington, D.C.
|
||||
The capital of Canada is Ottawa.
|
||||
The capital of Japan is Tokyo
|
||||
prefill logits tensor([-8.3125, -7.1172, 3.3398, ..., -4.9570, -4.1328, -3.4141],
|
||||
|
||||
========== Prompt 1 ==========
|
||||
prefill logits (final) tensor([-8.9062, -9.0156, 4.1484, ..., -4.9922, -4.4961, -4.0742],
|
||||
device='cuda:0')
|
||||
<s> The capital of the United Kindom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
prefill logits tensor([-8.9062, -9.0156, 4.1406, ..., -4.9922, -4.4961, -4.0742],
|
||||
The capital of
|
||||
|
||||
========== Prompt 2 ==========
|
||||
prefill logits (final) tensor([-9.6328, -9.0547, 4.0234, ..., -5.3047, -4.7148, -4.4609],
|
||||
device='cuda:0')
|
||||
<s> Today is a sunny day and I like to go for a walk in the park.
|
||||
I'm going to the park to play in the grass and water.
|
||||
Today is a very
|
||||
prefill logits tensor([-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4609],
|
||||
device='cuda:0')
|
||||
I'm going to the
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -47,7 +48,7 @@ def normal_text(args):
|
||||
]
|
||||
max_new_tokens = 16
|
||||
|
||||
for p in prompts:
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = t.encode(p, return_tensors="pt").cuda()
|
||||
else:
|
||||
@@ -60,7 +61,8 @@ def normal_text(args):
|
||||
|
||||
prefill_logits = m.forward(input_ids).logits[0][-1]
|
||||
|
||||
print("prefill logits", prefill_logits)
|
||||
print(f"\n========== Prompt {i} ==========")
|
||||
print("prefill logits (final)", prefill_logits)
|
||||
print(output_str)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user