[Fix] Reduce memory usage for loading llava model & Remove EntryClassRemapping (#1308)

This commit is contained in:
Lianmin Zheng
2024-09-02 21:44:45 -07:00
committed by GitHub
parent a5a134f39f
commit f64eae3a29
17 changed files with 105 additions and 158 deletions

View File

@@ -4,7 +4,7 @@ from typing import List, Optional
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import (
@@ -23,6 +23,7 @@ class RuntimeEndpoint(BaseBackend):
base_url: str,
api_key: Optional[str] = None,
verify: Optional[str] = None,
chat_template_name: Optional[str] = None,
):
super().__init__()
self.support_concate_and_append = True
@@ -39,9 +40,12 @@ class RuntimeEndpoint(BaseBackend):
self._assert_success(res)
self.model_info = res.json()
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_path"]
)
if chat_template_name:
self.chat_template = get_chat_template(chat_template_name)
else:
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_path"]
)
def get_model_name(self):
return self.model_info["model_path"]

View File

@@ -86,8 +86,8 @@ class TokenizerManager:
self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_router = context.socket(zmq.PUSH)
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
self.send_to_controller = context.socket(zmq.PUSH)
self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
# Read model args
self.model_path = server_args.model_path
@@ -271,7 +271,7 @@ class TokenizerManager:
input_ids,
sampling_params,
)
self.send_to_router.send_pyobj(tokenized_obj)
self.send_to_controller.send_pyobj(tokenized_obj)
# Recv results
event = asyncio.Event()
@@ -367,7 +367,7 @@ class TokenizerManager:
input_ids,
sampling_params,
)
self.send_to_router.send_pyobj(tokenized_obj)
self.send_to_controller.send_pyobj(tokenized_obj)
event = asyncio.Event()
state = ReqState([], False, event)
@@ -500,14 +500,14 @@ class TokenizerManager:
def flush_cache(self):
req = FlushCacheReq()
self.send_to_router.send_pyobj(req)
self.send_to_controller.send_pyobj(req)
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_router.send_pyobj(req)
self.send_to_controller.send_pyobj(req)
async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
@@ -524,7 +524,7 @@ class TokenizerManager:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0)
self.send_to_router.send_pyobj(obj)
self.send_to_controller.send_pyobj(obj)
self.model_update_result = asyncio.Future()
result = await self.model_update_result
if result.success:

View File

@@ -606,16 +606,6 @@ def import_model_classes():
assert entry.__name__ not in model_arch_name_to_cls
model_arch_name_to_cls[entry.__name__] = entry
# compat: some models such as chatglm has incorrect class set in config.json
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
if hasattr(module, "EntryClassRemapping") and isinstance(
module.EntryClassRemapping, list
):
for remap in module.EntryClassRemapping:
if isinstance(remap, tuple) and len(remap) == 2:
assert remap[0] not in model_arch_name_to_cls
model_arch_name_to_cls[remap[0]] = remap[1]
return model_arch_name_to_cls

View File

@@ -402,6 +402,8 @@ class ChatGLMForCausalLM(nn.Module):
weight_loader(param, loaded_weight)
EntryClass = ChatGLMForCausalLM
# compat: glm model.config class == ChatGLMModel
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
class ChatGLMModel(ChatGLMForCausalLM):
pass
EntryClass = [ChatGLMForCausalLM, ChatGLMModel]

View File

@@ -297,7 +297,6 @@ class ExaoneForCausalLM(nn.Module):
config,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
efficient_weight_load=False,
) -> None:
super().__init__()
self.config = config
@@ -345,9 +344,7 @@ class ExaoneForCausalLM(nn.Module):
params_dict = dict(self.named_parameters())
return len(params_dict)
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -358,7 +355,7 @@ class ExaoneForCausalLM(nn.Module):
]
params_dict = dict(self.named_parameters())
def load_weights_per_param(name, loaded_weight):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
return
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -368,6 +365,7 @@ class ExaoneForCausalLM(nn.Module):
if name.startswith("model.vision_tower") and name not in params_dict:
return
name = name.replace("attn.attention", "self_attn")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
@@ -387,13 +385,5 @@ class ExaoneForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if name is None or loaded_weight is None:
for name, loaded_weight in weights:
name = name.replace("attn.attention", "self_attn")
load_weights_per_param(name, loaded_weight)
else:
name = name.replace("attn.attention", "self_attn")
load_weights_per_param(name, loaded_weight)
EntryClass = ExaoneForCausalLM

View File

