620 lines
26 KiB
Python
620 lines
26 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
import torch
|
|
import torch_br
|
|
import torch_br.supa._debug as supa_debug
|
|
|
|
from vllm_br import envs
|
|
|
|
|
|
def align_n(n, align_size, spc_num=envs.VLLM_BR_DEVICE_SPC_NUM) -> int:
|
|
n_block = (n + spc_num - 1) // spc_num
|
|
n_block = (n_block + align_size - 1) // align_size * align_size
|
|
return n_block
|
|
|
|
|
|
def _br_qweight_cvt(quant_method,
|
|
qweight,
|
|
qzeros,
|
|
size_k,
|
|
size_n,
|
|
override_group_size=None):
|
|
group_size = override_group_size or quant_method.quant_config.group_size
|
|
curr_dev = qweight.device
|
|
group_num = size_k // group_size if group_size > 0 else 1
|
|
qweight = qweight.cpu().view(torch.int8).reshape(
|
|
size_k // 4, size_n,
|
|
4).permute(0, 2, 1).contiguous().reshape(group_num,
|
|
size_k // group_num, size_n)
|
|
if qzeros is not None and not torch.all(qzeros == 0):
|
|
qzeros = qzeros.cpu().view(torch.int8).to(torch.int32) + 1
|
|
qweight = (qweight.to(torch.int32) - qzeros.unsqueeze(1)).to(
|
|
torch.int8)
|
|
qwei_int8 = qweight.reshape(size_k, size_n).to(curr_dev)
|
|
return qwei_int8
|
|
|
|
|
|
def _numa_scales_cvt(scales, wn, spc_num):
|
|
align_size = 32
|
|
wn_block = (wn + spc_num - 1) // spc_num
|
|
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
|
cvt_scales = torch.nn.functional.pad(scales, (0, spc_num * wn_block - wn),
|
|
mode='constant',
|
|
value=0)
|
|
cvt_scales = cvt_scales.reshape(spc_num, wn_block).contiguous()
|
|
return cvt_scales
|
|
|
|
|
|
def cross_weight_32(t1, t2, spc_num, dim=1, need_pad=True):
|
|
width = t1.shape[dim]
|
|
# NOTE: br166 must ensure dual-dies width are 32-aligned
|
|
if spc_num > 16:
|
|
assert width % 2 == 0
|
|
half_width = width // 2
|
|
half_width_ = (half_width + 32 - 1) // 32 * 32
|
|
half_pad = half_width_ - half_width
|
|
if half_pad > 0:
|
|
t10, t11 = torch.chunk(t1, 2, dim=-1)
|
|
t10 = torch.nn.functional.pad(t10, (0, half_pad), "constant", 0)
|
|
t11 = torch.nn.functional.pad(t11, (0, half_pad), "constant", 0)
|
|
t1 = torch.cat([t10, t11], dim=-1)
|
|
t20, t21 = torch.chunk(t2, 2, dim=-1)
|
|
t20 = torch.nn.functional.pad(t20, (0, half_pad), "constant", 0)
|
|
t21 = torch.nn.functional.pad(t21, (0, half_pad), "constant", 0)
|
|
t2 = torch.cat([t20, t21], dim=-1)
|
|
width = half_width_ * 2
|
|
else:
|
|
width_ = (width + 32 - 1) // 32 * 32
|
|
t1 = torch.nn.functional.pad(t1, (0, width_ - width), "constant", 0)
|
|
t2 = torch.nn.functional.pad(t2, (0, width_ - width), "constant", 0)
|
|
width = width_
|
|
|
|
cnt = width // 32
|
|
t1_list = torch.chunk(t1, cnt, dim)
|
|
t2_list = torch.chunk(t2, cnt, dim)
|
|
tt = []
|
|
for i in range(cnt):
|
|
tt.append(t1_list[i])
|
|
tt.append(t2_list[i])
|
|
no_pad = torch.cat(tt, dim=dim)
|
|
if not need_pad:
|
|
return no_pad
|
|
|
|
if spc_num > 16:
|
|
align = (spc_num // 2) * 32 * 2
|
|
width_align = (width + align - 1) // align * align
|
|
pad_size = width_align - width
|
|
out0, out1 = torch.chunk(no_pad, 2, dim=-1)
|
|
out0 = torch.nn.functional.pad(out0, (0, pad_size), "constant", 0)
|
|
out1 = torch.nn.functional.pad(out1, (0, pad_size), "constant", 0)
|
|
out = torch.cat([out0, out1], dim=-1)
|
|
else:
|
|
align = spc_num * 32 * 2 # 768
|
|
width_align = (width * 2 + align - 1) // align * align
|
|
pad_size = width_align - width * 2
|
|
out = torch.nn.functional.pad(no_pad, (0, pad_size), "constant", 0)
|
|
return out
|
|
|
|
|
|
# # NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
|
|
def _convert_to_uma_tensor(tensor,
|
|
align_size,
|
|
layout,
|
|
dtype,
|
|
do_transpose=False,
|
|
wk=None,
|
|
wn=None,
|
|
parallel_type="col_parallel"):
|
|
|
|
assert parallel_type in ("col_parallel", "row_parallel")
|
|
|
|
layout = layout.lower()
|
|
if layout == "colmajor":
|
|
wk = wk or tensor.shape[1]
|
|
wn = wn or tensor.shape[0]
|
|
d_shape = (wn, wk)
|
|
if do_transpose:
|
|
data = tensor.cpu().permute(1, 0).contiguous()
|
|
d_shape = (wk, wn)
|
|
else:
|
|
data = tensor.cpu().contiguous()
|
|
if parallel_type == "col_parallel":
|
|
uma_tensor = torch_br._empty_ut_only(
|
|
size=d_shape,
|
|
dtype=dtype,
|
|
is_numa=False,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout,
|
|
sbp="SB",
|
|
axis=0)
|
|
else:
|
|
uma_tensor = torch_br._empty_ut_only(
|
|
size=d_shape,
|
|
dtype=dtype,
|
|
is_numa=False,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout,
|
|
axis=1,
|
|
sbp="SB")
|
|
torch.supa.synchronize()
|
|
uma_tensor.copy_(data.to(torch.supa.current_device()))
|
|
elif layout == "linear_bias":
|
|
axis = 0
|
|
wn = wn or tensor.shape[-1]
|
|
wk = 1
|
|
data = tensor
|
|
if len(data.shape) == 2 and data.shape[0] == 1:
|
|
data = tensor.cpu().reshape(-1).contiguous()
|
|
elif len(data.shape) == 2:
|
|
axis = 1
|
|
wk = data.shape[0]
|
|
elif len(data.shape) == 3 and data.shape[1] == 1:
|
|
data = tensor.cpu().reshape(
|
|
(data.shape[0], data.shape[2])).contiguous()
|
|
axis = 1
|
|
wk = data.shape[0]
|
|
d_shape = (wn, ) if axis == 0 else (wk, wn)
|
|
if parallel_type == "row_parallel":
|
|
uma_tensor = torch_br._empty_ut_only(
|
|
size=d_shape,
|
|
dtype=dtype,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout)
|
|
elif parallel_type == "col_parallel":
|
|
uma_tensor = torch_br._empty_ut_only(
|
|
size=d_shape,
|
|
dtype=dtype,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout,
|
|
axis=axis,
|
|
sbp="SB")
|
|
|
|
torch.supa.synchronize()
|
|
uma_tensor.copy_(data.to(torch.supa.current_device()))
|
|
else:
|
|
raise ValueError("uma tensor only support colmajor and linear_bias")
|
|
return uma_tensor
|
|
|
|
|
|
def _convert_to_numa_tensor_vit(tensor,
|
|
align_size,
|
|
layout,
|
|
dtype,
|
|
do_transpose=False,
|
|
wk=None,
|
|
wn=None,
|
|
parallel_type="col_parallel",
|
|
pad_zeros=False):
|
|
assert parallel_type in ("col_parallel", "row_parallel")
|
|
|
|
enable_force_uma = supa_debug.is_enable_force_uma()
|
|
supa_debug.set_enable_force_uma(False)
|
|
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
|
layout = layout.lower()
|
|
die_num = 1
|
|
if spc_num > 16:
|
|
spc_num = spc_num // 2
|
|
die_num = 2
|
|
die_spc_num = die_num * spc_num
|
|
|
|
if layout == "colmajor":
|
|
wk = wk or tensor.shape[0]
|
|
wn = wn or tensor.shape[1]
|
|
if die_num == 1:
|
|
wn_block = (wn + spc_num - 1) // spc_num
|
|
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
|
numa_tensor = torch_br._empty_ut_only(
|
|
size=(spc_num, wk, wn_block),
|
|
dtype=dtype,
|
|
is_numa=True,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout)
|
|
if do_transpose:
|
|
data = tensor.cpu().permute(1, 0).contiguous()
|
|
else:
|
|
data = tensor.cpu().contiguous()
|
|
data = torch.nn.functional.pad(data,
|
|
(0, spc_num * wn_block - wn, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
|
|
2).contiguous()
|
|
torch.supa.synchronize()
|
|
numa_tensor.copy_(data.to(torch.supa.current_device()))
|
|
else:
|
|
if parallel_type == "col_parallel":
|
|
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
|
wn_block = (wn_block + align_size -
|
|
1) // align_size * align_size
|
|
numa_tensor = torch_br._empty_ut_only(
|
|
size=(die_spc_num, wk, wn_block),
|
|
dtype=dtype,
|
|
is_numa=True,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout)
|
|
if do_transpose:
|
|
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
|
|
wk, die_num, wn // die_num)
|
|
else:
|
|
weight = tensor.cpu().contiguous().reshape(
|
|
wk, die_num, wn // die_num)
|
|
weight = torch.nn.functional.pad(
|
|
weight,
|
|
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
weight = weight.reshape(wk, die_spc_num,
|
|
wn_block).permute(1, 0,
|
|
2).contiguous()
|
|
numa_tensor.copy_(weight)
|
|
else:
|
|
wn_block = (wn + spc_num - 1) // spc_num
|
|
# w_block must align with 32 (warp_size)
|
|
wn_block = (wn_block + align_size -
|
|
1) // align_size * align_size
|
|
numa_tensor = torch_br._empty_ut_only(
|
|
size=(die_spc_num, wk // die_num, wn_block),
|
|
dtype=dtype,
|
|
is_numa=True,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout)
|
|
if do_transpose:
|
|
weight = tensor.cpu().permute(1, 0).contiguous()
|
|
else:
|
|
weight = tensor.cpu().contiguous()
|
|
weight = torch.nn.functional.pad(
|
|
weight, (0, spc_num * wn_block - wn, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
weight = weight.reshape(
|
|
die_num, wk // die_num, spc_num,
|
|
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
|
|
die_spc_num, wk // die_num, wn_block)
|
|
numa_tensor.copy_(weight)
|
|
|
|
elif layout == "linear_bias":
|
|
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
|
|
# NOTE: index -1 for both scales and bias
|
|
wn = tensor.shape[-1] if wn is None else wn
|
|
group_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
|
|
if die_num == 1:
|
|
wn_block = (wn + spc_num - 1) // spc_num
|
|
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
|
numa_tensor = torch_br._empty_ut_only(
|
|
size=(spc_num * group_num, wn_block),
|
|
dtype=dtype,
|
|
is_numa=True,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout)
|
|
data = torch.nn.functional.pad(tensor.cpu(),
|
|
(0, spc_num * wn_block - wn),
|
|
mode='constant',
|
|
value=0)
|
|
if group_num > 1:
|
|
data = data.type(dtype).reshape(
|
|
group_num, spc_num,
|
|
wn_block).permute(1, 0, 2).contiguous().reshape(
|
|
spc_num * group_num, wn_block)
|
|
else:
|
|
data = data.type(dtype).reshape(spc_num, wn_block).contiguous()
|
|
torch.supa.synchronize()
|
|
numa_tensor.copy_(data.to(torch.supa.current_device()))
|
|
else:
|
|
if parallel_type == "col_parallel":
|
|
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
|
wn_block = (wn_block + align_size -
|
|
1) // align_size * align_size
|
|
numa_tensor = torch_br._empty_ut_only(
|
|
size=(die_spc_num, wn_block),
|
|
dtype=dtype,
|
|
is_numa=True,
|
|
device=torch.supa.current_device(),
|
|
tensor_type="linear_bias")
|
|
bias = tensor.cpu().reshape(die_num, wn // die_num)
|
|
bias = torch.nn.functional.pad(
|
|
bias, (0, spc_num * wn_block - wn // die_num, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
bias = bias.type(torch.float32).reshape(die_spc_num,
|
|
wn_block).contiguous()
|
|
numa_tensor.copy_(bias)
|
|
else:
|
|
wn_block = (wn + spc_num - 1) // spc_num
|
|
# w_block must align with 32 (warp_size)
|
|
wn_block = (wn_block + align_size -
|
|
1) // align_size * align_size
|
|
numa_tensor = torch_br._empty_ut_only(
|
|
size=(die_spc_num, wn_block),
|
|
dtype=dtype,
|
|
is_numa=True,
|
|
device=torch.supa.current_device(),
|
|
tensor_type="linear_bias")
|
|
bias = torch.nn.functional.pad(tensor.cpu(),
|
|
(0, spc_num * wn_block - wn),
|
|
mode='constant',
|
|
value=0)
|
|
bias = bias.type(torch.float32).reshape(spc_num,
|
|
wn_block).contiguous()
|
|
if pad_zeros:
|
|
bias_zeros_die2 = torch.zeros((spc_num, wn_block),
|
|
dtype=bias.dtype)
|
|
bias = torch.concat([bias, bias_zeros_die2], dim=0)
|
|
else:
|
|
bias = torch.concat([bias, bias], dim=0)
|
|
numa_tensor.copy_(bias)
|
|
else:
|
|
raise ValueError(f"Unsupported tensor_type: {layout}")
|
|
torch.supa.synchronize()
|
|
supa_debug.set_enable_force_uma(enable_force_uma)
|
|
return numa_tensor
|
|
|
|
|
|
def _convert_to_numa_tensor(tensor,
|
|
align_size,
|
|
layout,
|
|
dtype,
|
|
do_transpose=False,
|
|
wk=None,
|
|
wn=None,
|
|
parallel_type="col_parallel",
|
|
pad_zeros=False):
|
|
assert parallel_type in ("col_parallel", "row_parallel")
|
|
|
|
enable_force_uma = supa_debug.is_enable_force_uma()
|
|
supa_debug.set_enable_force_uma(False)
|
|
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
|
layout = layout.lower()
|
|
die_num = 1
|
|
if spc_num > 16:
|
|
spc_num = spc_num // 2
|
|
die_num = 2
|
|
die_spc_num = die_num * spc_num
|
|
|
|
if layout == "colmajor":
|
|
wk = wk or tensor.shape[0]
|
|
wn = wn or tensor.shape[1]
|
|
if die_num == 1:
|
|
wn_block = (wn + spc_num - 1) // spc_num
|
|
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
|
numa_tensor = torch_br._empty_ut_only(
|
|
size=(spc_num, wk, wn_block),
|
|
dtype=dtype,
|
|
is_numa=True,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout)
|
|
if do_transpose:
|
|
data = tensor.cpu().permute(1, 0).contiguous()
|
|
else:
|
|
data = tensor.cpu().contiguous()
|
|
data = torch.nn.functional.pad(data,
|
|
(0, spc_num * wn_block - wn, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
|
|
2).contiguous()
|
|
torch.supa.synchronize()
|
|
numa_tensor.copy_(data.to(torch.supa.current_device()))
|
|
else:
|
|
if parallel_type == "col_parallel":
|
|
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
|
wn_block = (wn_block + align_size -
|
|
1) // align_size * align_size
|
|
numa_tensor = torch_br._empty_ut_only(
|
|
size=(die_spc_num, wk, wn_block),
|
|
dtype=dtype,
|
|
is_numa=True,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout,
|
|
axis=0,
|
|
sbp="SS")
|
|
if do_transpose:
|
|
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
|
|
wk, die_num, wn // die_num)
|
|
else:
|
|
weight = tensor.cpu().contiguous().reshape(
|
|
wk, die_num, wn // die_num)
|
|
weight = torch.nn.functional.pad(
|
|
weight,
|
|
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
weight = weight.reshape(wk, die_spc_num,
|
|
wn_block).permute(1, 0,
|
|
2).contiguous()
|
|
numa_tensor.copy_(weight)
|
|
else:
|
|
wn_block = (wn + spc_num - 1) // spc_num
|
|
# w_block must align with 32 (warp_size)
|
|
wn_block = (wn_block + align_size -
|
|
1) // align_size * align_size
|
|
numa_tensor = torch_br._empty_ut_only(
|
|
size=(die_spc_num, wk // die_num, wn_block),
|
|
dtype=dtype,
|
|
is_numa=True,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout,
|
|
axis=0,
|
|
sbp="SS")
|
|
if do_transpose:
|
|
weight = tensor.cpu().permute(1, 0).contiguous()
|
|
else:
|
|
weight = tensor.cpu().contiguous()
|
|
weight = torch.nn.functional.pad(
|
|
weight, (0, spc_num * wn_block - wn, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
weight = weight.reshape(
|
|
die_num, wk // die_num, spc_num,
|
|
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
|
|
die_spc_num, wk // die_num, wn_block)
|
|
numa_tensor.copy_(weight)
|
|
|
|
elif layout == "linear_bias":
|
|
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
|
|
# NOTE: index -1 for both scales and bias
|
|
wn = tensor.shape[-1] if wn is None else wn
|
|
expert_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
|
|
bias_shape = (expert_num, wn) if expert_num > 1 else (wn, )
|
|
if die_num == 1:
|
|
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
|
|
dtype=dtype,
|
|
is_numa=False,
|
|
device=tensor.device,
|
|
tensor_type=layout)
|
|
data = tensor.cpu().type(dtype)
|
|
if expert_num > 1:
|
|
data = data.reshape(expert_num, wn)
|
|
else:
|
|
data = data.reshape(wn).type(dtype)
|
|
torch.supa.synchronize()
|
|
numa_tensor.copy_(data.to(tensor.device))
|
|
|
|
else:
|
|
if parallel_type == "col_parallel":
|
|
axis = 1 if expert_num > 1 else 0
|
|
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
|
|
dtype=dtype,
|
|
is_numa=False,
|
|
device=tensor.device,
|
|
tensor_type="buffer_any",
|
|
axis=axis,
|
|
sbp="SB")
|
|
if expert_num == 1:
|
|
tensor = tensor.reshape(-1)
|
|
numa_tensor.copy_(tensor.to(torch.supa.current_device()))
|
|
else:
|
|
numa_tensor = torch_br._empty_ut_only(
|
|
size=bias_shape,
|
|
dtype=dtype,
|
|
is_numa=False,
|
|
device=tensor.device,
|
|
tensor_type="linear_bias",
|
|
sbp="BB")
|
|
bias = tensor.reshape(expert_num, wn).cpu().type(dtype)
|
|
if expert_num == 1:
|
|
bias = bias.reshape(-1)
|
|
numa_tensor.copy_(bias.to(torch.supa.current_device()))
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported tensor_type: {layout}")
|
|
torch.supa.synchronize()
|
|
supa_debug.set_enable_force_uma(enable_force_uma)
|
|
return numa_tensor
|
|
|
|
|
|
def _convert_to_crossed_numa_tensor(t1,
|
|
t2,
|
|
spc_num,
|
|
dim=1,
|
|
need_pad=True,
|
|
layout="colmajor",
|
|
do_transpose=False):
|
|
"""Equals to V0: cross_weight_32 + numa_weight_convert/_numa_weight_cvt
|
|
"""
|
|
uma_weight = cross_weight_32(t1, t2, spc_num, dim, need_pad)
|
|
numa_weight = _convert_to_numa_tensor(uma_weight, 32, layout,
|
|
uma_weight.dtype, do_transpose)
|
|
return numa_weight
|
|
|
|
|
|
def _convert_to_numa_tensor_moe(tensor,
|
|
align_size,
|
|
layout,
|
|
dtype,
|
|
do_transpose=False,
|
|
wb=None,
|
|
wk=None,
|
|
wn=None,
|
|
parallel_type="col_parallel",
|
|
pad_zeros=False):
|
|
assert parallel_type in ("col_parallel", "row_parallel")
|
|
|
|
enable_force_uma = supa_debug.is_enable_force_uma()
|
|
supa_debug.set_enable_force_uma(False)
|
|
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
|
layout = layout.lower()
|
|
die_num = 1
|
|
if spc_num > 16:
|
|
spc_num = spc_num // 2
|
|
die_num = 2
|
|
die_spc_num = die_num * spc_num
|
|
assert die_num == 2
|
|
if layout == "colmajor":
|
|
wb = wb or tensor.shape[0]
|
|
wk = wk or tensor.shape[1]
|
|
wn = wn or tensor.shape[2]
|
|
if parallel_type == "col_parallel":
|
|
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
|
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
|
numa_tensor = torch_br._empty_ut_only(
|
|
size=(die_spc_num * wb, wk, wn_block),
|
|
dtype=dtype,
|
|
is_numa=True,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout)
|
|
if do_transpose:
|
|
weight = tensor.cpu().permute(0, 2, 1).contiguous().reshape(
|
|
wb, wk, die_num, wn // die_num)
|
|
else:
|
|
weight = tensor.cpu().contiguous().reshape(
|
|
wb, wk, die_num, wn // die_num)
|
|
weight = torch.nn.functional.pad(
|
|
weight,
|
|
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
weight = weight.reshape(wb, wk, die_spc_num, wn_block).permute(
|
|
2, 0, 1, 3).reshape(wb * die_spc_num, wk,
|
|
wn_block).contiguous()
|
|
numa_tensor.copy_(weight)
|
|
elif parallel_type == "row_parallel":
|
|
wn_block = (wn + spc_num - 1) // spc_num
|
|
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
|
numa_tensor = torch_br._empty_ut_only(
|
|
size=(die_spc_num * wb, wk // die_num, wn_block),
|
|
dtype=dtype,
|
|
is_numa=True,
|
|
device=torch.supa.current_device(),
|
|
tensor_type=layout)
|
|
if do_transpose:
|
|
weight = tensor.cpu().permute(0, 2, 1).contiguous()
|
|
else:
|
|
weight = tensor.cpu().contiguous()
|
|
weight = torch.nn.functional.pad(
|
|
weight, (0, spc_num * wn_block - wn, 0, 0, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
weight = weight.reshape(wb, die_num, wk // die_num, spc_num,
|
|
wn_block).permute(1, 3, 0, 2,
|
|
4).contiguous().reshape(
|
|
die_spc_num * wb,
|
|
wk // die_num,
|
|
wn_block)
|
|
numa_tensor.copy_(weight)
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported tensor_type: {layout}")
|
|
torch.supa.synchronize()
|
|
supa_debug.set_enable_force_uma(enable_force_uma)
|
|
return numa_tensor, (die_spc_num, wk, wn_block)
|
|
|
|
|
|
def is_br166_device():
|
|
spc_num = torch_br.supa.get_device_properties(
|
|
torch.device("supa")).max_compute_units
|
|
return bool(spc_num > 16 and spc_num <= 32)
|