first commit
This commit is contained in:
619
vllm_br/model_executor/layers/br_utils.py
Normal file
619
vllm_br/model_executor/layers/br_utils.py
Normal file
@@ -0,0 +1,619 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
import torch_br.supa._debug as supa_debug
|
||||
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
def align_n(n, align_size, spc_num=envs.VLLM_BR_DEVICE_SPC_NUM) -> int:
|
||||
n_block = (n + spc_num - 1) // spc_num
|
||||
n_block = (n_block + align_size - 1) // align_size * align_size
|
||||
return n_block
|
||||
|
||||
|
||||
def _br_qweight_cvt(quant_method,
|
||||
qweight,
|
||||
qzeros,
|
||||
size_k,
|
||||
size_n,
|
||||
override_group_size=None):
|
||||
group_size = override_group_size or quant_method.quant_config.group_size
|
||||
curr_dev = qweight.device
|
||||
group_num = size_k // group_size if group_size > 0 else 1
|
||||
qweight = qweight.cpu().view(torch.int8).reshape(
|
||||
size_k // 4, size_n,
|
||||
4).permute(0, 2, 1).contiguous().reshape(group_num,
|
||||
size_k // group_num, size_n)
|
||||
if qzeros is not None and not torch.all(qzeros == 0):
|
||||
qzeros = qzeros.cpu().view(torch.int8).to(torch.int32) + 1
|
||||
qweight = (qweight.to(torch.int32) - qzeros.unsqueeze(1)).to(
|
||||
torch.int8)
|
||||
qwei_int8 = qweight.reshape(size_k, size_n).to(curr_dev)
|
||||
return qwei_int8
|
||||
|
||||
|
||||
def _numa_scales_cvt(scales, wn, spc_num):
|
||||
align_size = 32
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
cvt_scales = torch.nn.functional.pad(scales, (0, spc_num * wn_block - wn),
|
||||
mode='constant',
|
||||
value=0)
|
||||
cvt_scales = cvt_scales.reshape(spc_num, wn_block).contiguous()
|
||||
return cvt_scales
|
||||
|
||||
|
||||
def cross_weight_32(t1, t2, spc_num, dim=1, need_pad=True):
|
||||
width = t1.shape[dim]
|
||||
# NOTE: br166 must ensure dual-dies width are 32-aligned
|
||||
if spc_num > 16:
|
||||
assert width % 2 == 0
|
||||
half_width = width // 2
|
||||
half_width_ = (half_width + 32 - 1) // 32 * 32
|
||||
half_pad = half_width_ - half_width
|
||||
if half_pad > 0:
|
||||
t10, t11 = torch.chunk(t1, 2, dim=-1)
|
||||
t10 = torch.nn.functional.pad(t10, (0, half_pad), "constant", 0)
|
||||
t11 = torch.nn.functional.pad(t11, (0, half_pad), "constant", 0)
|
||||
t1 = torch.cat([t10, t11], dim=-1)
|
||||
t20, t21 = torch.chunk(t2, 2, dim=-1)
|
||||
t20 = torch.nn.functional.pad(t20, (0, half_pad), "constant", 0)
|
||||
t21 = torch.nn.functional.pad(t21, (0, half_pad), "constant", 0)
|
||||
t2 = torch.cat([t20, t21], dim=-1)
|
||||
width = half_width_ * 2
|
||||
else:
|
||||
width_ = (width + 32 - 1) // 32 * 32
|
||||
t1 = torch.nn.functional.pad(t1, (0, width_ - width), "constant", 0)
|
||||
t2 = torch.nn.functional.pad(t2, (0, width_ - width), "constant", 0)
|
||||
width = width_
|
||||
|
||||
cnt = width // 32
|
||||
t1_list = torch.chunk(t1, cnt, dim)
|
||||
t2_list = torch.chunk(t2, cnt, dim)
|
||||
tt = []
|
||||
for i in range(cnt):
|
||||
tt.append(t1_list[i])
|
||||
tt.append(t2_list[i])
|
||||
no_pad = torch.cat(tt, dim=dim)
|
||||
if not need_pad:
|
||||
return no_pad
|
||||
|
||||
if spc_num > 16:
|
||||
align = (spc_num // 2) * 32 * 2
|
||||
width_align = (width + align - 1) // align * align
|
||||
pad_size = width_align - width
|
||||
out0, out1 = torch.chunk(no_pad, 2, dim=-1)
|
||||
out0 = torch.nn.functional.pad(out0, (0, pad_size), "constant", 0)
|
||||
out1 = torch.nn.functional.pad(out1, (0, pad_size), "constant", 0)
|
||||
out = torch.cat([out0, out1], dim=-1)
|
||||
else:
|
||||
align = spc_num * 32 * 2 # 768
|
||||
width_align = (width * 2 + align - 1) // align * align
|
||||
pad_size = width_align - width * 2
|
||||
out = torch.nn.functional.pad(no_pad, (0, pad_size), "constant", 0)
|
||||
return out
|
||||
|
||||
|
||||
# # NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
|
||||
def _convert_to_uma_tensor(tensor,
|
||||
align_size,
|
||||
layout,
|
||||
dtype,
|
||||
do_transpose=False,
|
||||
wk=None,
|
||||
wn=None,
|
||||
parallel_type="col_parallel"):
|
||||
|
||||
assert parallel_type in ("col_parallel", "row_parallel")
|
||||
|
||||
layout = layout.lower()
|
||||
if layout == "colmajor":
|
||||
wk = wk or tensor.shape[1]
|
||||
wn = wn or tensor.shape[0]
|
||||
d_shape = (wn, wk)
|
||||
if do_transpose:
|
||||
data = tensor.cpu().permute(1, 0).contiguous()
|
||||
d_shape = (wk, wn)
|
||||
else:
|
||||
data = tensor.cpu().contiguous()
|
||||
if parallel_type == "col_parallel":
|
||||
uma_tensor = torch_br._empty_ut_only(
|
||||
size=d_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
sbp="SB",
|
||||
axis=0)
|
||||
else:
|
||||
uma_tensor = torch_br._empty_ut_only(
|
||||
size=d_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
axis=1,
|
||||
sbp="SB")
|
||||
torch.supa.synchronize()
|
||||
uma_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
elif layout == "linear_bias":
|
||||
axis = 0
|
||||
wn = wn or tensor.shape[-1]
|
||||
wk = 1
|
||||
data = tensor
|
||||
if len(data.shape) == 2 and data.shape[0] == 1:
|
||||
data = tensor.cpu().reshape(-1).contiguous()
|
||||
elif len(data.shape) == 2:
|
||||
axis = 1
|
||||
wk = data.shape[0]
|
||||
elif len(data.shape) == 3 and data.shape[1] == 1:
|
||||
data = tensor.cpu().reshape(
|
||||
(data.shape[0], data.shape[2])).contiguous()
|
||||
axis = 1
|
||||
wk = data.shape[0]
|
||||
d_shape = (wn, ) if axis == 0 else (wk, wn)
|
||||
if parallel_type == "row_parallel":
|
||||
uma_tensor = torch_br._empty_ut_only(
|
||||
size=d_shape,
|
||||
dtype=dtype,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
elif parallel_type == "col_parallel":
|
||||
uma_tensor = torch_br._empty_ut_only(
|
||||
size=d_shape,
|
||||
dtype=dtype,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
axis=axis,
|
||||
sbp="SB")
|
||||
|
||||
torch.supa.synchronize()
|
||||
uma_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
else:
|
||||
raise ValueError("uma tensor only support colmajor and linear_bias")
|
||||
return uma_tensor
|
||||
|
||||
|
||||
def _convert_to_numa_tensor_vit(tensor,
|
||||
align_size,
|
||||
layout,
|
||||
dtype,
|
||||
do_transpose=False,
|
||||
wk=None,
|
||||
wn=None,
|
||||
parallel_type="col_parallel",
|
||||
pad_zeros=False):
|
||||
assert parallel_type in ("col_parallel", "row_parallel")
|
||||
|
||||
enable_force_uma = supa_debug.is_enable_force_uma()
|
||||
supa_debug.set_enable_force_uma(False)
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
layout = layout.lower()
|
||||
die_num = 1
|
||||
if spc_num > 16:
|
||||
spc_num = spc_num // 2
|
||||
die_num = 2
|
||||
die_spc_num = die_num * spc_num
|
||||
|
||||
if layout == "colmajor":
|
||||
wk = wk or tensor.shape[0]
|
||||
wn = wn or tensor.shape[1]
|
||||
if die_num == 1:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(spc_num, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
data = tensor.cpu().permute(1, 0).contiguous()
|
||||
else:
|
||||
data = tensor.cpu().contiguous()
|
||||
data = torch.nn.functional.pad(data,
|
||||
(0, spc_num * wn_block - wn, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
|
||||
2).contiguous()
|
||||
torch.supa.synchronize()
|
||||
numa_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
else:
|
||||
if parallel_type == "col_parallel":
|
||||
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
|
||||
wk, die_num, wn // die_num)
|
||||
else:
|
||||
weight = tensor.cpu().contiguous().reshape(
|
||||
wk, die_num, wn // die_num)
|
||||
weight = torch.nn.functional.pad(
|
||||
weight,
|
||||
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(wk, die_spc_num,
|
||||
wn_block).permute(1, 0,
|
||||
2).contiguous()
|
||||
numa_tensor.copy_(weight)
|
||||
else:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
# w_block must align with 32 (warp_size)
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wk // die_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(1, 0).contiguous()
|
||||
else:
|
||||
weight = tensor.cpu().contiguous()
|
||||
weight = torch.nn.functional.pad(
|
||||
weight, (0, spc_num * wn_block - wn, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(
|
||||
die_num, wk // die_num, spc_num,
|
||||
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
|
||||
die_spc_num, wk // die_num, wn_block)
|
||||
numa_tensor.copy_(weight)
|
||||
|
||||
elif layout == "linear_bias":
|
||||
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
|
||||
# NOTE: index -1 for both scales and bias
|
||||
wn = tensor.shape[-1] if wn is None else wn
|
||||
group_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
|
||||
if die_num == 1:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(spc_num * group_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
data = torch.nn.functional.pad(tensor.cpu(),
|
||||
(0, spc_num * wn_block - wn),
|
||||
mode='constant',
|
||||
value=0)
|
||||
if group_num > 1:
|
||||
data = data.type(dtype).reshape(
|
||||
group_num, spc_num,
|
||||
wn_block).permute(1, 0, 2).contiguous().reshape(
|
||||
spc_num * group_num, wn_block)
|
||||
else:
|
||||
data = data.type(dtype).reshape(spc_num, wn_block).contiguous()
|
||||
torch.supa.synchronize()
|
||||
numa_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
else:
|
||||
if parallel_type == "col_parallel":
|
||||
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type="linear_bias")
|
||||
bias = tensor.cpu().reshape(die_num, wn // die_num)
|
||||
bias = torch.nn.functional.pad(
|
||||
bias, (0, spc_num * wn_block - wn // die_num, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
bias = bias.type(torch.float32).reshape(die_spc_num,
|
||||
wn_block).contiguous()
|
||||
numa_tensor.copy_(bias)
|
||||
else:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
# w_block must align with 32 (warp_size)
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type="linear_bias")
|
||||
bias = torch.nn.functional.pad(tensor.cpu(),
|
||||
(0, spc_num * wn_block - wn),
|
||||
mode='constant',
|
||||
value=0)
|
||||
bias = bias.type(torch.float32).reshape(spc_num,
|
||||
wn_block).contiguous()
|
||||
if pad_zeros:
|
||||
bias_zeros_die2 = torch.zeros((spc_num, wn_block),
|
||||
dtype=bias.dtype)
|
||||
bias = torch.concat([bias, bias_zeros_die2], dim=0)
|
||||
else:
|
||||
bias = torch.concat([bias, bias], dim=0)
|
||||
numa_tensor.copy_(bias)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tensor_type: {layout}")
|
||||
torch.supa.synchronize()
|
||||
supa_debug.set_enable_force_uma(enable_force_uma)
|
||||
return numa_tensor
|
||||
|
||||
|
||||
def _convert_to_numa_tensor(tensor,
|
||||
align_size,
|
||||
layout,
|
||||
dtype,
|
||||
do_transpose=False,
|
||||
wk=None,
|
||||
wn=None,
|
||||
parallel_type="col_parallel",
|
||||
pad_zeros=False):
|
||||
assert parallel_type in ("col_parallel", "row_parallel")
|
||||
|
||||
enable_force_uma = supa_debug.is_enable_force_uma()
|
||||
supa_debug.set_enable_force_uma(False)
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
layout = layout.lower()
|
||||
die_num = 1
|
||||
if spc_num > 16:
|
||||
spc_num = spc_num // 2
|
||||
die_num = 2
|
||||
die_spc_num = die_num * spc_num
|
||||
|
||||
if layout == "colmajor":
|
||||
wk = wk or tensor.shape[0]
|
||||
wn = wn or tensor.shape[1]
|
||||
if die_num == 1:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(spc_num, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
data = tensor.cpu().permute(1, 0).contiguous()
|
||||
else:
|
||||
data = tensor.cpu().contiguous()
|
||||
data = torch.nn.functional.pad(data,
|
||||
(0, spc_num * wn_block - wn, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
|
||||
2).contiguous()
|
||||
torch.supa.synchronize()
|
||||
numa_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
else:
|
||||
if parallel_type == "col_parallel":
|
||||
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
axis=0,
|
||||
sbp="SS")
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
|
||||
wk, die_num, wn // die_num)
|
||||
else:
|
||||
weight = tensor.cpu().contiguous().reshape(
|
||||
wk, die_num, wn // die_num)
|
||||
weight = torch.nn.functional.pad(
|
||||
weight,
|
||||
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(wk, die_spc_num,
|
||||
wn_block).permute(1, 0,
|
||||
2).contiguous()
|
||||
numa_tensor.copy_(weight)
|
||||
else:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
# w_block must align with 32 (warp_size)
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wk // die_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
axis=0,
|
||||
sbp="SS")
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(1, 0).contiguous()
|
||||
else:
|
||||
weight = tensor.cpu().contiguous()
|
||||
weight = torch.nn.functional.pad(
|
||||
weight, (0, spc_num * wn_block - wn, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(
|
||||
die_num, wk // die_num, spc_num,
|
||||
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
|
||||
die_spc_num, wk // die_num, wn_block)
|
||||
numa_tensor.copy_(weight)
|
||||
|
||||
elif layout == "linear_bias":
|
||||
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
|
||||
# NOTE: index -1 for both scales and bias
|
||||
wn = tensor.shape[-1] if wn is None else wn
|
||||
expert_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
|
||||
bias_shape = (expert_num, wn) if expert_num > 1 else (wn, )
|
||||
if die_num == 1:
|
||||
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=tensor.device,
|
||||
tensor_type=layout)
|
||||
data = tensor.cpu().type(dtype)
|
||||
if expert_num > 1:
|
||||
data = data.reshape(expert_num, wn)
|
||||
else:
|
||||
data = data.reshape(wn).type(dtype)
|
||||
torch.supa.synchronize()
|
||||
numa_tensor.copy_(data.to(tensor.device))
|
||||
|
||||
else:
|
||||
if parallel_type == "col_parallel":
|
||||
axis = 1 if expert_num > 1 else 0
|
||||
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=tensor.device,
|
||||
tensor_type="buffer_any",
|
||||
axis=axis,
|
||||
sbp="SB")
|
||||
if expert_num == 1:
|
||||
tensor = tensor.reshape(-1)
|
||||
numa_tensor.copy_(tensor.to(torch.supa.current_device()))
|
||||
else:
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=bias_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=tensor.device,
|
||||
tensor_type="linear_bias",
|
||||
sbp="BB")
|
||||
bias = tensor.reshape(expert_num, wn).cpu().type(dtype)
|
||||
if expert_num == 1:
|
||||
bias = bias.reshape(-1)
|
||||
numa_tensor.copy_(bias.to(torch.supa.current_device()))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported tensor_type: {layout}")
|
||||
torch.supa.synchronize()
|
||||
supa_debug.set_enable_force_uma(enable_force_uma)
|
||||
return numa_tensor
|
||||
|
||||
|
||||
def _convert_to_crossed_numa_tensor(t1,
|
||||
t2,
|
||||
spc_num,
|
||||
dim=1,
|
||||
need_pad=True,
|
||||
layout="colmajor",
|
||||
do_transpose=False):
|
||||
"""Equals to V0: cross_weight_32 + numa_weight_convert/_numa_weight_cvt
|
||||
"""
|
||||
uma_weight = cross_weight_32(t1, t2, spc_num, dim, need_pad)
|
||||
numa_weight = _convert_to_numa_tensor(uma_weight, 32, layout,
|
||||
uma_weight.dtype, do_transpose)
|
||||
return numa_weight
|
||||
|
||||
|
||||
def _convert_to_numa_tensor_moe(tensor,
|
||||
align_size,
|
||||
layout,
|
||||
dtype,
|
||||
do_transpose=False,
|
||||
wb=None,
|
||||
wk=None,
|
||||
wn=None,
|
||||
parallel_type="col_parallel",
|
||||
pad_zeros=False):
|
||||
assert parallel_type in ("col_parallel", "row_parallel")
|
||||
|
||||
enable_force_uma = supa_debug.is_enable_force_uma()
|
||||
supa_debug.set_enable_force_uma(False)
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
layout = layout.lower()
|
||||
die_num = 1
|
||||
if spc_num > 16:
|
||||
spc_num = spc_num // 2
|
||||
die_num = 2
|
||||
die_spc_num = die_num * spc_num
|
||||
assert die_num == 2
|
||||
if layout == "colmajor":
|
||||
wb = wb or tensor.shape[0]
|
||||
wk = wk or tensor.shape[1]
|
||||
wn = wn or tensor.shape[2]
|
||||
if parallel_type == "col_parallel":
|
||||
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * wb, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(0, 2, 1).contiguous().reshape(
|
||||
wb, wk, die_num, wn // die_num)
|
||||
else:
|
||||
weight = tensor.cpu().contiguous().reshape(
|
||||
wb, wk, die_num, wn // die_num)
|
||||
weight = torch.nn.functional.pad(
|
||||
weight,
|
||||
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(wb, wk, die_spc_num, wn_block).permute(
|
||||
2, 0, 1, 3).reshape(wb * die_spc_num, wk,
|
||||
wn_block).contiguous()
|
||||
numa_tensor.copy_(weight)
|
||||
elif parallel_type == "row_parallel":
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * wb, wk // die_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(0, 2, 1).contiguous()
|
||||
else:
|
||||
weight = tensor.cpu().contiguous()
|
||||
weight = torch.nn.functional.pad(
|
||||
weight, (0, spc_num * wn_block - wn, 0, 0, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(wb, die_num, wk // die_num, spc_num,
|
||||
wn_block).permute(1, 3, 0, 2,
|
||||
4).contiguous().reshape(
|
||||
die_spc_num * wb,
|
||||
wk // die_num,
|
||||
wn_block)
|
||||
numa_tensor.copy_(weight)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported tensor_type: {layout}")
|
||||
torch.supa.synchronize()
|
||||
supa_debug.set_enable_force_uma(enable_force_uma)
|
||||
return numa_tensor, (die_spc_num, wk, wn_block)
|
||||
|
||||
|
||||
def is_br166_device():
|
||||
spc_num = torch_br.supa.get_device_properties(
|
||||
torch.device("supa")).max_compute_units
|
||||
return bool(spc_num > 16 and spc_num <= 32)
|
||||
Reference in New Issue
Block a user