Clean up model loader (#1440)
This commit is contained in:
@@ -415,7 +415,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
# Truncate prompts that are too long
|
# Truncate prompts that are too long
|
||||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||||
logger.warn(
|
logger.warning(
|
||||||
"Request length is longer than the KV cache pool size or "
|
"Request length is longer than the KV cache pool size or "
|
||||||
"the max context length. Truncated!!!"
|
"the max context length. Truncated!!!"
|
||||||
)
|
)
|
||||||
@@ -936,6 +936,8 @@ class ModelTpServer:
|
|||||||
if success:
|
if success:
|
||||||
flash_cache_success = self.flush_cache()
|
flash_cache_success = self.flush_cache()
|
||||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||||
|
else:
|
||||||
|
logger.error(message)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -54,11 +54,9 @@ from sglang.srt.server_args import ServerArgs
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
is_generation_model,
|
is_generation_model,
|
||||||
is_llama3_405b_fp8_head_16,
|
|
||||||
is_multimodal_model,
|
is_multimodal_model,
|
||||||
monkey_patch_vllm_dummy_weight_loader,
|
monkey_patch_vllm_dummy_weight_loader,
|
||||||
monkey_patch_vllm_p2p_access_check,
|
monkey_patch_vllm_p2p_access_check,
|
||||||
monkey_patch_vllm_qvk_linear_loader,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -166,10 +164,13 @@ class ModelRunner:
|
|||||||
return min_per_gpu_memory
|
return min_per_gpu_memory
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
torch.set_num_threads(1)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This can reduce thread conflicts and speed up weight loading.
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
if torch.cuda.get_device_capability()[0] < 8:
|
if torch.cuda.get_device_capability()[0] < 8:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
||||||
@@ -178,6 +179,7 @@ class ModelRunner:
|
|||||||
if torch.cuda.get_device_capability()[1] < 5:
|
if torch.cuda.get_device_capability()[1] < 5:
|
||||||
raise RuntimeError("SGLang only supports sm75 and above.")
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
||||||
|
|
||||||
|
# Prepare the vllm model config
|
||||||
monkey_patch_vllm_dummy_weight_loader()
|
monkey_patch_vllm_dummy_weight_loader()
|
||||||
self.device_config = DeviceConfig()
|
self.device_config = DeviceConfig()
|
||||||
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
||||||
@@ -188,23 +190,16 @@ class ModelRunner:
|
|||||||
tokenizer_mode=None,
|
tokenizer_mode=None,
|
||||||
trust_remote_code=self.server_args.trust_remote_code,
|
trust_remote_code=self.server_args.trust_remote_code,
|
||||||
dtype=self.server_args.dtype,
|
dtype=self.server_args.dtype,
|
||||||
seed=42,
|
seed=self.server_args.random_seed,
|
||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
|
||||||
# Drop this after Sept, 2024.
|
|
||||||
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
|
||||||
self.model_config.hf_config.num_key_value_heads = 8
|
|
||||||
self.vllm_model_config.hf_config.num_key_value_heads = 8
|
|
||||||
monkey_patch_vllm_qvk_linear_loader()
|
|
||||||
|
|
||||||
self.dtype = self.vllm_model_config.dtype
|
|
||||||
if self.model_config.model_override_args is not None:
|
if self.model_config.model_override_args is not None:
|
||||||
self.vllm_model_config.hf_config.update(
|
self.vllm_model_config.hf_config.update(
|
||||||
self.model_config.model_override_args
|
self.model_config.model_override_args
|
||||||
)
|
)
|
||||||
|
self.dtype = self.vllm_model_config.dtype
|
||||||
|
|
||||||
|
# Load the model
|
||||||
self.model = get_model(
|
self.model = get_model(
|
||||||
model_config=self.vllm_model_config,
|
model_config=self.vllm_model_config,
|
||||||
load_config=self.load_config,
|
load_config=self.load_config,
|
||||||
@@ -255,20 +250,20 @@ class ModelRunner:
|
|||||||
tokenizer_mode=None,
|
tokenizer_mode=None,
|
||||||
trust_remote_code=self.server_args.trust_remote_code,
|
trust_remote_code=self.server_args.trust_remote_code,
|
||||||
dtype=self.server_args.dtype,
|
dtype=self.server_args.dtype,
|
||||||
seed=42,
|
seed=self.server_args.random_seed,
|
||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load model config: {e}")
|
message = f"Failed to load model config: {e}."
|
||||||
return False, "Failed to update model weights"
|
return False, message
|
||||||
|
|
||||||
load_config = LoadConfig(load_format=load_format)
|
load_config = LoadConfig(load_format=load_format)
|
||||||
|
|
||||||
# Only support vllm DefaultModelLoader for now
|
# Only support vllm DefaultModelLoader for now
|
||||||
loader = get_model_loader(load_config)
|
loader = get_model_loader(load_config)
|
||||||
if not isinstance(loader, DefaultModelLoader):
|
if not isinstance(loader, DefaultModelLoader):
|
||||||
logger.error("Failed to get weights iterator: Unsupported loader")
|
message = f"Failed to get model loader: {loader}."
|
||||||
return False, "Failed to update model weights"
|
return False, message
|
||||||
|
|
||||||
def get_weight_iter(config):
|
def get_weight_iter(config):
|
||||||
iter = loader._get_weights_iterator(
|
iter = loader._get_weights_iterator(
|
||||||
@@ -293,14 +288,14 @@ class ModelRunner:
|
|||||||
try:
|
try:
|
||||||
iter = get_weight_iter(vllm_model_config)
|
iter = get_weight_iter(vllm_model_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
message = f"Failed to get weights iterator: {e}"
|
message = f"Failed to get weights iterator: {e}."
|
||||||
logger.error(message)
|
|
||||||
return False, message
|
return False, message
|
||||||
try:
|
try:
|
||||||
model = model_load_weights(self.model, iter)
|
model = model_load_weights(self.model, iter)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
message = f"Failed to update weights: {e}. \n Rolling back to original weights"
|
message = (
|
||||||
logger.error(message)
|
f"Failed to update weights: {e}.\nRolling back to original weights."
|
||||||
|
)
|
||||||
del iter
|
del iter
|
||||||
gc.collect()
|
gc.collect()
|
||||||
iter = get_weight_iter(self.vllm_model_config)
|
iter = get_weight_iter(self.vllm_model_config)
|
||||||
@@ -315,7 +310,7 @@ class ModelRunner:
|
|||||||
self.model_config.path = model_path
|
self.model_config.path = model_path
|
||||||
|
|
||||||
logger.info("Update weights end.")
|
logger.info("Update weights end.")
|
||||||
return True, "Succeeded to update model weights"
|
return True, "Succeeded to update model weights."
|
||||||
|
|
||||||
def init_lora_manager(self):
|
def init_lora_manager(self):
|
||||||
self.lora_manager = LoRAManager(
|
self.lora_manager = LoRAManager(
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ async def flush_cache():
|
|||||||
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
||||||
|
|
||||||
success, message = await tokenizer_manager.update_weights(obj, request)
|
success, message = await tokenizer_manager.update_weights(obj, request)
|
||||||
content = {"message": message, "success": str(success)}
|
content = {"success": success, "message": message}
|
||||||
if success:
|
if success:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content,
|
content,
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ def allocate_init_ports(
|
|||||||
cur_port += 1
|
cur_port += 1
|
||||||
|
|
||||||
if port is not None and ret_ports[0] != port:
|
if port is not None and ret_ports[0] != port:
|
||||||
logger.warn(
|
logger.warning(
|
||||||
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -623,56 +623,7 @@ def set_ulimit(target_soft_limit=65535):
|
|||||||
try:
|
try:
|
||||||
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
|
logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||||
|
|
||||||
|
|
||||||
def is_llama3_405b_fp8_head_16(model_config):
|
|
||||||
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
|
|
||||||
if (
|
|
||||||
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
|
|
||||||
and model_config.hf_config.hidden_size == 16384
|
|
||||||
and model_config.hf_config.intermediate_size == 53248
|
|
||||||
and model_config.hf_config.num_hidden_layers == 126
|
|
||||||
and model_config.hf_config.num_key_value_heads == 16
|
|
||||||
and hasattr(model_config.hf_config, "quantization_config")
|
|
||||||
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_vllm_qvk_linear_loader():
|
|
||||||
"""A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints."""
|
|
||||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
|
||||||
|
|
||||||
origin_weight_loader = QKVParallelLinear.weight_loader
|
|
||||||
|
|
||||||
def get_original_weight(loaded_weight, head_dim):
|
|
||||||
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
|
|
||||||
dim = loaded_weight.shape[1]
|
|
||||||
for i in range(n_kv_head):
|
|
||||||
loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
|
|
||||||
2 * i * head_dim : (2 * i + 1) * head_dim, :
|
|
||||||
]
|
|
||||||
original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
|
|
||||||
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
|
|
||||||
return original_kv_weight
|
|
||||||
|
|
||||||
def weight_loader_srt(
|
|
||||||
self,
|
|
||||||
param: Parameter,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
loaded_shard_id: Optional[str] = None,
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
loaded_shard_id in ["k", "v"]
|
|
||||||
and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
|
|
||||||
):
|
|
||||||
loaded_weight = get_original_weight(loaded_weight, self.head_size)
|
|
||||||
|
|
||||||
origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
|
|
||||||
|
|
||||||
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
|
|
||||||
|
|
||||||
|
|
||||||
def add_api_key_middleware(app, api_key: str):
|
def add_api_key_middleware(app, api_key: str):
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ class TestReplaceWeights(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
print(json.dumps(response.json()))
|
print(json.dumps(response.json()))
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
# return the "text" in response
|
|
||||||
text = response.json()["text"]
|
text = response.json()["text"]
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@@ -61,7 +60,9 @@ class TestReplaceWeights(unittest.TestCase):
|
|||||||
"model_path": model_path,
|
"model_path": model_path,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
ret = response.json()
|
||||||
print(json.dumps(response.json()))
|
print(json.dumps(response.json()))
|
||||||
|
return ret
|
||||||
|
|
||||||
def test_replace_weights(self):
|
def test_replace_weights(self):
|
||||||
origin_model_path = self.get_model_info()
|
origin_model_path = self.get_model_info()
|
||||||
@@ -70,7 +71,8 @@ class TestReplaceWeights(unittest.TestCase):
|
|||||||
|
|
||||||
# update weights
|
# update weights
|
||||||
new_model_path = "meta-llama/Meta-Llama-3.1-8B"
|
new_model_path = "meta-llama/Meta-Llama-3.1-8B"
|
||||||
self.run_update_weights(new_model_path)
|
ret = self.run_update_weights(new_model_path)
|
||||||
|
assert ret["success"]
|
||||||
|
|
||||||
updated_model_path = self.get_model_info()
|
updated_model_path = self.get_model_info()
|
||||||
print(f"updated_model_path: {updated_model_path}")
|
print(f"updated_model_path: {updated_model_path}")
|
||||||
@@ -81,7 +83,9 @@ class TestReplaceWeights(unittest.TestCase):
|
|||||||
assert origin_response[:32] != updated_response[:32]
|
assert origin_response[:32] != updated_response[:32]
|
||||||
|
|
||||||
# update weights back
|
# update weights back
|
||||||
self.run_update_weights(origin_model_path)
|
ret = self.run_update_weights(origin_model_path)
|
||||||
|
assert ret["success"]
|
||||||
|
|
||||||
updated_model_path = self.get_model_info()
|
updated_model_path = self.get_model_info()
|
||||||
assert updated_model_path == origin_model_path
|
assert updated_model_path == origin_model_path
|
||||||
|
|
||||||
@@ -95,7 +99,8 @@ class TestReplaceWeights(unittest.TestCase):
|
|||||||
|
|
||||||
# update weights
|
# update weights
|
||||||
new_model_path = "meta-llama/Meta-Llama-3.1-8B-1"
|
new_model_path = "meta-llama/Meta-Llama-3.1-8B-1"
|
||||||
self.run_update_weights(new_model_path)
|
ret = self.run_update_weights(new_model_path)
|
||||||
|
assert not ret["success"]
|
||||||
|
|
||||||
updated_model_path = self.get_model_info()
|
updated_model_path = self.get_model_info()
|
||||||
print(f"updated_model_path: {updated_model_path}")
|
print(f"updated_model_path: {updated_model_path}")
|
||||||
|
|||||||
Reference in New Issue
Block a user