################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from typing import Literal, Optional, Union import torch import torch.nn.functional as F import torch_br import torch_br.supa._debug as supa_debug from fastcore.basics import patch_to from torch.nn.parameter import Parameter from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, split_tensor_along_last_dim, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import get_pp_group from vllm.logger import logger from vllm.model_executor.layers.linear import (adjust_bitsandbytes_4bit_shard, adjust_marlin_shard, adjust_scalar_to_fused_array) from vllm_br import envs from vllm_br.utils import get_grandparent_pid from .br_utils import (_convert_to_crossed_numa_tensor, _convert_to_numa_tensor, _convert_to_numa_tensor_vit, is_br166_device) from vllm.model_executor.layers.linear import ( # isort:skip LinearBase, MergedColumnParallelLinear, QuantizationConfig, ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod, QKVParallelLinear) def _should_skip_linear_post_process(layer, use_ds_mla, use_ds_mla_sparse): """NOTE: SUPA: for MLA linears, we do process in MLA.process_weights_after_loading """ # TODO: Hard code for native dsa op if use_ds_mla_sparse: MLA_LINEAR_NAMES = [ "kv_b_proj", ] else: MLA_LINEAR_NAMES = [ "q_a_proj", "q_b_proj", # "q_proj", "kv_a_proj_with_mqa", "kv_b_proj", # "o_proj", ] if use_ds_mla and not use_ds_mla_sparse: MLA_LINEAR_NAMES.append("o_proj") skip = any(k in layer.prefix for k in MLA_LINEAR_NAMES) if skip: logger.debug( f'[SUPA] skip {layer.prefix} UnquantizedLinearMethod.process_weights_after_loading' # noqa: G004 ) return skip # NOTE: ReplicatedLinear, usually used in MoE as a gate module. # In DeepseekV3, it needs to be transposed. def process_weights_ReplicatedLinear( layer: ReplicatedLinear) -> Literal[True, False]: layer.weight.data = layer.weight.data.transpose(1, 0).contiguous() return True def process_share_expert_weight(layer: MergedColumnParallelLinear): gate_up_weight = layer.weight.transpose(1, 0).contiguous() gate_weight, up_weight = torch.chunk(gate_up_weight, 2, dim=-1) die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM is_br166 = die_spc_num > 16 spc_num = die_spc_num // 2 if is_br166 else die_spc_num if is_br166: # 2&2 for 4spc, 4&8 for 12spc, 8&8 for 16spc spc_for_shared = 2 if spc_num == 4 else 8 spc_for_router = spc_num - spc_for_shared align_size = 32 weight_dtype = gate_weight.dtype hidden_size = gate_weight.shape[0] gate_d0, gate_d1 = torch.chunk(gate_weight, 2, dim=-1) up_d0, up_d1 = torch.chunk(up_weight, 2, dim=-1) im_size = gate_d0.shape[-1] n_align_size = (align_size * 2) * spc_for_shared swiglu_w_aligned = (( (im_size * 2) + n_align_size - 1) // n_align_size) * n_align_size region_size = swiglu_w_aligned // spc_for_shared block_nums = (region_size // (align_size * 2)) * spc_for_shared gate_d0_align = torch.nn.functional.pad( gate_d0, (0, (swiglu_w_aligned // 2) - im_size, 0, 0), mode='constant', value=0) gate_d1_align = torch.nn.functional.pad( gate_d1, (0, (swiglu_w_aligned // 2) - im_size, 0, 0), mode='constant', value=0) up_d0_align = torch.nn.functional.pad( up_d0, (0, (swiglu_w_aligned // 2) - im_size, 0, 0), mode='constant', value=0) up_d1_align = torch.nn.functional.pad( up_d1, (0, (swiglu_w_aligned // 2) - im_size, 0, 0), mode='constant', value=0) gate_weight_d0_reshape = gate_d0_align.reshape( hidden_size, block_nums, align_size).contiguous().to(weight_dtype) gate_weight_d1_reshape = gate_d1_align.reshape( hidden_size, block_nums, align_size).contiguous().to(weight_dtype) up_weight_d0_reshape = up_d0_align.reshape( hidden_size, block_nums, align_size).contiguous().to(weight_dtype) up_weight_d1_reshape = up_d1_align.reshape( hidden_size, block_nums, align_size).contiguous().to(weight_dtype) gate_up_weight_d0 = torch.zeros( [hidden_size, block_nums, align_size * 2], dtype=weight_dtype, device='supa') gate_up_weight_d0[:, :, 0:0 + align_size] = gate_weight_d0_reshape[:, :, 0:align_size] gate_up_weight_d0[:, :, align_size:align_size * 2] = up_weight_d0_reshape[:, :, 0:align_size] gate_up_weight_d0 = gate_up_weight_d0.reshape( hidden_size, spc_for_shared, region_size).permute(1, 0, 2).contiguous().to(weight_dtype) gate_up_d0_invalid = torch.zeros( [spc_for_router, hidden_size, region_size], dtype=weight_dtype, device='supa') # invalid regions gate_up_weight_d0_whole = torch.cat( [gate_up_weight_d0, gate_up_d0_invalid], dim=0) gate_up_weight_d1 = torch.zeros( [hidden_size, block_nums, align_size * 2], dtype=weight_dtype, device='supa') gate_up_weight_d1[:, :, 0:0 + align_size] = gate_weight_d1_reshape[:, :, 0:align_size] gate_up_weight_d1[:, :, align_size:align_size * 2] = up_weight_d1_reshape[:, :, 0:align_size] gate_up_weight_d1 = gate_up_weight_d1.reshape( hidden_size, spc_for_shared, region_size).permute(1, 0, 2).contiguous().to(weight_dtype) gate_up_d1_invalid = torch.zeros( [spc_for_router, hidden_size, region_size], dtype=weight_dtype, device='supa') # invalid regions gate_up_weight_d1_whole = torch.cat( [gate_up_weight_d1, gate_up_d1_invalid], dim=0) gate_up_weight_whole = torch.cat( [gate_up_weight_d0_whole, gate_up_weight_d1_whole], dim=0) gate_up_weight_supa = torch_br._empty_ut_only( size=gate_up_weight_whole.shape, dtype=gate_weight.dtype, is_numa=True, device="supa", tensor_type="colmajor") gate_up_weight_supa.copy_(gate_up_weight_whole) layer.weight.data = gate_up_weight_supa else: # 2&2 for 4spc, 4&8 for 12spc, 8&8 for 16spc spc_for_shared = 2 if spc_num == 4 else 8 spc_for_router = spc_num - spc_for_shared align_size = 32 weight_dtype = gate_weight.dtype hidden_size = gate_weight.shape[0] im_size = gate_weight.shape[-1] n_align_size = (align_size * 2) * spc_for_shared swiglu_w_aligned = (( (im_size * 2) + n_align_size - 1) // n_align_size) * n_align_size region_size = swiglu_w_aligned // spc_for_shared block_nums = (region_size // (align_size * 2)) * spc_for_shared gate_golden_align = torch.nn.functional.pad( gate_weight, (0, (swiglu_w_aligned // 2) - im_size, 0, 0), mode='constant', value=0) up_golden_align = torch.nn.functional.pad( up_weight, (0, (swiglu_w_aligned // 2) - im_size, 0, 0), mode='constant', value=0) gate_weight_golden_reshape = gate_golden_align.reshape( hidden_size, block_nums, align_size).contiguous().to(weight_dtype) up_weight_golden_reshape = up_golden_align.reshape( hidden_size, block_nums, align_size).contiguous().to(weight_dtype) gate_up_weight_golden = torch.zeros( [hidden_size, block_nums, align_size * 2], dtype=weight_dtype, device='supa') gate_up_weight_golden[:, :, 0:0 + align_size] = gate_weight_golden_reshape[:, :, 0: align_size] gate_up_weight_golden[:, :, align_size:align_size * 2] = up_weight_golden_reshape[:, :, 0:align_size] gate_up_weight_golden = gate_up_weight_golden.reshape( hidden_size, spc_for_shared, region_size).permute(1, 0, 2).contiguous().to(weight_dtype) gate_up_invalid = torch.zeros( [spc_for_router, hidden_size, region_size], dtype=weight_dtype, device='supa') # invalid regions gate_up_weight_whole = torch.cat( [gate_up_weight_golden, gate_up_invalid], dim=0) gate_up_weight_supa = torch_br._empty_ut_only( size=gate_up_weight_whole.shape, dtype=gate_weight.dtype, is_numa=True, device="supa", tensor_type="colmajor") gate_up_weight_supa.copy_(gate_up_weight_whole) layer.weight.data = gate_up_weight_supa # NOTE: MergedColumnParallelLinear, usually used in MergedGateUpMLPSiluL2 def process_weights_QuantMergedColumnParallelLinear( layer: MergedColumnParallelLinear): if 'shared_experts' not in layer.prefix: #NOTE: normal MLP gate_up, after load weight, convert to supa numa tensor if hasattr(layer, "qweight"): gate_weight, up_weight = torch.chunk(layer.qweight, 2, dim=-1) gate_up_weight_numa = _convert_to_crossed_numa_tensor( gate_weight, up_weight, envs.VLLM_BR_DEVICE_SPC_NUM, dim=-1, need_pad=True, do_transpose=False) layer.qweight.data = gate_up_weight_numa else: gate_up_weight = layer.weight.permute(1, 0).contiguous() gate_up_weight_numa = _convert_to_numa_tensor( gate_up_weight, 32, "colmajor", gate_up_weight.dtype, False, parallel_type="col_parallel") layer.weight.data = gate_up_weight_numa if hasattr(layer, "scales") and layer.scales is not None: gate_scales, up_scales = torch.chunk(layer.scales, 2, dim=-1) gate_up_scales_internleaved_numa = _convert_to_crossed_numa_tensor( gate_scales, up_scales, envs.VLLM_BR_DEVICE_SPC_NUM, dim=-1, need_pad=False, layout="linear_bias", do_transpose=False) layer.scales.data = gate_up_scales_internleaved_numa if hasattr(layer, "bias") and layer.bias is not None: gate_bias, up_bias = torch.chunk(layer.bias, 2, dim=-1) gate_up_bias_internleaved_numa = _convert_to_crossed_numa_tensor( gate_bias, up_bias, envs.VLLM_BR_DEVICE_SPC_NUM, dim=-1, need_pad=False, layout="linear_bias", do_transpose=False) layer.bias.data = gate_up_bias_internleaved_numa else: process_share_expert_weight(layer) def process_weights_MergedColumnParallelLinear( layer: MergedColumnParallelLinear): if 'shared_experts' not in layer.prefix: gate_up_weight = layer.weight.permute(1, 0).contiguous() if not (hasattr(layer, "no_need_cross") and layer.no_need_cross): gate_weight, up_weight = torch.chunk(gate_up_weight, 2, dim=-1) gate_up_weight_internleaved_numa = _convert_to_crossed_numa_tensor( gate_weight, up_weight, envs.VLLM_BR_DEVICE_SPC_NUM, dim=-1, need_pad=True, do_transpose=False) layer.weight.data = gate_up_weight_internleaved_numa else: gate_up_weight_numa = _convert_to_numa_tensor( gate_up_weight, align_size=32, layout="colmajor", dtype=gate_up_weight.dtype, do_transpose=False) layer.weight.data = gate_up_weight_numa if hasattr(layer, "bias") and layer.bias is not None: gate_bias, up_bias = torch.chunk(layer.bias, 2, dim=-1) gate_up_bias_internleaved_numa = _convert_to_crossed_numa_tensor( gate_bias, up_bias, envs.VLLM_BR_DEVICE_SPC_NUM, dim=-1, need_pad=False, layout="linear_bias", do_transpose=False) layer.bias.data = gate_up_bias_internleaved_numa else: #NOTE: by default, gate module and shared_expert(1) module will be involved into calculation in 1 kernel process_share_expert_weight(layer) @patch_to(UnquantizedLinearMethod) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if _should_skip_linear_post_process( layer, self.use_ds_mla, self.use_ds_mla_sparse) or self.weight_type != "NUMA": return still_need_process = True do_transpose = True parallel_type = "col_parallel" # NOTE: all process_weights func should done before process_weights_after_loading match layer: case ReplicatedLinear(): process_weights_ReplicatedLinear(layer) still_need_process = not ("indexer" not in layer.prefix and ( layer.output_size == 64 or layer.output_size == 160 # Glm4-Moe or layer.output_size == 128 or layer.output_size == 256)) do_transpose = False case MergedColumnParallelLinear(): process_weights_MergedColumnParallelLinear(layer) still_need_process = False do_transpose = False case RowParallelLinear(): parallel_type = "row_parallel" case _: pass if not still_need_process or self.weight_type != "NUMA": return # process numa weight and bias if hasattr(layer, "weight") and len(layer.weight.shape) == 2: if 'vision' in layer.prefix and is_br166_device(): layer.weight.data = _convert_to_numa_tensor_vit( layer.weight, envs.VLLM_BR_DEVICE_WARP_SIZE, "colmajor", torch.bfloat16, do_transpose=do_transpose, wk=(layer.weight.data.shape[1] if do_transpose else layer.weight.data.shape[0]), wn=(layer.weight.data.shape[0] if do_transpose else layer.weight.data.shape[1]), parallel_type=parallel_type) # noqa: SIM210 else: layer.weight.data = _convert_to_numa_tensor( layer.weight, envs.VLLM_BR_DEVICE_WARP_SIZE, "colmajor", torch.bfloat16, do_transpose=do_transpose, wk=(layer.weight.data.shape[1] if do_transpose else layer.weight.data.shape[0]), wn=(layer.weight.data.shape[0] if do_transpose else layer.weight.data.shape[1]), parallel_type=parallel_type) # noqa: SIM210 if hasattr(layer, "bias") and layer.bias is not None: pad_zeros = (parallel_type == "row_parallel") if 'vision' in layer.prefix and is_br166_device(): if (pad_zeros and layer.reduce_results): return layer.bias.data = _convert_to_numa_tensor_vit( layer.bias, envs.VLLM_BR_DEVICE_WARP_SIZE, "linear_bias", torch.float32, parallel_type=parallel_type, pad_zeros=pad_zeros) else: layer.bias.data = _convert_to_numa_tensor( layer.bias, envs.VLLM_BR_DEVICE_WARP_SIZE, "linear_bias", torch.float32, parallel_type=parallel_type, pad_zeros=pad_zeros) @patch_to(UnquantizedLinearMethod) def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: # numa weight is 3-dims if 'vision' in layer.prefix and is_br166_device(): if len(layer.weight.shape) == 3: is_row = isinstance(layer, RowParallelLinear) output_size = (layer.output_size_per_partition if hasattr( layer, "output_size_per_partition") else layer.output_size) act_mode = "act_default" if isinstance(layer, MergedColumnParallelLinear) and not (hasattr( layer, "no_need_cross") and layer.no_need_cross): act_mode = "act_swiglu" output_size //= 2 if bias is None or (is_row and layer.reduce_results): # return torch_br.br_matmul_infer( # x, # layer.weight, # bias=None, # output_w=output_size, # ) return torch_br.br_fused_mlp_infer(x, [layer.weight], output_w=output_size, activation_mode=act_mode) else: return torch_br.br_matmul_infer(x, layer.weight, bias, output_size) supa_debug.set_enable_sublas_api(True) output = F.linear(x, layer.weight, bias) supa_debug.set_enable_sublas_api(False) return output if len(layer.weight.shape) == 3: output_size = (layer.output_size_per_partition if hasattr( layer, "output_size_per_partition") else layer.output_size) act_mode = "act_default" if isinstance(layer, MergedColumnParallelLinear) and not (hasattr( layer, "no_need_cross") and layer.no_need_cross): act_mode = "act_swiglu" output_size //= 2 bias = [bias] if bias is not None else None if isinstance(layer, RowParallelLinear): seq_len = x.shape[-2] tp_size = get_tensor_model_parallel_world_size() # TODO(CaoJun): This is WA, delete (16, 8) so that the test_vllm_model_accu_qwen25_72b_instruct can run through support_types = ((16, 4), (32, 2), (32, 4)) # bypass tp8 and tp4pp2 allreduce pp_size = get_pp_group().world_size all_rank = tp_size * pp_size layer.reduce_results = not ( all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and (envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types) if layer.reduce_results: return torch_br.br_fused_mlp_infer(x, [layer.weight], output_w=output_size, bias=bias, activation_mode=act_mode) else: tp_rank = get_tp_group().rank_in_group global_rank = get_tp_group().rank rank_i = global_rank % tp_size assert rank_i == tp_rank return torch_br.supa_fused_linear_allreduce_opt( x, layer.weight, output_size, tp_rank, tp_size, global_rank, 0) else: return torch_br.br_fused_mlp_infer(x, [layer.weight], output_w=output_size, bias=bias, activation_mode=act_mode) supa_debug.set_enable_sublas_api(True) output = F.linear(x, layer.weight, bias) supa_debug.set_enable_sublas_api(False) return output @patch_to(LinearBase) def __init__( self, input_size: int, output_size: int, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", *, return_bias: bool = True, disable_tp: bool = False, ): super(LinearBase, self).__init__() # Keep input parameters self.input_size = input_size self.output_size = output_size self.skip_bias_add = skip_bias_add if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype if quant_config is None: self.quant_method = UnquantizedLinearMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias self.prefix = prefix self.tp_rank = (get_tensor_model_parallel_rank() if not disable_tp else 0) self.tp_size = (get_tensor_model_parallel_world_size() if not disable_tp else 1) @patch_to(RowParallelLinear) def forward( self, input_ ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr( self, "grandparent_pid"): self.grandparent_pid = get_grandparent_pid() if self.input_is_parallel: input_parallel = input_ else: tp_rank = get_tensor_model_parallel_rank() splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) input_parallel = splitted_input[tp_rank].contiguous() # Matrix multiply. assert self.quant_method is not None # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) if self.reduce_results and self.tp_size > 1: # CPU all reduce will be applied. if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and self.tp_size >= 4 and output_parallel.shape[ 1] <= 32: tp_rank = get_tensor_model_parallel_rank() output = torch_br.supa_allreduce_pcie_infer( output_parallel, tp_rank, self.tp_size, self.grandparent_pid) else: output = tensor_model_parallel_all_reduce(output_parallel) else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None if not self.return_bias: return output return output, output_bias @patch_to(QKVParallelLinear) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None): # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) if is_gguf_weight_type: idx_map = {"q": 0, "k": 1, "v": 2} if loaded_shard_id is not None: param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) param.shard_weight_type[loaded_shard_id] = loaded_weight.item() else: param.shard_weight_type = { k: loaded_weight.item() for k in idx_map } return if is_gguf_weight: tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) shard_size = loaded_weight.size(output_dim) // tp_size start_idx = tp_rank * shard_size if loaded_shard_id is not None: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) return param_data = param.data output_dim = getattr(param, "output_dim", None) # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) # Special case for per-tensor scales in fused case. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) if loaded_shard_id is None: # Loaded weight is already fused on disk (qkv). # (e.g., Phi-3's qkv_proj). if output_dim is None: if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( param_data, loaded_weight, 0) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return shard_offsets = [ # (shard_id, shard_offset, shard_size) ("q", 0, self.total_num_heads * self.head_size), ("k", self.total_num_heads * self.head_size, self.total_num_kv_heads * self.head_size), ("v", (self.total_num_heads + self.total_num_kv_heads) * self.head_size, self.total_num_kv_heads * self.head_size), ] use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.total_num_heads * self.head_size), "k": (self.total_num_heads * self.head_size, self.total_num_kv_heads * self.head_size), "v": ((self.total_num_heads + self.total_num_kv_heads) * self.head_size, self.total_num_kv_heads * self.head_size), "total": ((self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size, 0) } shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( param, orig_qkv_offsets, shard_id) loaded_weight_shard = loaded_weight.narrow(output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) return tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] # If output dim is defined, use the default loading process. if output_dim is not None: if loaded_shard_id == "q": shard_offset = 0 shard_size = self.num_heads * self.head_size elif loaded_shard_id == "k": shard_offset = self.num_heads * self.head_size shard_size = self.num_kv_heads * self.head_size elif loaded_shard_id == "v": shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) is_sharded_weight = getattr(param, "is_sharded_weight", False) # bitsandbytes loads the weights of the specific portion # no need to narrow is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.num_heads * self.head_size), "k": (self.num_heads * self.head_size, self.num_kv_heads * self.head_size), "v": ((self.num_heads + self.num_kv_heads) * self.head_size, self.num_kv_heads * self.head_size), "total": ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, 0) } shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( param, orig_qkv_offsets, loaded_shard_id) if envs.VLLM_BR_DEVICE_SPC_NUM > 16: half_w = param_data.shape[output_dim] // 2 param_data = (param_data.narrow(output_dim, shard_offset // 2, shard_size // 2), param_data.narrow(output_dim, shard_offset // 2 + half_w, shard_size // 2)) else: param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": shard_id = tp_rank else: shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size if not is_sharded_weight: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 shard_size = loaded_weight.shape[0] shard_index = ["q", "k", "v"].index(loaded_shard_id) param_data = param_data.narrow(0, shard_index * shard_size, shard_size) # Special case for per-tensor scales in fused case. elif needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( param_data, loaded_weight, loaded_shard_id) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: logger.warning( "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " "for all partitions.") if isinstance(param_data, tuple): half_w = loaded_weight.shape[output_dim] // 2 param_data[0].copy_(loaded_weight.narrow(output_dim, 0, half_w)) param_data[1].copy_(loaded_weight.narrow(output_dim, half_w, half_w)) else: assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight)