[Fix] Reduce memory usage for loading llava model & Remove EntryClassRemapping (#1308)
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import List, Optional
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import (
|
||||
@@ -23,6 +23,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
base_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
verify: Optional[str] = None,
|
||||
chat_template_name: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.support_concate_and_append = True
|
||||
@@ -39,9 +40,12 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self._assert_success(res)
|
||||
self.model_info = res.json()
|
||||
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_path"]
|
||||
)
|
||||
if chat_template_name:
|
||||
self.chat_template = get_chat_template(chat_template_name)
|
||||
else:
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_path"]
|
||||
)
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_path"]
|
||||
|
||||
@@ -86,8 +86,8 @@ class TokenizerManager:
|
||||
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
|
||||
self.send_to_router = context.socket(zmq.PUSH)
|
||||
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||
self.send_to_controller = context.socket(zmq.PUSH)
|
||||
self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
@@ -271,7 +271,7 @@ class TokenizerManager:
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
self.send_to_controller.send_pyobj(tokenized_obj)
|
||||
|
||||
# Recv results
|
||||
event = asyncio.Event()
|
||||
@@ -367,7 +367,7 @@ class TokenizerManager:
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
self.send_to_controller.send_pyobj(tokenized_obj)
|
||||
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event)
|
||||
@@ -500,14 +500,14 @@ class TokenizerManager:
|
||||
|
||||
def flush_cache(self):
|
||||
req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(req)
|
||||
self.send_to_controller.send_pyobj(req)
|
||||
|
||||
def abort_request(self, rid: str):
|
||||
if rid not in self.rid_to_state:
|
||||
return
|
||||
del self.rid_to_state[rid]
|
||||
req = AbortReq(rid)
|
||||
self.send_to_router.send_pyobj(req)
|
||||
self.send_to_controller.send_pyobj(req)
|
||||
|
||||
async def update_weights(
|
||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||
@@ -524,7 +524,7 @@ class TokenizerManager:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0)
|
||||
self.send_to_router.send_pyobj(obj)
|
||||
self.send_to_controller.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
result = await self.model_update_result
|
||||
if result.success:
|
||||
|
||||
@@ -606,16 +606,6 @@ def import_model_classes():
|
||||
assert entry.__name__ not in model_arch_name_to_cls
|
||||
model_arch_name_to_cls[entry.__name__] = entry
|
||||
|
||||
# compat: some models such as chatglm has incorrect class set in config.json
|
||||
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
|
||||
if hasattr(module, "EntryClassRemapping") and isinstance(
|
||||
module.EntryClassRemapping, list
|
||||
):
|
||||
for remap in module.EntryClassRemapping:
|
||||
if isinstance(remap, tuple) and len(remap) == 2:
|
||||
assert remap[0] not in model_arch_name_to_cls
|
||||
model_arch_name_to_cls[remap[0]] = remap[1]
|
||||
|
||||
return model_arch_name_to_cls
|
||||
|
||||
|
||||
|
||||
@@ -402,6 +402,8 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = ChatGLMForCausalLM
|
||||
# compat: glm model.config class == ChatGLMModel
|
||||
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
|
||||
class ChatGLMModel(ChatGLMForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [ChatGLMForCausalLM, ChatGLMModel]
|
||||
|
||||
@@ -297,7 +297,6 @@ class ExaoneForCausalLM(nn.Module):
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
efficient_weight_load=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -345,9 +344,7 @@ class ExaoneForCausalLM(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
||||
):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -358,7 +355,7 @@ class ExaoneForCausalLM(nn.Module):
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
def load_weights_per_param(name, loaded_weight):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
return
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
@@ -368,6 +365,7 @@ class ExaoneForCausalLM(nn.Module):
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
return
|
||||
|
||||
name = name.replace("attn.attention", "self_attn")
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@@ -387,13 +385,5 @@ class ExaoneForCausalLM(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name is None or loaded_weight is None:
|
||||
for name, loaded_weight in weights:
|
||||
name = name.replace("attn.attention", "self_attn")
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
else:
|
||||
name = name.replace("attn.attention", "self_attn")
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = ExaoneForCausalLM
|
||||
|
||||
@@ -295,7 +295,6 @@ class LlamaForCausalLM(nn.Module):
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
efficient_weight_load=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -305,6 +304,8 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
@@ -320,30 +321,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def get_module_name(self, name):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id, num_shard)
|
||||
("qkv_proj", "q_proj", "q", 3),
|
||||
("qkv_proj", "k_proj", "k", 3),
|
||||
("qkv_proj", "v_proj", "v", 3),
|
||||
("gate_up_proj", "gate_proj", 0, 2),
|
||||
("gate_up_proj", "up_proj", 1, 2),
|
||||
]
|
||||
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
return (
|
||||
name.replace(weight_name, param_name)[: -len(".weight")],
|
||||
num_shard,
|
||||
)
|
||||
return name[: -len(".weight")], 1
|
||||
|
||||
def get_num_params(self):
|
||||
params_dict = dict(self.named_parameters())
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
||||
):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -352,9 +330,9 @@ class LlamaForCausalLM(nn.Module):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
params_dict = self.param_dict
|
||||
|
||||
def load_weights_per_param(name, loaded_weight):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
return
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
@@ -383,11 +361,5 @@ class LlamaForCausalLM(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name is None or loaded_weight is None:
|
||||
for name, loaded_weight in weights:
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
else:
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = LlamaForCausalLM
|
||||
@@ -16,17 +16,16 @@ limitations under the License.
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaModel
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
|
||||
class LlamaForClassification(nn.Module):
|
||||
@@ -42,10 +41,12 @@ class LlamaForClassification(nn.Module):
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
|
||||
self.classification_head = nn.Linear(
|
||||
config.hidden_size, config.classification_out_size
|
||||
config.hidden_size, config.classification_out_size, bias=False
|
||||
)
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
@@ -65,7 +66,7 @@ class LlamaForClassification(nn.Module):
|
||||
(input_metadata.batch_size, self.config.classification_out_size)
|
||||
).to(input_ids.device)
|
||||
|
||||
return LogitsProcessorOutput(
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=scores,
|
||||
next_token_logprobs=scores,
|
||||
normalized_prompt_logprobs=scores,
|
||||
@@ -74,46 +75,38 @@ class LlamaForClassification(nn.Module):
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
continue
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if "lm_head" in name:
|
||||
continue
|
||||
# A dummy to make this work
|
||||
sample_output = SampleOutput(
|
||||
success=torch.full(
|
||||
size=(scores.shape[0],),
|
||||
fill_value=True,
|
||||
dtype=torch.bool,
|
||||
),
|
||||
probs=torch.full(
|
||||
size=(scores.shape[0], 1),
|
||||
fill_value=1.0,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
batch_next_token_ids=torch.full(
|
||||
size=(scores.shape[0],),
|
||||
fill_value=0,
|
||||
dtype=torch.long,
|
||||
),
|
||||
)
|
||||
return sample_output, logits_output
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = self.param_dict
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "classification_head" in name:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
elif "lm_head" in name:
|
||||
continue
|
||||
else:
|
||||
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
|
||||
|
||||
|
||||
EntryClass = LlamaForClassification
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Iterable, Optional, Tuple
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -7,7 +7,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel
|
||||
from sglang.srt.models.llama import LlamaModel
|
||||
|
||||
|
||||
class LlamaEmbeddingModel(nn.Module):
|
||||
@@ -16,7 +16,6 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
config: LlamaConfig,
|
||||
quant_config=None,
|
||||
cache_config=None,
|
||||
efficient_weight_load=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
@@ -86,6 +85,8 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = LlamaEmbeddingModel
|
||||
# compat: e5-mistral model.config class == MistralModel
|
||||
EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
|
||||
class MistralModel(LlamaEmbeddingModel):
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [LlamaEmbeddingModel, MistralModel]
|
||||
|
||||
@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
|
||||
@@ -395,21 +395,19 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||
"model.image_newline": "language_model.model.image_newline",
|
||||
}
|
||||
params_dict = dict(self.named_parameters())
|
||||
weights = list(weights)
|
||||
for name, loaded_weight in weights:
|
||||
# FIXME: why projector weights read two times?
|
||||
if "projector" in name or "vision_tower" in name:
|
||||
if "projector" in name or "vision_tower" in name or "image_newline" in name:
|
||||
for weight_name, param_name in projector_weights.items():
|
||||
if weight_name in name:
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load language model
|
||||
self.language_model.load_weights(weights)
|
||||
else:
|
||||
self.language_model.load_weights([(name, loaded_weight)])
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
@@ -429,6 +427,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
||||
self.vision_tower = None
|
||||
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||
self.config.text_config.hidden_size = config.hidden_size
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
@@ -448,9 +447,9 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
||||
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
|
||||
if getattr(self.config, "vision_config", None) is None:
|
||||
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
||||
|
||||
if getattr(self.config, "text_config", None) is None:
|
||||
self.config.text_config = Qwen2Config(self.config._name_or_path)
|
||||
|
||||
@@ -459,7 +458,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
||||
|
||||
if getattr(self.config, "projector_hidden_act", None) is None:
|
||||
self.config.projector_hidden_act = "gelu"
|
||||
|
||||
if getattr(self.config, "image_token_index", None) is None:
|
||||
self.config.image_token_index = 151646
|
||||
|
||||
@@ -482,9 +480,9 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
|
||||
if getattr(self.config, "vision_config", None) is None:
|
||||
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
||||
|
||||
if getattr(self.config, "text_config", None) is None:
|
||||
self.config.text_config = MistralConfig(self.config._name_or_path)
|
||||
|
||||
@@ -493,7 +491,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||
|
||||
if getattr(self.config, "projector_hidden_act", None) is None:
|
||||
self.config.projector_hidden_act = "gelu"
|
||||
|
||||
if getattr(self.config, "image_token_index", None) is None:
|
||||
self.config.image_token_index = 32000
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
|
||||
|
||||
class LlavaVidForCausalLM(nn.Module):
|
||||
@@ -239,12 +239,12 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
||||
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
||||
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||
"model.image_newline": "language_model.model.image_newline",
|
||||
}
|
||||
params_dict = dict(self.named_parameters())
|
||||
weights = list(weights)
|
||||
for name, loaded_weight in weights:
|
||||
# FIXME: why projector weights read two times?
|
||||
if "projector" in name or "vision_tower" in name:
|
||||
if "projector" in name or "vision_tower" in name or "image_newline" in name:
|
||||
for weight_name, param_name in projector_weights.items():
|
||||
if weight_name in name:
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -255,9 +255,8 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
continue
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load language model
|
||||
self.language_model.load_weights(weights)
|
||||
else:
|
||||
self.language_model.load_weights([(name, loaded_weight)])
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
|
||||
@@ -15,12 +15,11 @@ limitations under the License.
|
||||
|
||||
"""Inference-only Mistral model."""
|
||||
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
|
||||
|
||||
class MistralForCausalLM(LlamaForCausalLM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = MistralForCausalLM
|
||||
|
||||
Reference in New Issue
Block a user