################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ import torch from torch import nn from vllm.attention import Attention from vllm.config import ModelConfig from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase) def process_weights_after_loading(model: nn.Module, model_config: ModelConfig, target_device: torch.device) -> None: for _, module in model.named_modules(): if isinstance(module, QKVCrossParallelLinear): # NOTE(Isotr0py): special case for cross QKV layer because # q and kv proj aren't registered as submodules intentionally module.process_weights_after_loading() torch.supa.empty_cache() continue quant_method = getattr(module, "quant_method", None) if isinstance(quant_method, QuantizeMethodBase): # When quant methods need to process weights after loading # (for repacking, quantizing, etc), they expect parameters # to be on the global target device. This scope is for the # case where cpu offloading is used, where we will move the # parameters onto device for processing and back off after. # with device_loading_context(module, target_device): quant_method.weight_type = model_config.weight_type quant_method.use_ds_mla = model_config.use_ds_mla quant_method.use_ds_mla_sparse = model_config.use_ds_mla_sparse quant_method.process_weights_after_loading(module) torch.supa.empty_cache() # Currently only used by MLA. # NOTE: This intentionally happens after other modules so we can easily # decompress the weights for MLA. for _, module in model.named_modules(): if isinstance(module, Attention) and \ hasattr(module, "process_weights_after_loading"): # TODO(lucas): see if there is a way to unify the signatures # of process_weights_after_loading module.process_weights_after_loading(model_config.dtype) torch.supa.empty_cache()