################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ import torch import torch_br import torch_br.supa._debug as supa_debug from vllm_br import envs def align_n(n, align_size, spc_num=envs.VLLM_BR_DEVICE_SPC_NUM) -> int: n_block = (n + spc_num - 1) // spc_num n_block = (n_block + align_size - 1) // align_size * align_size return n_block def _br_qweight_cvt(quant_method, qweight, qzeros, size_k, size_n, override_group_size=None): group_size = override_group_size or quant_method.quant_config.group_size curr_dev = qweight.device group_num = size_k // group_size if group_size > 0 else 1 qweight = qweight.cpu().view(torch.int8).reshape( size_k // 4, size_n, 4).permute(0, 2, 1).contiguous().reshape(group_num, size_k // group_num, size_n) if qzeros is not None and not torch.all(qzeros == 0): qzeros = qzeros.cpu().view(torch.int8).to(torch.int32) + 1 qweight = (qweight.to(torch.int32) - qzeros.unsqueeze(1)).to( torch.int8) qwei_int8 = qweight.reshape(size_k, size_n).to(curr_dev) return qwei_int8 def _numa_scales_cvt(scales, wn, spc_num): align_size = 32 wn_block = (wn + spc_num - 1) // spc_num wn_block = (wn_block + align_size - 1) // align_size * align_size cvt_scales = torch.nn.functional.pad(scales, (0, spc_num * wn_block - wn), mode='constant', value=0) cvt_scales = cvt_scales.reshape(spc_num, wn_block).contiguous() return cvt_scales def cross_weight_32(t1, t2, spc_num, dim=1, need_pad=True): width = t1.shape[dim] # NOTE: br166 must ensure dual-dies width are 32-aligned if spc_num > 16: assert width % 2 == 0 half_width = width // 2 half_width_ = (half_width + 32 - 1) // 32 * 32 half_pad = half_width_ - half_width if half_pad > 0: t10, t11 = torch.chunk(t1, 2, dim=-1) t10 = torch.nn.functional.pad(t10, (0, half_pad), "constant", 0) t11 = torch.nn.functional.pad(t11, (0, half_pad), "constant", 0) t1 = torch.cat([t10, t11], dim=-1) t20, t21 = torch.chunk(t2, 2, dim=-1) t20 = torch.nn.functional.pad(t20, (0, half_pad), "constant", 0) t21 = torch.nn.functional.pad(t21, (0, half_pad), "constant", 0) t2 = torch.cat([t20, t21], dim=-1) width = half_width_ * 2 else: width_ = (width + 32 - 1) // 32 * 32 t1 = torch.nn.functional.pad(t1, (0, width_ - width), "constant", 0) t2 = torch.nn.functional.pad(t2, (0, width_ - width), "constant", 0) width = width_ cnt = width // 32 t1_list = torch.chunk(t1, cnt, dim) t2_list = torch.chunk(t2, cnt, dim) tt = [] for i in range(cnt): tt.append(t1_list[i]) tt.append(t2_list[i]) no_pad = torch.cat(tt, dim=dim) if not need_pad: return no_pad if spc_num > 16: align = (spc_num // 2) * 32 * 2 width_align = (width + align - 1) // align * align pad_size = width_align - width out0, out1 = torch.chunk(no_pad, 2, dim=-1) out0 = torch.nn.functional.pad(out0, (0, pad_size), "constant", 0) out1 = torch.nn.functional.pad(out1, (0, pad_size), "constant", 0) out = torch.cat([out0, out1], dim=-1) else: align = spc_num * 32 * 2 # 768 width_align = (width * 2 + align - 1) // align * align pad_size = width_align - width * 2 out = torch.nn.functional.pad(no_pad, (0, pad_size), "constant", 0) return out # # NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear def _convert_to_uma_tensor(tensor, align_size, layout, dtype, do_transpose=False, wk=None, wn=None, parallel_type="col_parallel"): assert parallel_type in ("col_parallel", "row_parallel") layout = layout.lower() if layout == "colmajor": wk = wk or tensor.shape[1] wn = wn or tensor.shape[0] d_shape = (wn, wk) if do_transpose: data = tensor.cpu().permute(1, 0).contiguous() d_shape = (wk, wn) else: data = tensor.cpu().contiguous() if parallel_type == "col_parallel": uma_tensor = torch_br._empty_ut_only( size=d_shape, dtype=dtype, is_numa=False, device=torch.supa.current_device(), tensor_type=layout, sbp="SB", axis=0) else: uma_tensor = torch_br._empty_ut_only( size=d_shape, dtype=dtype, is_numa=False, device=torch.supa.current_device(), tensor_type=layout, axis=1, sbp="SB") torch.supa.synchronize() uma_tensor.copy_(data.to(torch.supa.current_device())) elif layout == "linear_bias": axis = 0 wn = wn or tensor.shape[-1] wk = 1 data = tensor if len(data.shape) == 2 and data.shape[0] == 1: data = tensor.cpu().reshape(-1).contiguous() elif len(data.shape) == 2: axis = 1 wk = data.shape[0] elif len(data.shape) == 3 and data.shape[1] == 1: data = tensor.cpu().reshape( (data.shape[0], data.shape[2])).contiguous() axis = 1 wk = data.shape[0] d_shape = (wn, ) if axis == 0 else (wk, wn) if parallel_type == "row_parallel": uma_tensor = torch_br._empty_ut_only( size=d_shape, dtype=dtype, device=torch.supa.current_device(), tensor_type=layout) elif parallel_type == "col_parallel": uma_tensor = torch_br._empty_ut_only( size=d_shape, dtype=dtype, device=torch.supa.current_device(), tensor_type=layout, axis=axis, sbp="SB") torch.supa.synchronize() uma_tensor.copy_(data.to(torch.supa.current_device())) else: raise ValueError("uma tensor only support colmajor and linear_bias") return uma_tensor def _convert_to_numa_tensor_vit(tensor, align_size, layout, dtype, do_transpose=False, wk=None, wn=None, parallel_type="col_parallel", pad_zeros=False): assert parallel_type in ("col_parallel", "row_parallel") enable_force_uma = supa_debug.is_enable_force_uma() supa_debug.set_enable_force_uma(False) spc_num = envs.VLLM_BR_DEVICE_SPC_NUM layout = layout.lower() die_num = 1 if spc_num > 16: spc_num = spc_num // 2 die_num = 2 die_spc_num = die_num * spc_num if layout == "colmajor": wk = wk or tensor.shape[0] wn = wn or tensor.shape[1] if die_num == 1: wn_block = (wn + spc_num - 1) // spc_num wn_block = (wn_block + align_size - 1) // align_size * align_size numa_tensor = torch_br._empty_ut_only( size=(spc_num, wk, wn_block), dtype=dtype, is_numa=True, device=torch.supa.current_device(), tensor_type=layout) if do_transpose: data = tensor.cpu().permute(1, 0).contiguous() else: data = tensor.cpu().contiguous() data = torch.nn.functional.pad(data, (0, spc_num * wn_block - wn, 0, 0), mode='constant', value=0) data = data.reshape(wk, spc_num, wn_block).permute(1, 0, 2).contiguous() torch.supa.synchronize() numa_tensor.copy_(data.to(torch.supa.current_device())) else: if parallel_type == "col_parallel": wn_block = (wn // die_num + spc_num - 1) // spc_num wn_block = (wn_block + align_size - 1) // align_size * align_size numa_tensor = torch_br._empty_ut_only( size=(die_spc_num, wk, wn_block), dtype=dtype, is_numa=True, device=torch.supa.current_device(), tensor_type=layout) if do_transpose: weight = tensor.cpu().permute(1, 0).contiguous().reshape( wk, die_num, wn // die_num) else: weight = tensor.cpu().contiguous().reshape( wk, die_num, wn // die_num) weight = torch.nn.functional.pad( weight, (0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0), mode='constant', value=0) weight = weight.reshape(wk, die_spc_num, wn_block).permute(1, 0, 2).contiguous() numa_tensor.copy_(weight) else: wn_block = (wn + spc_num - 1) // spc_num # w_block must align with 32 (warp_size) wn_block = (wn_block + align_size - 1) // align_size * align_size numa_tensor = torch_br._empty_ut_only( size=(die_spc_num, wk // die_num, wn_block), dtype=dtype, is_numa=True, device=torch.supa.current_device(), tensor_type=layout) if do_transpose: weight = tensor.cpu().permute(1, 0).contiguous() else: weight = tensor.cpu().contiguous() weight = torch.nn.functional.pad( weight, (0, spc_num * wn_block - wn, 0, 0), mode='constant', value=0) weight = weight.reshape( die_num, wk // die_num, spc_num, wn_block).permute(0, 2, 1, 3).contiguous().reshape( die_spc_num, wk // die_num, wn_block) numa_tensor.copy_(weight) elif layout == "linear_bias": # NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear # NOTE: index -1 for both scales and bias wn = tensor.shape[-1] if wn is None else wn group_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1 if die_num == 1: wn_block = (wn + spc_num - 1) // spc_num wn_block = (wn_block + align_size - 1) // align_size * align_size numa_tensor = torch_br._empty_ut_only( size=(spc_num * group_num, wn_block), dtype=dtype, is_numa=True, device=torch.supa.current_device(), tensor_type=layout) data = torch.nn.functional.pad(tensor.cpu(), (0, spc_num * wn_block - wn), mode='constant', value=0) if group_num > 1: data = data.type(dtype).reshape( group_num, spc_num, wn_block).permute(1, 0, 2).contiguous().reshape( spc_num * group_num, wn_block) else: data = data.type(dtype).reshape(spc_num, wn_block).contiguous() torch.supa.synchronize() numa_tensor.copy_(data.to(torch.supa.current_device())) else: if parallel_type == "col_parallel": wn_block = (wn // die_num + spc_num - 1) // spc_num wn_block = (wn_block + align_size - 1) // align_size * align_size numa_tensor = torch_br._empty_ut_only( size=(die_spc_num, wn_block), dtype=dtype, is_numa=True, device=torch.supa.current_device(), tensor_type="linear_bias") bias = tensor.cpu().reshape(die_num, wn // die_num) bias = torch.nn.functional.pad( bias, (0, spc_num * wn_block - wn // die_num, 0, 0), mode='constant', value=0) bias = bias.type(torch.float32).reshape(die_spc_num, wn_block).contiguous() numa_tensor.copy_(bias) else: wn_block = (wn + spc_num - 1) // spc_num # w_block must align with 32 (warp_size) wn_block = (wn_block + align_size - 1) // align_size * align_size numa_tensor = torch_br._empty_ut_only( size=(die_spc_num, wn_block), dtype=dtype, is_numa=True, device=torch.supa.current_device(), tensor_type="linear_bias") bias = torch.nn.functional.pad(tensor.cpu(), (0, spc_num * wn_block - wn), mode='constant', value=0) bias = bias.type(torch.float32).reshape(spc_num, wn_block).contiguous() if pad_zeros: bias_zeros_die2 = torch.zeros((spc_num, wn_block), dtype=bias.dtype) bias = torch.concat([bias, bias_zeros_die2], dim=0) else: bias = torch.concat([bias, bias], dim=0) numa_tensor.copy_(bias) else: raise ValueError(f"Unsupported tensor_type: {layout}") torch.supa.synchronize() supa_debug.set_enable_force_uma(enable_force_uma) return numa_tensor def _convert_to_numa_tensor(tensor, align_size, layout, dtype, do_transpose=False, wk=None, wn=None, parallel_type="col_parallel", pad_zeros=False): assert parallel_type in ("col_parallel", "row_parallel") enable_force_uma = supa_debug.is_enable_force_uma() supa_debug.set_enable_force_uma(False) spc_num = envs.VLLM_BR_DEVICE_SPC_NUM layout = layout.lower() die_num = 1 if spc_num > 16: spc_num = spc_num // 2 die_num = 2 die_spc_num = die_num * spc_num if layout == "colmajor": wk = wk or tensor.shape[0] wn = wn or tensor.shape[1] if die_num == 1: wn_block = (wn + spc_num - 1) // spc_num wn_block = (wn_block + align_size - 1) // align_size * align_size numa_tensor = torch_br._empty_ut_only( size=(spc_num, wk, wn_block), dtype=dtype, is_numa=True, device=torch.supa.current_device(), tensor_type=layout) if do_transpose: data = tensor.cpu().permute(1, 0).contiguous() else: data = tensor.cpu().contiguous() data = torch.nn.functional.pad(data, (0, spc_num * wn_block - wn, 0, 0), mode='constant', value=0) data = data.reshape(wk, spc_num, wn_block).permute(1, 0, 2).contiguous() torch.supa.synchronize() numa_tensor.copy_(data.to(torch.supa.current_device())) else: if parallel_type == "col_parallel": wn_block = (wn // die_num + spc_num - 1) // spc_num wn_block = (wn_block + align_size - 1) // align_size * align_size numa_tensor = torch_br._empty_ut_only( size=(die_spc_num, wk, wn_block), dtype=dtype, is_numa=True, device=torch.supa.current_device(), tensor_type=layout, axis=0, sbp="SS") if do_transpose: weight = tensor.cpu().permute(1, 0).contiguous().reshape( wk, die_num, wn // die_num) else: weight = tensor.cpu().contiguous().reshape( wk, die_num, wn // die_num) weight = torch.nn.functional.pad( weight, (0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0), mode='constant', value=0) weight = weight.reshape(wk, die_spc_num, wn_block).permute(1, 0, 2).contiguous() numa_tensor.copy_(weight) else: wn_block = (wn + spc_num - 1) // spc_num # w_block must align with 32 (warp_size) wn_block = (wn_block + align_size - 1) // align_size * align_size numa_tensor = torch_br._empty_ut_only( size=(die_spc_num, wk // die_num, wn_block), dtype=dtype, is_numa=True, device=torch.supa.current_device(), tensor_type=layout, axis=0, sbp="SS") if do_transpose: weight = tensor.cpu().permute(1, 0).contiguous() else: weight = tensor.cpu().contiguous() weight = torch.nn.functional.pad( weight, (0, spc_num * wn_block - wn, 0, 0), mode='constant', value=0) weight = weight.reshape( die_num, wk // die_num, spc_num, wn_block).permute(0, 2, 1, 3).contiguous().reshape( die_spc_num, wk // die_num, wn_block) numa_tensor.copy_(weight) elif layout == "linear_bias": # NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear # NOTE: index -1 for both scales and bias wn = tensor.shape[-1] if wn is None else wn expert_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1 bias_shape = (expert_num, wn) if expert_num > 1 else (wn, ) if die_num == 1: numa_tensor = torch_br._empty_ut_only(size=bias_shape, dtype=dtype, is_numa=False, device=tensor.device, tensor_type=layout) data = tensor.cpu().type(dtype) if expert_num > 1: data = data.reshape(expert_num, wn) else: data = data.reshape(wn).type(dtype) torch.supa.synchronize() numa_tensor.copy_(data.to(tensor.device)) else: if parallel_type == "col_parallel": axis = 1 if expert_num > 1 else 0 numa_tensor = torch_br._empty_ut_only(size=bias_shape, dtype=dtype, is_numa=False, device=tensor.device, tensor_type="buffer_any", axis=axis, sbp="SB") if expert_num == 1: tensor = tensor.reshape(-1) numa_tensor.copy_(tensor.to(torch.supa.current_device())) else: numa_tensor = torch_br._empty_ut_only( size=bias_shape, dtype=dtype, is_numa=False, device=tensor.device, tensor_type="linear_bias", sbp="BB") bias = tensor.reshape(expert_num, wn).cpu().type(dtype) if expert_num == 1: bias = bias.reshape(-1) numa_tensor.copy_(bias.to(torch.supa.current_device())) else: raise ValueError(f"Unsupported tensor_type: {layout}") torch.supa.synchronize() supa_debug.set_enable_force_uma(enable_force_uma) return numa_tensor def _convert_to_crossed_numa_tensor(t1, t2, spc_num, dim=1, need_pad=True, layout="colmajor", do_transpose=False): """Equals to V0: cross_weight_32 + numa_weight_convert/_numa_weight_cvt """ uma_weight = cross_weight_32(t1, t2, spc_num, dim, need_pad) numa_weight = _convert_to_numa_tensor(uma_weight, 32, layout, uma_weight.dtype, do_transpose) return numa_weight def _convert_to_numa_tensor_moe(tensor, align_size, layout, dtype, do_transpose=False, wb=None, wk=None, wn=None, parallel_type="col_parallel", pad_zeros=False): assert parallel_type in ("col_parallel", "row_parallel") enable_force_uma = supa_debug.is_enable_force_uma() supa_debug.set_enable_force_uma(False) spc_num = envs.VLLM_BR_DEVICE_SPC_NUM layout = layout.lower() die_num = 1 if spc_num > 16: spc_num = spc_num // 2 die_num = 2 die_spc_num = die_num * spc_num assert die_num == 2 if layout == "colmajor": wb = wb or tensor.shape[0] wk = wk or tensor.shape[1] wn = wn or tensor.shape[2] if parallel_type == "col_parallel": wn_block = (wn // die_num + spc_num - 1) // spc_num wn_block = (wn_block + align_size - 1) // align_size * align_size numa_tensor = torch_br._empty_ut_only( size=(die_spc_num * wb, wk, wn_block), dtype=dtype, is_numa=True, device=torch.supa.current_device(), tensor_type=layout) if do_transpose: weight = tensor.cpu().permute(0, 2, 1).contiguous().reshape( wb, wk, die_num, wn // die_num) else: weight = tensor.cpu().contiguous().reshape( wb, wk, die_num, wn // die_num) weight = torch.nn.functional.pad( weight, (0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0, 0, 0), mode='constant', value=0) weight = weight.reshape(wb, wk, die_spc_num, wn_block).permute( 2, 0, 1, 3).reshape(wb * die_spc_num, wk, wn_block).contiguous() numa_tensor.copy_(weight) elif parallel_type == "row_parallel": wn_block = (wn + spc_num - 1) // spc_num wn_block = (wn_block + align_size - 1) // align_size * align_size numa_tensor = torch_br._empty_ut_only( size=(die_spc_num * wb, wk // die_num, wn_block), dtype=dtype, is_numa=True, device=torch.supa.current_device(), tensor_type=layout) if do_transpose: weight = tensor.cpu().permute(0, 2, 1).contiguous() else: weight = tensor.cpu().contiguous() weight = torch.nn.functional.pad( weight, (0, spc_num * wn_block - wn, 0, 0, 0, 0), mode='constant', value=0) weight = weight.reshape(wb, die_num, wk // die_num, spc_num, wn_block).permute(1, 3, 0, 2, 4).contiguous().reshape( die_spc_num * wb, wk // die_num, wn_block) numa_tensor.copy_(weight) else: raise ValueError(f"Unsupported tensor_type: {layout}") torch.supa.synchronize() supa_debug.set_enable_force_uma(enable_force_uma) return numa_tensor, (die_spc_num, wk, wn_block) def is_br166_device(): spc_num = torch_br.supa.get_device_properties( torch.device("supa")).max_compute_units return bool(spc_num > 16 and spc_num <= 32)