Files
2026-03-10 13:31:25 +08:00

620 lines
26 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import torch
import torch_br
import torch_br.supa._debug as supa_debug
from vllm_br import envs
def align_n(n, align_size, spc_num=envs.VLLM_BR_DEVICE_SPC_NUM) -> int:
n_block = (n + spc_num - 1) // spc_num
n_block = (n_block + align_size - 1) // align_size * align_size
return n_block
def _br_qweight_cvt(quant_method,
qweight,
qzeros,
size_k,
size_n,
override_group_size=None):
group_size = override_group_size or quant_method.quant_config.group_size
curr_dev = qweight.device
group_num = size_k // group_size if group_size > 0 else 1
qweight = qweight.cpu().view(torch.int8).reshape(
size_k // 4, size_n,
4).permute(0, 2, 1).contiguous().reshape(group_num,
size_k // group_num, size_n)
if qzeros is not None and not torch.all(qzeros == 0):
qzeros = qzeros.cpu().view(torch.int8).to(torch.int32) + 1
qweight = (qweight.to(torch.int32) - qzeros.unsqueeze(1)).to(
torch.int8)
qwei_int8 = qweight.reshape(size_k, size_n).to(curr_dev)
return qwei_int8
def _numa_scales_cvt(scales, wn, spc_num):
align_size = 32
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
cvt_scales = torch.nn.functional.pad(scales, (0, spc_num * wn_block - wn),
mode='constant',
value=0)
cvt_scales = cvt_scales.reshape(spc_num, wn_block).contiguous()
return cvt_scales
def cross_weight_32(t1, t2, spc_num, dim=1, need_pad=True):
width = t1.shape[dim]
# NOTE: br166 must ensure dual-dies width are 32-aligned
if spc_num > 16:
assert width % 2 == 0
half_width = width // 2
half_width_ = (half_width + 32 - 1) // 32 * 32
half_pad = half_width_ - half_width
if half_pad > 0:
t10, t11 = torch.chunk(t1, 2, dim=-1)
t10 = torch.nn.functional.pad(t10, (0, half_pad), "constant", 0)
t11 = torch.nn.functional.pad(t11, (0, half_pad), "constant", 0)
t1 = torch.cat([t10, t11], dim=-1)
t20, t21 = torch.chunk(t2, 2, dim=-1)
t20 = torch.nn.functional.pad(t20, (0, half_pad), "constant", 0)
t21 = torch.nn.functional.pad(t21, (0, half_pad), "constant", 0)
t2 = torch.cat([t20, t21], dim=-1)
width = half_width_ * 2
else:
width_ = (width + 32 - 1) // 32 * 32
t1 = torch.nn.functional.pad(t1, (0, width_ - width), "constant", 0)
t2 = torch.nn.functional.pad(t2, (0, width_ - width), "constant", 0)
width = width_
cnt = width // 32
t1_list = torch.chunk(t1, cnt, dim)
t2_list = torch.chunk(t2, cnt, dim)
tt = []
for i in range(cnt):
tt.append(t1_list[i])
tt.append(t2_list[i])
no_pad = torch.cat(tt, dim=dim)
if not need_pad:
return no_pad
if spc_num > 16:
align = (spc_num // 2) * 32 * 2
width_align = (width + align - 1) // align * align
pad_size = width_align - width
out0, out1 = torch.chunk(no_pad, 2, dim=-1)
out0 = torch.nn.functional.pad(out0, (0, pad_size), "constant", 0)
out1 = torch.nn.functional.pad(out1, (0, pad_size), "constant", 0)
out = torch.cat([out0, out1], dim=-1)
else:
align = spc_num * 32 * 2 # 768
width_align = (width * 2 + align - 1) // align * align
pad_size = width_align - width * 2
out = torch.nn.functional.pad(no_pad, (0, pad_size), "constant", 0)
return out
# # NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
def _convert_to_uma_tensor(tensor,
align_size,
layout,
dtype,
do_transpose=False,
wk=None,
wn=None,
parallel_type="col_parallel"):
assert parallel_type in ("col_parallel", "row_parallel")
layout = layout.lower()
if layout == "colmajor":
wk = wk or tensor.shape[1]
wn = wn or tensor.shape[0]
d_shape = (wn, wk)
if do_transpose:
data = tensor.cpu().permute(1, 0).contiguous()
d_shape = (wk, wn)
else:
data = tensor.cpu().contiguous()
if parallel_type == "col_parallel":
uma_tensor = torch_br._empty_ut_only(
size=d_shape,
dtype=dtype,
is_numa=False,
device=torch.supa.current_device(),
tensor_type=layout,
sbp="SB",
axis=0)
else:
uma_tensor = torch_br._empty_ut_only(
size=d_shape,
dtype=dtype,
is_numa=False,
device=torch.supa.current_device(),
tensor_type=layout,
axis=1,
sbp="SB")
torch.supa.synchronize()
uma_tensor.copy_(data.to(torch.supa.current_device()))
elif layout == "linear_bias":
axis = 0
wn = wn or tensor.shape[-1]
wk = 1
data = tensor
if len(data.shape) == 2 and data.shape[0] == 1:
data = tensor.cpu().reshape(-1).contiguous()
elif len(data.shape) == 2:
axis = 1
wk = data.shape[0]
elif len(data.shape) == 3 and data.shape[1] == 1:
data = tensor.cpu().reshape(
(data.shape[0], data.shape[2])).contiguous()
axis = 1
wk = data.shape[0]
d_shape = (wn, ) if axis == 0 else (wk, wn)
if parallel_type == "row_parallel":
uma_tensor = torch_br._empty_ut_only(
size=d_shape,
dtype=dtype,
device=torch.supa.current_device(),
tensor_type=layout)
elif parallel_type == "col_parallel":
uma_tensor = torch_br._empty_ut_only(
size=d_shape,
dtype=dtype,
device=torch.supa.current_device(),
tensor_type=layout,
axis=axis,
sbp="SB")
torch.supa.synchronize()
uma_tensor.copy_(data.to(torch.supa.current_device()))
else:
raise ValueError("uma tensor only support colmajor and linear_bias")
return uma_tensor
def _convert_to_numa_tensor_vit(tensor,
align_size,
layout,
dtype,
do_transpose=False,
wk=None,
wn=None,
parallel_type="col_parallel",
pad_zeros=False):
assert parallel_type in ("col_parallel", "row_parallel")
enable_force_uma = supa_debug.is_enable_force_uma()
supa_debug.set_enable_force_uma(False)
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
layout = layout.lower()
die_num = 1
if spc_num > 16:
spc_num = spc_num // 2
die_num = 2
die_spc_num = die_num * spc_num
if layout == "colmajor":
wk = wk or tensor.shape[0]
wn = wn or tensor.shape[1]
if die_num == 1:
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(spc_num, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
data = tensor.cpu().permute(1, 0).contiguous()
else:
data = tensor.cpu().contiguous()
data = torch.nn.functional.pad(data,
(0, spc_num * wn_block - wn, 0, 0),
mode='constant',
value=0)
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
2).contiguous()
torch.supa.synchronize()
numa_tensor.copy_(data.to(torch.supa.current_device()))
else:
if parallel_type == "col_parallel":
wn_block = (wn // die_num + spc_num - 1) // spc_num
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
wk, die_num, wn // die_num)
else:
weight = tensor.cpu().contiguous().reshape(
wk, die_num, wn // die_num)
weight = torch.nn.functional.pad(
weight,
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(wk, die_spc_num,
wn_block).permute(1, 0,
2).contiguous()
numa_tensor.copy_(weight)
else:
wn_block = (wn + spc_num - 1) // spc_num
# w_block must align with 32 (warp_size)
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wk // die_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
weight = tensor.cpu().permute(1, 0).contiguous()
else:
weight = tensor.cpu().contiguous()
weight = torch.nn.functional.pad(
weight, (0, spc_num * wn_block - wn, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(
die_num, wk // die_num, spc_num,
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
die_spc_num, wk // die_num, wn_block)
numa_tensor.copy_(weight)
elif layout == "linear_bias":
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
# NOTE: index -1 for both scales and bias
wn = tensor.shape[-1] if wn is None else wn
group_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
if die_num == 1:
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(spc_num * group_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
data = torch.nn.functional.pad(tensor.cpu(),
(0, spc_num * wn_block - wn),
mode='constant',
value=0)
if group_num > 1:
data = data.type(dtype).reshape(
group_num, spc_num,
wn_block).permute(1, 0, 2).contiguous().reshape(
spc_num * group_num, wn_block)
else:
data = data.type(dtype).reshape(spc_num, wn_block).contiguous()
torch.supa.synchronize()
numa_tensor.copy_(data.to(torch.supa.current_device()))
else:
if parallel_type == "col_parallel":
wn_block = (wn // die_num + spc_num - 1) // spc_num
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type="linear_bias")
bias = tensor.cpu().reshape(die_num, wn // die_num)
bias = torch.nn.functional.pad(
bias, (0, spc_num * wn_block - wn // die_num, 0, 0),
mode='constant',
value=0)
bias = bias.type(torch.float32).reshape(die_spc_num,
wn_block).contiguous()
numa_tensor.copy_(bias)
else:
wn_block = (wn + spc_num - 1) // spc_num
# w_block must align with 32 (warp_size)
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type="linear_bias")
bias = torch.nn.functional.pad(tensor.cpu(),
(0, spc_num * wn_block - wn),
mode='constant',
value=0)
bias = bias.type(torch.float32).reshape(spc_num,
wn_block).contiguous()
if pad_zeros:
bias_zeros_die2 = torch.zeros((spc_num, wn_block),
dtype=bias.dtype)
bias = torch.concat([bias, bias_zeros_die2], dim=0)
else:
bias = torch.concat([bias, bias], dim=0)
numa_tensor.copy_(bias)
else:
raise ValueError(f"Unsupported tensor_type: {layout}")
torch.supa.synchronize()
supa_debug.set_enable_force_uma(enable_force_uma)
return numa_tensor
def _convert_to_numa_tensor(tensor,
align_size,
layout,
dtype,
do_transpose=False,
wk=None,
wn=None,
parallel_type="col_parallel",
pad_zeros=False):
assert parallel_type in ("col_parallel", "row_parallel")
enable_force_uma = supa_debug.is_enable_force_uma()
supa_debug.set_enable_force_uma(False)
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
layout = layout.lower()
die_num = 1
if spc_num > 16:
spc_num = spc_num // 2
die_num = 2
die_spc_num = die_num * spc_num
if layout == "colmajor":
wk = wk or tensor.shape[0]
wn = wn or tensor.shape[1]
if die_num == 1:
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(spc_num, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
data = tensor.cpu().permute(1, 0).contiguous()
else:
data = tensor.cpu().contiguous()
data = torch.nn.functional.pad(data,
(0, spc_num * wn_block - wn, 0, 0),
mode='constant',
value=0)
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
2).contiguous()
torch.supa.synchronize()
numa_tensor.copy_(data.to(torch.supa.current_device()))
else:
if parallel_type == "col_parallel":
wn_block = (wn // die_num + spc_num - 1) // spc_num
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout,
axis=0,
sbp="SS")
if do_transpose:
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
wk, die_num, wn // die_num)
else:
weight = tensor.cpu().contiguous().reshape(
wk, die_num, wn // die_num)
weight = torch.nn.functional.pad(
weight,
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(wk, die_spc_num,
wn_block).permute(1, 0,
2).contiguous()
numa_tensor.copy_(weight)
else:
wn_block = (wn + spc_num - 1) // spc_num
# w_block must align with 32 (warp_size)
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wk // die_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout,
axis=0,
sbp="SS")
if do_transpose:
weight = tensor.cpu().permute(1, 0).contiguous()
else:
weight = tensor.cpu().contiguous()
weight = torch.nn.functional.pad(
weight, (0, spc_num * wn_block - wn, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(
die_num, wk // die_num, spc_num,
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
die_spc_num, wk // die_num, wn_block)
numa_tensor.copy_(weight)
elif layout == "linear_bias":
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
# NOTE: index -1 for both scales and bias
wn = tensor.shape[-1] if wn is None else wn
expert_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
bias_shape = (expert_num, wn) if expert_num > 1 else (wn, )
if die_num == 1:
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
dtype=dtype,
is_numa=False,
device=tensor.device,
tensor_type=layout)
data = tensor.cpu().type(dtype)
if expert_num > 1:
data = data.reshape(expert_num, wn)
else:
data = data.reshape(wn).type(dtype)
torch.supa.synchronize()
numa_tensor.copy_(data.to(tensor.device))
else:
if parallel_type == "col_parallel":
axis = 1 if expert_num > 1 else 0
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
dtype=dtype,
is_numa=False,
device=tensor.device,
tensor_type="buffer_any",
axis=axis,
sbp="SB")
if expert_num == 1:
tensor = tensor.reshape(-1)
numa_tensor.copy_(tensor.to(torch.supa.current_device()))
else:
numa_tensor = torch_br._empty_ut_only(
size=bias_shape,
dtype=dtype,
is_numa=False,
device=tensor.device,
tensor_type="linear_bias",
sbp="BB")
bias = tensor.reshape(expert_num, wn).cpu().type(dtype)
if expert_num == 1:
bias = bias.reshape(-1)
numa_tensor.copy_(bias.to(torch.supa.current_device()))
else:
raise ValueError(f"Unsupported tensor_type: {layout}")
torch.supa.synchronize()
supa_debug.set_enable_force_uma(enable_force_uma)
return numa_tensor
def _convert_to_crossed_numa_tensor(t1,
t2,
spc_num,
dim=1,
need_pad=True,
layout="colmajor",
do_transpose=False):
"""Equals to V0: cross_weight_32 + numa_weight_convert/_numa_weight_cvt
"""
uma_weight = cross_weight_32(t1, t2, spc_num, dim, need_pad)
numa_weight = _convert_to_numa_tensor(uma_weight, 32, layout,
uma_weight.dtype, do_transpose)
return numa_weight
def _convert_to_numa_tensor_moe(tensor,
align_size,
layout,
dtype,
do_transpose=False,
wb=None,
wk=None,
wn=None,
parallel_type="col_parallel",
pad_zeros=False):
assert parallel_type in ("col_parallel", "row_parallel")
enable_force_uma = supa_debug.is_enable_force_uma()
supa_debug.set_enable_force_uma(False)
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
layout = layout.lower()
die_num = 1
if spc_num > 16:
spc_num = spc_num // 2
die_num = 2
die_spc_num = die_num * spc_num
assert die_num == 2
if layout == "colmajor":
wb = wb or tensor.shape[0]
wk = wk or tensor.shape[1]
wn = wn or tensor.shape[2]
if parallel_type == "col_parallel":
wn_block = (wn // die_num + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num * wb, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
weight = tensor.cpu().permute(0, 2, 1).contiguous().reshape(
wb, wk, die_num, wn // die_num)
else:
weight = tensor.cpu().contiguous().reshape(
wb, wk, die_num, wn // die_num)
weight = torch.nn.functional.pad(
weight,
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(wb, wk, die_spc_num, wn_block).permute(
2, 0, 1, 3).reshape(wb * die_spc_num, wk,
wn_block).contiguous()
numa_tensor.copy_(weight)
elif parallel_type == "row_parallel":
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num * wb, wk // die_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
weight = tensor.cpu().permute(0, 2, 1).contiguous()
else:
weight = tensor.cpu().contiguous()
weight = torch.nn.functional.pad(
weight, (0, spc_num * wn_block - wn, 0, 0, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(wb, die_num, wk // die_num, spc_num,
wn_block).permute(1, 3, 0, 2,
4).contiguous().reshape(
die_spc_num * wb,
wk // die_num,
wn_block)
numa_tensor.copy_(weight)
else:
raise ValueError(f"Unsupported tensor_type: {layout}")
torch.supa.synchronize()
supa_debug.set_enable_force_uma(enable_force_uma)
return numa_tensor, (die_spc_num, wk, wn_block)
def is_br166_device():
spc_num = torch_br.supa.get_device_properties(
torch.device("supa")).max_compute_units
return bool(spc_num > 16 and spc_num <= 32)