@@ -295,7 +295,6 @@ class LlamaForCausalLM(nn.Module):
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
efficient_weight_load=False,
) -> None:
super().__init__()
self.config = config
@@ -305,6 +304,8 @@ class LlamaForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.param_dict = dict(self.named_parameters())
@torch.no_grad()
def forward(
self,
@@ -320,30 +321,7 @@ class LlamaForCausalLM(nn.Module):
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_module_name(self, name):
stacked_params_mapping = [
# (param_name, shard_name, shard_id, num_shard)
("qkv_proj", "q_proj", "q", 3),
("qkv_proj", "k_proj", "k", 3),
("qkv_proj", "v_proj", "v", 3),
("gate_up_proj", "gate_proj", 0, 2),
("gate_up_proj", "up_proj", 1, 2),
]
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
if weight_name in name:
return (
name.replace(weight_name, param_name)[: -len(".weight")],
num_shard,
)
return name[: -len(".weight")], 1
def get_num_params(self):
params_dict = dict(self.named_parameters())
return len(params_dict)
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -352,9 +330,9 @@ class LlamaForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
params_dict = self.param_dict
def load_weights_per_param(name, loaded_weight):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
return
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -383,11 +361,5 @@ class LlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if name is None or loaded_weight is None:
for name, loaded_weight in weights:
load_weights_per_param(name, loaded_weight)
else:
load_weights_per_param(name, loaded_weight)
EntryClass = LlamaForCausalLM

View File

@@ -16,17 +16,16 @@ limitations under the License.
from typing import Iterable, Optional, Tuple
import torch
import tqdm
from torch import nn
from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama2 import LlamaModel
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
class LlamaForClassification(nn.Module):
@@ -42,10 +41,12 @@ class LlamaForClassification(nn.Module):
self.model = LlamaModel(config, quant_config=quant_config)
self.classification_head = nn.Linear(
config.hidden_size, config.classification_out_size
config.hidden_size, config.classification_out_size, bias=False
)
self.eos_token_id = config.eos_token_id
self.param_dict = dict(self.named_parameters())
@torch.no_grad()
def forward(
self,
@@ -65,7 +66,7 @@ class LlamaForClassification(nn.Module):
(input_metadata.batch_size, self.config.classification_out_size)
).to(input_ids.device)
return LogitsProcessorOutput(
logits_output = LogitsProcessorOutput(
next_token_logits=scores,
next_token_logprobs=scores,
normalized_prompt_logprobs=scores,
@@ -74,46 +75,38 @@ class LlamaForClassification(nn.Module):
output_top_logprobs=None,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if "lm_head" in name:
continue
# A dummy to make this work
sample_output = SampleOutput(
success=torch.full(
size=(scores.shape[0],),
fill_value=True,
dtype=torch.bool,
),
probs=torch.full(
size=(scores.shape[0], 1),
fill_value=1.0,
dtype=torch.float16,
),
batch_next_token_ids=torch.full(
size=(scores.shape[0],),
fill_value=0,
dtype=torch.long,
),
)
return sample_output, logits_output
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = self.param_dict
for name, loaded_weight in weights:
if "classification_head" in name:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
elif "lm_head" in name:
continue
else:
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
EntryClass = LlamaForClassification

View File

@@ -1,4 +1,4 @@
from typing import Iterable, Optional, Tuple
from typing import Iterable, Tuple
import torch
from torch import nn
@@ -7,7 +7,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.model_executor.model_runner import InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel
from sglang.srt.models.llama import LlamaModel
class LlamaEmbeddingModel(nn.Module):
@@ -16,7 +16,6 @@ class LlamaEmbeddingModel(nn.Module):
config: LlamaConfig,
quant_config=None,
cache_config=None,
efficient_weight_load=False,
) -> None:
super().__init__()
self.model = LlamaModel(config, quant_config=quant_config)
@@ -86,6 +85,8 @@ class LlamaEmbeddingModel(nn.Module):
load_weights_per_param(name, loaded_weight)
EntryClass = LlamaEmbeddingModel
# compat: e5-mistral model.config class == MistralModel
EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
class MistralModel(LlamaEmbeddingModel):
pass
EntryClass = [LlamaEmbeddingModel, MistralModel]

View File

@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -395,21 +395,19 @@ class LlavaBaseForCausalLM(nn.Module):
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline": "language_model.model.image_newline",
}
params_dict = dict(self.named_parameters())
weights = list(weights)
for name, loaded_weight in weights:
# FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name:
if "projector" in name or "vision_tower" in name or "image_newline" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
# load language model
self.language_model.load_weights(weights)
else:
self.language_model.load_weights([(name, loaded_weight)])
@property
def num_patches_per_side(self):
@@ -429,6 +427,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
self.vision_tower = None
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
@@ -448,9 +447,9 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
self.config = config
self.vision_tower = None
if getattr(self.config, "vision_config", None) is None:
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
if getattr(self.config, "text_config", None) is None:
self.config.text_config = Qwen2Config(self.config._name_or_path)
@@ -459,7 +458,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 151646
@@ -482,9 +480,9 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
self.config = config
self.vision_tower = None
if getattr(self.config, "vision_config", None) is None:
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
if getattr(self.config, "text_config", None) is None:
self.config.text_config = MistralConfig(self.config._name_or_path)
@@ -493,7 +491,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 32000

View File

@@ -27,7 +27,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.llama import LlamaForCausalLM
class LlavaVidForCausalLM(nn.Module):
@@ -239,12 +239,12 @@ class LlavaVidForCausalLM(nn.Module):
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline": "language_model.model.image_newline",
}
params_dict = dict(self.named_parameters())
weights = list(weights)
for name, loaded_weight in weights:
# FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name:
if "projector" in name or "vision_tower" in name or "image_newline" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
@@ -255,9 +255,8 @@ class LlavaVidForCausalLM(nn.Module):
continue
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
# load language model
self.language_model.load_weights(weights)
else:
self.language_model.load_weights([(name, loaded_weight)])
@property
def num_patches_per_side(self):

View File

@@ -15,12 +15,11 @@ limitations under the License.
"""Inference-only Mistral model."""
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.llama import LlamaForCausalLM
class MistralForCausalLM(LlamaForCausalLM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
pass
EntryClass = MistralForCausalLM