first commit
This commit is contained in:
25
vllm_br/model_executor/layers/__init__.py
Normal file
25
vllm_br/model_executor/layers/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import vllm_br.model_executor.layers.activation
|
||||
import vllm_br.model_executor.layers.fused_moe
|
||||
import vllm_br.model_executor.layers.layernorm
|
||||
import vllm_br.model_executor.layers.linear
|
||||
import vllm_br.model_executor.layers.logits_processor
|
||||
import vllm_br.model_executor.layers.quantization
|
||||
import vllm_br.model_executor.layers.rotary_embedding
|
||||
import vllm_br.model_executor.layers.utils
|
||||
import vllm_br.model_executor.layers.vocab_parallel_embedding # noqa: F401
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_br/model_executor/layers/__pycache__/linear.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/layers/__pycache__/linear.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_br/model_executor/layers/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/layers/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
31
vllm_br/model_executor/layers/activation.py
Normal file
31
vllm_br/model_executor/layers/activation.py
Normal file
@@ -0,0 +1,31 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
|
||||
|
||||
|
||||
@patch_to(SiluAndMul)
|
||||
def silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return torch_br.supa_silumul(x[..., :d], x[..., d:]) # type: ignore
|
||||
|
||||
|
||||
@patch_to(QuickGELU)
|
||||
def quick_gelu_forward_oot(self, x: torch.Tensor) -> torch.Tensor: # noqa:F811
|
||||
return self.forward_native(x)
|
||||
619
vllm_br/model_executor/layers/br_utils.py
Normal file
619
vllm_br/model_executor/layers/br_utils.py
Normal file
@@ -0,0 +1,619 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
import torch_br.supa._debug as supa_debug
|
||||
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
def align_n(n, align_size, spc_num=envs.VLLM_BR_DEVICE_SPC_NUM) -> int:
|
||||
n_block = (n + spc_num - 1) // spc_num
|
||||
n_block = (n_block + align_size - 1) // align_size * align_size
|
||||
return n_block
|
||||
|
||||
|
||||
def _br_qweight_cvt(quant_method,
|
||||
qweight,
|
||||
qzeros,
|
||||
size_k,
|
||||
size_n,
|
||||
override_group_size=None):
|
||||
group_size = override_group_size or quant_method.quant_config.group_size
|
||||
curr_dev = qweight.device
|
||||
group_num = size_k // group_size if group_size > 0 else 1
|
||||
qweight = qweight.cpu().view(torch.int8).reshape(
|
||||
size_k // 4, size_n,
|
||||
4).permute(0, 2, 1).contiguous().reshape(group_num,
|
||||
size_k // group_num, size_n)
|
||||
if qzeros is not None and not torch.all(qzeros == 0):
|
||||
qzeros = qzeros.cpu().view(torch.int8).to(torch.int32) + 1
|
||||
qweight = (qweight.to(torch.int32) - qzeros.unsqueeze(1)).to(
|
||||
torch.int8)
|
||||
qwei_int8 = qweight.reshape(size_k, size_n).to(curr_dev)
|
||||
return qwei_int8
|
||||
|
||||
|
||||
def _numa_scales_cvt(scales, wn, spc_num):
|
||||
align_size = 32
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
cvt_scales = torch.nn.functional.pad(scales, (0, spc_num * wn_block - wn),
|
||||
mode='constant',
|
||||
value=0)
|
||||
cvt_scales = cvt_scales.reshape(spc_num, wn_block).contiguous()
|
||||
return cvt_scales
|
||||
|
||||
|
||||
def cross_weight_32(t1, t2, spc_num, dim=1, need_pad=True):
|
||||
width = t1.shape[dim]
|
||||
# NOTE: br166 must ensure dual-dies width are 32-aligned
|
||||
if spc_num > 16:
|
||||
assert width % 2 == 0
|
||||
half_width = width // 2
|
||||
half_width_ = (half_width + 32 - 1) // 32 * 32
|
||||
half_pad = half_width_ - half_width
|
||||
if half_pad > 0:
|
||||
t10, t11 = torch.chunk(t1, 2, dim=-1)
|
||||
t10 = torch.nn.functional.pad(t10, (0, half_pad), "constant", 0)
|
||||
t11 = torch.nn.functional.pad(t11, (0, half_pad), "constant", 0)
|
||||
t1 = torch.cat([t10, t11], dim=-1)
|
||||
t20, t21 = torch.chunk(t2, 2, dim=-1)
|
||||
t20 = torch.nn.functional.pad(t20, (0, half_pad), "constant", 0)
|
||||
t21 = torch.nn.functional.pad(t21, (0, half_pad), "constant", 0)
|
||||
t2 = torch.cat([t20, t21], dim=-1)
|
||||
width = half_width_ * 2
|
||||
else:
|
||||
width_ = (width + 32 - 1) // 32 * 32
|
||||
t1 = torch.nn.functional.pad(t1, (0, width_ - width), "constant", 0)
|
||||
t2 = torch.nn.functional.pad(t2, (0, width_ - width), "constant", 0)
|
||||
width = width_
|
||||
|
||||
cnt = width // 32
|
||||
t1_list = torch.chunk(t1, cnt, dim)
|
||||
t2_list = torch.chunk(t2, cnt, dim)
|
||||
tt = []
|
||||
for i in range(cnt):
|
||||
tt.append(t1_list[i])
|
||||
tt.append(t2_list[i])
|
||||
no_pad = torch.cat(tt, dim=dim)
|
||||
if not need_pad:
|
||||
return no_pad
|
||||
|
||||
if spc_num > 16:
|
||||
align = (spc_num // 2) * 32 * 2
|
||||
width_align = (width + align - 1) // align * align
|
||||
pad_size = width_align - width
|
||||
out0, out1 = torch.chunk(no_pad, 2, dim=-1)
|
||||
out0 = torch.nn.functional.pad(out0, (0, pad_size), "constant", 0)
|
||||
out1 = torch.nn.functional.pad(out1, (0, pad_size), "constant", 0)
|
||||
out = torch.cat([out0, out1], dim=-1)
|
||||
else:
|
||||
align = spc_num * 32 * 2 # 768
|
||||
width_align = (width * 2 + align - 1) // align * align
|
||||
pad_size = width_align - width * 2
|
||||
out = torch.nn.functional.pad(no_pad, (0, pad_size), "constant", 0)
|
||||
return out
|
||||
|
||||
|
||||
# # NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
|
||||
def _convert_to_uma_tensor(tensor,
|
||||
align_size,
|
||||
layout,
|
||||
dtype,
|
||||
do_transpose=False,
|
||||
wk=None,
|
||||
wn=None,
|
||||
parallel_type="col_parallel"):
|
||||
|
||||
assert parallel_type in ("col_parallel", "row_parallel")
|
||||
|
||||
layout = layout.lower()
|
||||
if layout == "colmajor":
|
||||
wk = wk or tensor.shape[1]
|
||||
wn = wn or tensor.shape[0]
|
||||
d_shape = (wn, wk)
|
||||
if do_transpose:
|
||||
data = tensor.cpu().permute(1, 0).contiguous()
|
||||
d_shape = (wk, wn)
|
||||
else:
|
||||
data = tensor.cpu().contiguous()
|
||||
if parallel_type == "col_parallel":
|
||||
uma_tensor = torch_br._empty_ut_only(
|
||||
size=d_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
sbp="SB",
|
||||
axis=0)
|
||||
else:
|
||||
uma_tensor = torch_br._empty_ut_only(
|
||||
size=d_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
axis=1,
|
||||
sbp="SB")
|
||||
torch.supa.synchronize()
|
||||
uma_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
elif layout == "linear_bias":
|
||||
axis = 0
|
||||
wn = wn or tensor.shape[-1]
|
||||
wk = 1
|
||||
data = tensor
|
||||
if len(data.shape) == 2 and data.shape[0] == 1:
|
||||
data = tensor.cpu().reshape(-1).contiguous()
|
||||
elif len(data.shape) == 2:
|
||||
axis = 1
|
||||
wk = data.shape[0]
|
||||
elif len(data.shape) == 3 and data.shape[1] == 1:
|
||||
data = tensor.cpu().reshape(
|
||||
(data.shape[0], data.shape[2])).contiguous()
|
||||
axis = 1
|
||||
wk = data.shape[0]
|
||||
d_shape = (wn, ) if axis == 0 else (wk, wn)
|
||||
if parallel_type == "row_parallel":
|
||||
uma_tensor = torch_br._empty_ut_only(
|
||||
size=d_shape,
|
||||
dtype=dtype,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
elif parallel_type == "col_parallel":
|
||||
uma_tensor = torch_br._empty_ut_only(
|
||||
size=d_shape,
|
||||
dtype=dtype,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
axis=axis,
|
||||
sbp="SB")
|
||||
|
||||
torch.supa.synchronize()
|
||||
uma_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
else:
|
||||
raise ValueError("uma tensor only support colmajor and linear_bias")
|
||||
return uma_tensor
|
||||
|
||||
|
||||
def _convert_to_numa_tensor_vit(tensor,
|
||||
align_size,
|
||||
layout,
|
||||
dtype,
|
||||
do_transpose=False,
|
||||
wk=None,
|
||||
wn=None,
|
||||
parallel_type="col_parallel",
|
||||
pad_zeros=False):
|
||||
assert parallel_type in ("col_parallel", "row_parallel")
|
||||
|
||||
enable_force_uma = supa_debug.is_enable_force_uma()
|
||||
supa_debug.set_enable_force_uma(False)
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
layout = layout.lower()
|
||||
die_num = 1
|
||||
if spc_num > 16:
|
||||
spc_num = spc_num // 2
|
||||
die_num = 2
|
||||
die_spc_num = die_num * spc_num
|
||||
|
||||
if layout == "colmajor":
|
||||
wk = wk or tensor.shape[0]
|
||||
wn = wn or tensor.shape[1]
|
||||
if die_num == 1:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(spc_num, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
data = tensor.cpu().permute(1, 0).contiguous()
|
||||
else:
|
||||
data = tensor.cpu().contiguous()
|
||||
data = torch.nn.functional.pad(data,
|
||||
(0, spc_num * wn_block - wn, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
|
||||
2).contiguous()
|
||||
torch.supa.synchronize()
|
||||
numa_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
else:
|
||||
if parallel_type == "col_parallel":
|
||||
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
|
||||
wk, die_num, wn // die_num)
|
||||
else:
|
||||
weight = tensor.cpu().contiguous().reshape(
|
||||
wk, die_num, wn // die_num)
|
||||
weight = torch.nn.functional.pad(
|
||||
weight,
|
||||
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(wk, die_spc_num,
|
||||
wn_block).permute(1, 0,
|
||||
2).contiguous()
|
||||
numa_tensor.copy_(weight)
|
||||
else:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
# w_block must align with 32 (warp_size)
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wk // die_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(1, 0).contiguous()
|
||||
else:
|
||||
weight = tensor.cpu().contiguous()
|
||||
weight = torch.nn.functional.pad(
|
||||
weight, (0, spc_num * wn_block - wn, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(
|
||||
die_num, wk // die_num, spc_num,
|
||||
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
|
||||
die_spc_num, wk // die_num, wn_block)
|
||||
numa_tensor.copy_(weight)
|
||||
|
||||
elif layout == "linear_bias":
|
||||
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
|
||||
# NOTE: index -1 for both scales and bias
|
||||
wn = tensor.shape[-1] if wn is None else wn
|
||||
group_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
|
||||
if die_num == 1:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(spc_num * group_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
data = torch.nn.functional.pad(tensor.cpu(),
|
||||
(0, spc_num * wn_block - wn),
|
||||
mode='constant',
|
||||
value=0)
|
||||
if group_num > 1:
|
||||
data = data.type(dtype).reshape(
|
||||
group_num, spc_num,
|
||||
wn_block).permute(1, 0, 2).contiguous().reshape(
|
||||
spc_num * group_num, wn_block)
|
||||
else:
|
||||
data = data.type(dtype).reshape(spc_num, wn_block).contiguous()
|
||||
torch.supa.synchronize()
|
||||
numa_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
else:
|
||||
if parallel_type == "col_parallel":
|
||||
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type="linear_bias")
|
||||
bias = tensor.cpu().reshape(die_num, wn // die_num)
|
||||
bias = torch.nn.functional.pad(
|
||||
bias, (0, spc_num * wn_block - wn // die_num, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
bias = bias.type(torch.float32).reshape(die_spc_num,
|
||||
wn_block).contiguous()
|
||||
numa_tensor.copy_(bias)
|
||||
else:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
# w_block must align with 32 (warp_size)
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type="linear_bias")
|
||||
bias = torch.nn.functional.pad(tensor.cpu(),
|
||||
(0, spc_num * wn_block - wn),
|
||||
mode='constant',
|
||||
value=0)
|
||||
bias = bias.type(torch.float32).reshape(spc_num,
|
||||
wn_block).contiguous()
|
||||
if pad_zeros:
|
||||
bias_zeros_die2 = torch.zeros((spc_num, wn_block),
|
||||
dtype=bias.dtype)
|
||||
bias = torch.concat([bias, bias_zeros_die2], dim=0)
|
||||
else:
|
||||
bias = torch.concat([bias, bias], dim=0)
|
||||
numa_tensor.copy_(bias)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tensor_type: {layout}")
|
||||
torch.supa.synchronize()
|
||||
supa_debug.set_enable_force_uma(enable_force_uma)
|
||||
return numa_tensor
|
||||
|
||||
|
||||
def _convert_to_numa_tensor(tensor,
|
||||
align_size,
|
||||
layout,
|
||||
dtype,
|
||||
do_transpose=False,
|
||||
wk=None,
|
||||
wn=None,
|
||||
parallel_type="col_parallel",
|
||||
pad_zeros=False):
|
||||
assert parallel_type in ("col_parallel", "row_parallel")
|
||||
|
||||
enable_force_uma = supa_debug.is_enable_force_uma()
|
||||
supa_debug.set_enable_force_uma(False)
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
layout = layout.lower()
|
||||
die_num = 1
|
||||
if spc_num > 16:
|
||||
spc_num = spc_num // 2
|
||||
die_num = 2
|
||||
die_spc_num = die_num * spc_num
|
||||
|
||||
if layout == "colmajor":
|
||||
wk = wk or tensor.shape[0]
|
||||
wn = wn or tensor.shape[1]
|
||||
if die_num == 1:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(spc_num, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
data = tensor.cpu().permute(1, 0).contiguous()
|
||||
else:
|
||||
data = tensor.cpu().contiguous()
|
||||
data = torch.nn.functional.pad(data,
|
||||
(0, spc_num * wn_block - wn, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
|
||||
2).contiguous()
|
||||
torch.supa.synchronize()
|
||||
numa_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
else:
|
||||
if parallel_type == "col_parallel":
|
||||
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
axis=0,
|
||||
sbp="SS")
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
|
||||
wk, die_num, wn // die_num)
|
||||
else:
|
||||
weight = tensor.cpu().contiguous().reshape(
|
||||
wk, die_num, wn // die_num)
|
||||
weight = torch.nn.functional.pad(
|
||||
weight,
|
||||
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(wk, die_spc_num,
|
||||
wn_block).permute(1, 0,
|
||||
2).contiguous()
|
||||
numa_tensor.copy_(weight)
|
||||
else:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
# w_block must align with 32 (warp_size)
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wk // die_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
axis=0,
|
||||
sbp="SS")
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(1, 0).contiguous()
|
||||
else:
|
||||
weight = tensor.cpu().contiguous()
|
||||
weight = torch.nn.functional.pad(
|
||||
weight, (0, spc_num * wn_block - wn, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(
|
||||
die_num, wk // die_num, spc_num,
|
||||
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
|
||||
die_spc_num, wk // die_num, wn_block)
|
||||
numa_tensor.copy_(weight)
|
||||
|
||||
elif layout == "linear_bias":
|
||||
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
|
||||
# NOTE: index -1 for both scales and bias
|
||||
wn = tensor.shape[-1] if wn is None else wn
|
||||
expert_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
|
||||
bias_shape = (expert_num, wn) if expert_num > 1 else (wn, )
|
||||
if die_num == 1:
|
||||
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=tensor.device,
|
||||
tensor_type=layout)
|
||||
data = tensor.cpu().type(dtype)
|
||||
if expert_num > 1:
|
||||
data = data.reshape(expert_num, wn)
|
||||
else:
|
||||
data = data.reshape(wn).type(dtype)
|
||||
torch.supa.synchronize()
|
||||
numa_tensor.copy_(data.to(tensor.device))
|
||||
|
||||
else:
|
||||
if parallel_type == "col_parallel":
|
||||
axis = 1 if expert_num > 1 else 0
|
||||
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=tensor.device,
|
||||
tensor_type="buffer_any",
|
||||
axis=axis,
|
||||
sbp="SB")
|
||||
if expert_num == 1:
|
||||
tensor = tensor.reshape(-1)
|
||||
numa_tensor.copy_(tensor.to(torch.supa.current_device()))
|
||||
else:
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=bias_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=tensor.device,
|
||||
tensor_type="linear_bias",
|
||||
sbp="BB")
|
||||
bias = tensor.reshape(expert_num, wn).cpu().type(dtype)
|
||||
if expert_num == 1:
|
||||
bias = bias.reshape(-1)
|
||||
numa_tensor.copy_(bias.to(torch.supa.current_device()))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported tensor_type: {layout}")
|
||||
torch.supa.synchronize()
|
||||
supa_debug.set_enable_force_uma(enable_force_uma)
|
||||
return numa_tensor
|
||||
|
||||
|
||||
def _convert_to_crossed_numa_tensor(t1,
|
||||
t2,
|
||||
spc_num,
|
||||
dim=1,
|
||||
need_pad=True,
|
||||
layout="colmajor",
|
||||
do_transpose=False):
|
||||
"""Equals to V0: cross_weight_32 + numa_weight_convert/_numa_weight_cvt
|
||||
"""
|
||||
uma_weight = cross_weight_32(t1, t2, spc_num, dim, need_pad)
|
||||
numa_weight = _convert_to_numa_tensor(uma_weight, 32, layout,
|
||||
uma_weight.dtype, do_transpose)
|
||||
return numa_weight
|
||||
|
||||
|
||||
def _convert_to_numa_tensor_moe(tensor,
|
||||
align_size,
|
||||
layout,
|
||||
dtype,
|
||||
do_transpose=False,
|
||||
wb=None,
|
||||
wk=None,
|
||||
wn=None,
|
||||
parallel_type="col_parallel",
|
||||
pad_zeros=False):
|
||||
assert parallel_type in ("col_parallel", "row_parallel")
|
||||
|
||||
enable_force_uma = supa_debug.is_enable_force_uma()
|
||||
supa_debug.set_enable_force_uma(False)
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
layout = layout.lower()
|
||||
die_num = 1
|
||||
if spc_num > 16:
|
||||
spc_num = spc_num // 2
|
||||
die_num = 2
|
||||
die_spc_num = die_num * spc_num
|
||||
assert die_num == 2
|
||||
if layout == "colmajor":
|
||||
wb = wb or tensor.shape[0]
|
||||
wk = wk or tensor.shape[1]
|
||||
wn = wn or tensor.shape[2]
|
||||
if parallel_type == "col_parallel":
|
||||
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * wb, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(0, 2, 1).contiguous().reshape(
|
||||
wb, wk, die_num, wn // die_num)
|
||||
else:
|
||||
weight = tensor.cpu().contiguous().reshape(
|
||||
wb, wk, die_num, wn // die_num)
|
||||
weight = torch.nn.functional.pad(
|
||||
weight,
|
||||
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(wb, wk, die_spc_num, wn_block).permute(
|
||||
2, 0, 1, 3).reshape(wb * die_spc_num, wk,
|
||||
wn_block).contiguous()
|
||||
numa_tensor.copy_(weight)
|
||||
elif parallel_type == "row_parallel":
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * wb, wk // die_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(0, 2, 1).contiguous()
|
||||
else:
|
||||
weight = tensor.cpu().contiguous()
|
||||
weight = torch.nn.functional.pad(
|
||||
weight, (0, spc_num * wn_block - wn, 0, 0, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(wb, die_num, wk // die_num, spc_num,
|
||||
wn_block).permute(1, 3, 0, 2,
|
||||
4).contiguous().reshape(
|
||||
die_spc_num * wb,
|
||||
wk // die_num,
|
||||
wn_block)
|
||||
numa_tensor.copy_(weight)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported tensor_type: {layout}")
|
||||
torch.supa.synchronize()
|
||||
supa_debug.set_enable_force_uma(enable_force_uma)
|
||||
return numa_tensor, (die_spc_num, wk, wn_block)
|
||||
|
||||
|
||||
def is_br166_device():
|
||||
spc_num = torch_br.supa.get_device_properties(
|
||||
torch.device("supa")).max_compute_units
|
||||
return bool(spc_num > 16 and spc_num <= 32)
|
||||
23
vllm_br/model_executor/layers/fused_moe/__init__.py
Normal file
23
vllm_br/model_executor/layers/fused_moe/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import layer, supa_moe # noqa: E402
|
||||
from .layer import * # noqa: E402
|
||||
|
||||
__all__ = [
|
||||
"layer",
|
||||
"supa_moe",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
413
vllm_br/model_executor/layers/fused_moe/layer.py
Normal file
413
vllm_br/model_executor/layers/fused_moe/layer.py
Normal file
@@ -0,0 +1,413 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from torch_br.utils.tensor_methods import Sbp
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm_br import envs
|
||||
from ..br_utils import (_convert_to_crossed_numa_tensor,
|
||||
_convert_to_numa_tensor, align_n, cross_weight_32)
|
||||
from .supa_moe import (fused_moe_quant_device, fused_moe_quant_dyn,
|
||||
fused_oss_moe_dyn)
|
||||
|
||||
|
||||
@patch_to(UnquantizedFusedMoEMethod)
|
||||
def forward_oot(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Forward for UnquantizedFusedMoEMethod with SUPA out-of-tree support.
|
||||
"""
|
||||
if activation == "swigluoai":
|
||||
return fused_oss_moe_dyn(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w13_bias,
|
||||
layer.w2_weight,
|
||||
layer.w2_bias,
|
||||
router_logits,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size)
|
||||
|
||||
b_seq = x.shape[0]
|
||||
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
|
||||
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
|
||||
# prefill
|
||||
return fused_moe_quant_dyn(
|
||||
x,
|
||||
shared_gate_up_weight,
|
||||
shared_down_weight,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
gating_weight,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
tp_rank=get_tp_group().rank_in_group,
|
||||
global_rank=get_tp_group().rank,
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size)
|
||||
else:
|
||||
# decoder
|
||||
return fused_moe_quant_device(
|
||||
x,
|
||||
shared_gate_up_weight,
|
||||
shared_down_weight,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
gating_weight,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
tp_rank=get_tp_group().rank_in_group,
|
||||
global_rank=get_tp_group().rank,
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size)
|
||||
|
||||
|
||||
@patch_to(UnquantizedFusedMoEMethod)
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
device="cpu",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
if self.moe.has_bias:
|
||||
w13_bias = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
if self.moe.has_bias:
|
||||
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
device="cpu",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
|
||||
@patch_to(UnquantizedFusedMoEMethod)
|
||||
def process_weights_after_loading(self: UnquantizedFusedMoEMethod,
|
||||
layer: FusedMoE) -> None:
|
||||
cur_device = torch.supa.current_device()
|
||||
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
die_num = 1 if die_spc_num <= 16 else 2
|
||||
spc_num = die_spc_num // die_num
|
||||
align_size = 32 if layer.activation == "swigluoai" else 64
|
||||
is_dual_die = (die_spc_num > 16)
|
||||
|
||||
# NOTE: w13_weight
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# [num_experts, 2 * intermediate_size_per_partition, hidden_size]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] (wn = aligned(2 * intermediate_size_per_partition, align_size=64))
|
||||
wk = layer.hidden_size
|
||||
wn_block = align_n((layer.intermediate_size_per_partition * 2) // die_num,
|
||||
align_size=align_size,
|
||||
spc_num=spc_num)
|
||||
|
||||
supa_w13_weight = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SS" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13 = layer.w13_weight[expert_id].transpose(0, 1).contiguous()
|
||||
# swigluoai activation, no need do interweave
|
||||
if layer.activation and layer.activation == "swigluoai":
|
||||
pad_expert_w13 = _convert_to_numa_tensor(expert_w13, align_size,
|
||||
'COLMAJOR',
|
||||
expert_w13.dtype)
|
||||
pad_expert_w13_shape = pad_expert_w13.shape
|
||||
hw_size = pad_expert_w13_shape[-2] * pad_expert_w13_shape[-1]
|
||||
narrow_data = supa_w13_weight.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w13_shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w13)
|
||||
else:
|
||||
expert_1, expert_3 = expert_w13.chunk(2, dim=1)
|
||||
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
|
||||
expert_3,
|
||||
die_spc_num,
|
||||
dim=1,
|
||||
need_pad=True,
|
||||
layout='COLMAJOR')
|
||||
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
|
||||
narrow_data = supa_w13_weight.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w13)
|
||||
|
||||
layer.w13_weight.data = supa_w13_weight
|
||||
|
||||
# NOTE: w13_bias
|
||||
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
wn = layer.intermediate_size_per_partition * 2
|
||||
supa_w13_bias = torch_br._empty_ut_only(
|
||||
size=(layer.local_num_experts, wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="linear_bias",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13_bias = layer.w13_bias[expert_id]
|
||||
# swigluoai activation, no need do interweave
|
||||
if layer.activation and layer.activation == "swigluoai":
|
||||
narrow_data = supa_w13_bias[expert_id]
|
||||
narrow_data.copy_(expert_w13_bias)
|
||||
else:
|
||||
expert_1_bias, expert_3_bias = expert_w13_bias.chunk(2, dim=-1)
|
||||
crossed_expert_w13_bias = cross_weight_32(
|
||||
expert_1_bias,
|
||||
expert_3_bias,
|
||||
die_spc_num,
|
||||
dim=0,
|
||||
need_pad=False,
|
||||
)
|
||||
narrow_data = supa_w13_bias[expert_id]
|
||||
narrow_data.copy_(crossed_expert_w13_bias)
|
||||
layer.w13_bias.data = supa_w13_bias
|
||||
|
||||
# NOTE: w2_weight
|
||||
# after _load_w2, w2_weight is a rowparallel weight, shape
|
||||
# [num_experts, hidden_size, intermediate_size_per_partition]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block]
|
||||
align_size = 32
|
||||
wk = layer.intermediate_size_per_partition
|
||||
wn_block = align_n(layer.hidden_size,
|
||||
align_size=align_size,
|
||||
spc_num=spc_num)
|
||||
|
||||
supa_w2_weight = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * layer.local_num_experts, wk // die_num, wn_block),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SS" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight[expert_id].transpose(0, 1).contiguous()
|
||||
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
|
||||
align_size,
|
||||
'COLMAJOR',
|
||||
expert_w2.dtype,
|
||||
parallel_type="row_parallel")
|
||||
pad_expert_w2_shape = pad_expert_w2.shape
|
||||
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
|
||||
narrow_data = supa_w2_weight.view_as_usharp("COLMAJOR",
|
||||
pad_expert_w2_shape,
|
||||
Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w2)
|
||||
|
||||
layer.w2_weight.data = supa_w2_weight
|
||||
|
||||
# NOTE: w2_bias
|
||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
wn = layer.hidden_size
|
||||
supa_w2_bias = torch.zeros((layer.local_num_experts, wn),
|
||||
dtype=torch.float32,
|
||||
device=cur_device)
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_bias[expert_id]
|
||||
narrow_data = supa_w2_bias[expert_id]
|
||||
narrow_data.copy_(expert_w2)
|
||||
|
||||
layer.w2_bias.data = supa_w2_bias
|
||||
|
||||
|
||||
@patch_to(FusedMoE)
|
||||
def forward(self: FusedMoE, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
"""
|
||||
! router_logits is a tuple of gate, shared_experts.gate_up_proj,
|
||||
shared_experts.down_proj weights.
|
||||
"""
|
||||
assert self.quant_method is not None
|
||||
assert self.dp_size == 1, 'dp_size > 1 is not supported for now, please refer v0.11.0 moe codes'
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
# NOTE: if using supa-moe-ccl kernel, add property `all_reduced` to the final_hidden_states
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if hidden_states.shape[
|
||||
0] <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and envs.VLLM_BR_QUANT_METHOD != "INT4" and envs.VLLM_BR_USE_FUSED_ALLREDUCE and (
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
|
||||
final_hidden_states.all_reduced = True
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@patch_to(FusedMoE)
|
||||
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str,
|
||||
loaded_weight: torch.Tensor, tp_rank: int):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight.cpu())
|
||||
|
||||
|
||||
@patch_to(FusedMoE)
|
||||
def _load_w2(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight.cpu())
|
||||
|
||||
|
||||
def wrapper_FusedMoE_init(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
fn(self, *args, **kwargs)
|
||||
|
||||
if self.e_score_correction_bias is not None:
|
||||
self.e_score_correction_bias.data = self.e_score_correction_bias.float(
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
FusedMoE.__init__ = wrapper_FusedMoE_init(FusedMoE.__init__) # noqa: E501
|
||||
518
vllm_br/model_executor/layers/fused_moe/supa_moe.py
Normal file
518
vllm_br/model_executor/layers/fused_moe/supa_moe.py
Normal file
@@ -0,0 +1,518 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from torch_br.utils.tensor_methods import Sbp
|
||||
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
# gpt-oss moe forward version
|
||||
def fused_oss_moe_dyn(
|
||||
hidden_states: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w13_bias: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_bias: torch.Tensor,
|
||||
gating_weight: torch.Tensor,
|
||||
topk: int,
|
||||
intermediate_size: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
ep_rank: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
total_expert_num = gating_weight.shape[-2]
|
||||
probs_supa, indices_supa, prob_per_expert, indice_per_expert = torch_br.supa_moe_router_v2_infer(
|
||||
hidden_states,
|
||||
gating_weight,
|
||||
topk,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
gating_bias=e_score_correction_bias)
|
||||
|
||||
cur_device = hidden_states.device
|
||||
probs_supa = probs_supa.cpu().permute(1, 0).contiguous().to(cur_device)
|
||||
indices_supa = indices_supa.cpu().permute(1, 0).contiguous().to(cur_device)
|
||||
indice_per_expert = indice_per_expert.cpu().permute(
|
||||
1, 0).contiguous().to(cur_device)
|
||||
prob_per_expert = prob_per_expert.cpu().permute(
|
||||
1, 0).contiguous().to(cur_device)
|
||||
|
||||
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
|
||||
local_expert_num = total_expert_num // ep_size # type: ignore
|
||||
b_seq = hidden_states.shape[0]
|
||||
indices_trans_supa = torch_br._empty_ut_only(
|
||||
size=(local_expert_num, (b_seq + 64 - 1) // 64 * 64),
|
||||
dtype=torch.int32,
|
||||
is_numa=False,
|
||||
device=hidden_states.device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
tokens_per_expert_supa = torch.bincount(indices_supa.reshape(-1),
|
||||
minlength=total_expert_num)
|
||||
|
||||
tokens_per_expert_list = tokens_per_expert_supa.cpu().data.numpy().tolist(
|
||||
)[ep_rank * local_expert_num:(ep_rank + 1) * # type: ignore
|
||||
local_expert_num]
|
||||
topk_per_expert = sum(1 for x in tokens_per_expert_list if x != 0)
|
||||
|
||||
if topk_per_expert > 0:
|
||||
expert_tokens = torch_br.supa_permutation_infer(
|
||||
global_hidden_states=hidden_states,
|
||||
indices=indice_per_expert,
|
||||
tokens_per_expert=tokens_per_expert_list,
|
||||
indices_trans=indices_trans_supa)
|
||||
|
||||
assert len(
|
||||
expert_tokens) == local_expert_num, "Number of experts mismatch"
|
||||
|
||||
gate_up_outputs = []
|
||||
down_outputs = []
|
||||
cur_device = expert_tokens[0].device
|
||||
hidden_size = expert_tokens[0].shape[-1]
|
||||
for i in range(local_expert_num):
|
||||
if tokens_per_expert_list[i] == 0:
|
||||
gate_up_outputs.append(
|
||||
torch.empty(size=(0, intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
device=cur_device))
|
||||
down_outputs.append(
|
||||
torch.empty(size=(0, hidden_size),
|
||||
dtype=torch.float32,
|
||||
device=cur_device))
|
||||
continue
|
||||
|
||||
gate_up_output = torch_br._empty_ut_only(
|
||||
size=(tokens_per_expert_list[i], intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
sbp="SB" if is_dual_die else None,
|
||||
axis=1)
|
||||
gate_up_outputs.append(gate_up_output)
|
||||
|
||||
down_output = torch_br._empty_ut_only(
|
||||
size=(tokens_per_expert_list[i], hidden_size),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
down_outputs.append(down_output)
|
||||
|
||||
torch_br.supa_moe_fused_ffn_dyn_infer(gate_up_outputs,
|
||||
expert_tokens,
|
||||
w13,
|
||||
tokens_per_expert_list,
|
||||
max(tokens_per_expert_list),
|
||||
bias=w13_bias,
|
||||
act_mode="act_swiglu_oai")
|
||||
|
||||
torch_br.supa_moe_fused_ffn_dyn_infer(down_outputs,
|
||||
gate_up_outputs,
|
||||
w2,
|
||||
tokens_per_expert_list,
|
||||
max(tokens_per_expert_list),
|
||||
bias=w2_bias,
|
||||
act_mode="act_default")
|
||||
|
||||
output = torch_br.supa_unpermutation_infer(
|
||||
input_list=down_outputs,
|
||||
indices=indices_trans_supa,
|
||||
probs=prob_per_expert,
|
||||
tokens_per_expert=tokens_per_expert_list)
|
||||
else:
|
||||
output = torch.zeros_like(hidden_states)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
def fused_moe_quant_dyn(
|
||||
hidden_states: torch.Tensor,
|
||||
shared_gate_up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_weight: torch.Tensor,
|
||||
topk: int,
|
||||
intermediate_size: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
global_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_rank: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
|
||||
total_expert_num = gating_weight.shape[-1]
|
||||
cur_device = hidden_states.device
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
shared_output, _, indices_supa, indice_per_expert, prob_per_expert = torch_br.supa_fused_shared_router_prefill_v2_infer(
|
||||
hidden_states,
|
||||
shared_gate_up_weight,
|
||||
down_weight,
|
||||
gating_weight,
|
||||
intermediate_size,
|
||||
topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
if is_dual_die:
|
||||
shared_tmp = torch_br._empty_ut_only(size=shared_output.shape,
|
||||
dtype=shared_output.dtype,
|
||||
is_numa=False,
|
||||
device=shared_output.device,
|
||||
tensor_type="colmajor")
|
||||
shared_tmp.copy_(shared_output)
|
||||
shared_output = shared_tmp
|
||||
else:
|
||||
assert topk_group is None, "Only support non group topk router"
|
||||
assert shared_gate_up_weight is None and down_weight is None
|
||||
_, indices_supa, prob_per_expert, indice_per_expert = torch_br.supa_moe_router_v2_infer(
|
||||
hidden_states,
|
||||
gating_weight.permute(1, 0).contiguous(),
|
||||
topk,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
gating_bias=e_score_correction_bias)
|
||||
shared_output = None
|
||||
|
||||
indices_supa = indices_supa.permute(1, 0).contiguous()
|
||||
indice_per_expert = indice_per_expert.permute(1, 0).contiguous()
|
||||
prob_per_expert = prob_per_expert.permute(1, 0).contiguous()
|
||||
|
||||
local_expert_num = total_expert_num // ep_size # type: ignore
|
||||
b_seq = hidden_states.shape[0]
|
||||
indices_trans_supa = torch_br._empty_ut_only(
|
||||
size=(local_expert_num, (b_seq + 64 - 1) // 64 * 64),
|
||||
dtype=torch.int32,
|
||||
is_numa=False,
|
||||
device=hidden_states.device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
tokens_per_expert_supa = torch.bincount(indices_supa.reshape(-1),
|
||||
minlength=total_expert_num)
|
||||
|
||||
tokens_per_expert_list = tokens_per_expert_supa.cpu().data.numpy().tolist(
|
||||
)[ep_rank * local_expert_num:(ep_rank + 1) * # type: ignore
|
||||
local_expert_num]
|
||||
topk_per_expert = sum(1 for x in tokens_per_expert_list if x != 0)
|
||||
|
||||
if topk_per_expert > 0:
|
||||
expert_tokens = torch_br.supa_permutation_infer(
|
||||
global_hidden_states=hidden_states,
|
||||
indices=indice_per_expert,
|
||||
tokens_per_expert=tokens_per_expert_list,
|
||||
indices_trans=indices_trans_supa)
|
||||
|
||||
assert len(
|
||||
expert_tokens) == local_expert_num, "Number of experts mismatch"
|
||||
|
||||
supa_device = torch.supa.current_device()
|
||||
spc_num = torch_br.supa.get_device_properties(
|
||||
supa_device).max_compute_units
|
||||
|
||||
out_expert_tokens = []
|
||||
use_moe_fused_ffn_dyn = True
|
||||
if not use_moe_fused_ffn_dyn or total_expert_num == 128:
|
||||
w13_hw = w13.shape[-2] * w13.shape[-1]
|
||||
w2_hw = w2.shape[-2] * w2.shape[-1]
|
||||
|
||||
for i in range(local_expert_num):
|
||||
expert_token = expert_tokens[i]
|
||||
if tokens_per_expert_list[i] == 0:
|
||||
out_expert_tokens.append(expert_token)
|
||||
continue
|
||||
|
||||
expert_gate_up_weight = w13.view_as_usharp(
|
||||
"COLMAJOR", (spc_num, w13.shape[-2], w13.shape[-1]),
|
||||
Sbp.ss(0), i * w13_hw)
|
||||
|
||||
down_weight = w2.view_as_usharp(
|
||||
"COLMAJOR", (spc_num, w2.shape[-2], w2.shape[-1]),
|
||||
Sbp.ss(0), i * w2_hw)
|
||||
|
||||
expert_gate_up_scale = w13_scale[
|
||||
i] if w13_scale is not None else None
|
||||
down_scale = w2_scale[i] if w2_scale is not None else None
|
||||
|
||||
gate_up_output = torch_br._empty_ut_only(
|
||||
size=(expert_token.shape[0], intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=False,
|
||||
device=expert_token.device,
|
||||
tensor_type="colmajor",
|
||||
sbp="SB" if is_dual_die else None,
|
||||
axis=1)
|
||||
torch_br.supa_fused_linear_infer(gate_up_output,
|
||||
expert_token,
|
||||
expert_gate_up_weight,
|
||||
expert_gate_up_scale,
|
||||
act_mode="act_swiglu")
|
||||
|
||||
down_output = torch_br._empty_ut_only(
|
||||
size=expert_token.shape,
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=gate_up_output.device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
|
||||
torch_br.supa_fused_linear_infer(down_output, gate_up_output,
|
||||
down_weight, down_scale)
|
||||
|
||||
out_expert_tokens.append(down_output)
|
||||
else:
|
||||
gate_up_outputs = []
|
||||
cur_device = expert_tokens[0].device
|
||||
hidden_size = expert_tokens[0].shape[-1]
|
||||
for i in range(local_expert_num):
|
||||
if tokens_per_expert_list[i] == 0:
|
||||
gate_up_outputs.append(
|
||||
torch.empty(size=(0, intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
device=cur_device))
|
||||
out_expert_tokens.append(
|
||||
torch.empty(size=(0, hidden_size),
|
||||
dtype=torch.float32,
|
||||
device=cur_device))
|
||||
continue
|
||||
|
||||
gate_up_output = torch_br._empty_ut_only(
|
||||
size=(tokens_per_expert_list[i], intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
sbp="SB" if is_dual_die else None,
|
||||
axis=1)
|
||||
gate_up_outputs.append(gate_up_output)
|
||||
|
||||
down_output = torch_br._empty_ut_only(
|
||||
size=(tokens_per_expert_list[i], hidden_size),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
out_expert_tokens.append(down_output)
|
||||
torch_br.supa_moe_fused_ffn_dyn_infer(gate_up_outputs,
|
||||
expert_tokens,
|
||||
w13,
|
||||
tokens_per_expert_list,
|
||||
max(tokens_per_expert_list),
|
||||
scales=w13_scale,
|
||||
act_mode="act_swiglu")
|
||||
torch_br.supa_moe_fused_ffn_dyn_infer(out_expert_tokens,
|
||||
gate_up_outputs,
|
||||
w2,
|
||||
tokens_per_expert_list,
|
||||
max(tokens_per_expert_list),
|
||||
scales=w2_scale,
|
||||
act_mode="act_default")
|
||||
|
||||
out_states = torch_br.supa_unpermutation_infer(
|
||||
input_list=out_expert_tokens,
|
||||
indices=indices_trans_supa,
|
||||
probs=prob_per_expert,
|
||||
tokens_per_expert=tokens_per_expert_list)
|
||||
|
||||
output = out_states if shared_output is None else out_states + shared_output
|
||||
else:
|
||||
output = torch.zeros_like(
|
||||
hidden_states) if shared_output is None else shared_output
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
def fused_moe_quant_device(
|
||||
hidden_states: torch.Tensor,
|
||||
shared_gate_up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_weight: torch.Tensor,
|
||||
topk: int,
|
||||
intermediate_size: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
global_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_rank: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
|
||||
expert_num = gating_weight.shape[-1]
|
||||
b_seq = hidden_states.shape[-2]
|
||||
if topk_group is None:
|
||||
assert shared_gate_up_weight is None and down_weight is None
|
||||
shared_output, masked_probs, hitted_experts = torch_br.supa_moe_router_decoder_infer(
|
||||
hidden_states, gating_weight, topk, ep_size, ep_rank)
|
||||
else:
|
||||
assert use_grouped_topk is True, "Only support group topk router"
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
if ep_size > 1: # type: ignore
|
||||
shared_output, masked_probs, hitted_experts = torch_br.supa_fused_shared_router_v2_infer(
|
||||
hidden_states,
|
||||
shared_gate_up_weight,
|
||||
down_weight,
|
||||
gating_weight,
|
||||
intermediate_size,
|
||||
topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias
|
||||
if e_score_correction_bias is not None else torch.empty(
|
||||
(expert_num),
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device))
|
||||
else:
|
||||
shared_output, masked_probs, hitted_experts = torch_br.supa_fused_shared_router_infer(
|
||||
hidden_states,
|
||||
shared_gate_up_weight,
|
||||
down_weight,
|
||||
gating_weight,
|
||||
intermediate_size,
|
||||
topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias
|
||||
if e_score_correction_bias is not None else torch.empty(
|
||||
(expert_num),
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device))
|
||||
if is_dual_die:
|
||||
shared_output = shared_output.view_as_usharp(
|
||||
"COLMAJOR", shared_output.shape, Sbp.bb())
|
||||
|
||||
if w13.dtype == torch.int32:
|
||||
torch_br.supa_moe_fused_ffn_s4_infer(shared_output, hidden_states, w13,
|
||||
w2, hitted_experts, masked_probs,
|
||||
w13_scale, w2_scale)
|
||||
else:
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
if envs.VLLM_BR_USE_FUSED_ALLREDUCE and b_seq <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and (
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
|
||||
# ffn+allreduce only support tp 4|8 and 16spc
|
||||
torch_br.supa_moe_fused_ffn_allreduce(shared_output, hidden_states,
|
||||
w13, w2, hitted_experts,
|
||||
masked_probs, tp_rank,
|
||||
tp_size, global_rank, 0,
|
||||
w13_scale, w2_scale)
|
||||
else:
|
||||
torch_br.supa_moe_fused_ffn_infer(shared_output, hidden_states,
|
||||
w13, w2, hitted_experts,
|
||||
masked_probs, w13_scale,
|
||||
w2_scale)
|
||||
|
||||
return shared_output.unsqueeze(0)
|
||||
67
vllm_br/model_executor/layers/layernorm.py
Normal file
67
vllm_br/model_executor/layers/layernorm.py
Normal file
@@ -0,0 +1,67 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
import os
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from torch import Tensor, nn
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
@patch_to(RMSNorm)
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.weight.data.dtype == torch.bfloat16:
|
||||
self.weight.data = self.weight.data.to(torch.float32)
|
||||
|
||||
if residual is not None:
|
||||
y_supa, add_out_supa = torch_br.supa_add_rmsnorm_infer( # type: ignore
|
||||
x, residual, self.weight.data, self.variance_epsilon)
|
||||
return y_supa, add_out_supa
|
||||
else:
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0)
|
||||
if len(x.shape) == 4:
|
||||
x = x.squeeze(0)
|
||||
|
||||
x = torch_br.supa_rmsnorm_infer(
|
||||
x,
|
||||
self.weight.data,
|
||||
self.variance_epsilon # type: ignore
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
@patch_to(RMSNorm)
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@patch_to(nn.LayerNorm)
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if os.environ.get("USE_BR_FUSED_LAYERNORM",
|
||||
'False').lower() not in {'false', '0', ''}:
|
||||
return torch_br.fused_layernorm(input, self.weight, self.bias,
|
||||
self.eps)
|
||||
else:
|
||||
return nn.functional.layer_norm(input, self.normalized_shape,
|
||||
self.weight, self.bias, self.eps)
|
||||
767
vllm_br/model_executor/layers/linear.py
Normal file
767
vllm_br/model_executor/layers/linear.py
Normal file
@@ -0,0 +1,767 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_br
|
||||
import torch_br.supa._debug as supa_debug
|
||||
from fastcore.basics import patch_to
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group, split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import (adjust_bitsandbytes_4bit_shard,
|
||||
adjust_marlin_shard,
|
||||
adjust_scalar_to_fused_array)
|
||||
from vllm_br import envs
|
||||
from vllm_br.utils import get_grandparent_pid
|
||||
from .br_utils import (_convert_to_crossed_numa_tensor,
|
||||
_convert_to_numa_tensor, _convert_to_numa_tensor_vit,
|
||||
is_br166_device)
|
||||
|
||||
from vllm.model_executor.layers.linear import ( # isort:skip
|
||||
LinearBase, MergedColumnParallelLinear, QuantizationConfig,
|
||||
ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod,
|
||||
QKVParallelLinear)
|
||||
|
||||
|
||||
def _should_skip_linear_post_process(layer, use_ds_mla, use_ds_mla_sparse):
|
||||
"""NOTE: SUPA: for MLA linears, we do process in MLA.process_weights_after_loading """
|
||||
# TODO: Hard code for native dsa op
|
||||
if use_ds_mla_sparse:
|
||||
MLA_LINEAR_NAMES = [
|
||||
"kv_b_proj",
|
||||
]
|
||||
else:
|
||||
MLA_LINEAR_NAMES = [
|
||||
"q_a_proj",
|
||||
"q_b_proj",
|
||||
# "q_proj",
|
||||
"kv_a_proj_with_mqa",
|
||||
"kv_b_proj",
|
||||
# "o_proj",
|
||||
]
|
||||
if use_ds_mla and not use_ds_mla_sparse:
|
||||
MLA_LINEAR_NAMES.append("o_proj")
|
||||
|
||||
skip = any(k in layer.prefix for k in MLA_LINEAR_NAMES)
|
||||
if skip:
|
||||
logger.debug(
|
||||
f'[SUPA] skip {layer.prefix} UnquantizedLinearMethod.process_weights_after_loading' # noqa: G004
|
||||
)
|
||||
return skip
|
||||
|
||||
|
||||
# NOTE: ReplicatedLinear, usually used in MoE as a gate module.
|
||||
# In DeepseekV3, it needs to be transposed.
|
||||
def process_weights_ReplicatedLinear(
|
||||
layer: ReplicatedLinear) -> Literal[True, False]:
|
||||
layer.weight.data = layer.weight.data.transpose(1, 0).contiguous()
|
||||
return True
|
||||
|
||||
|
||||
def process_share_expert_weight(layer: MergedColumnParallelLinear):
|
||||
gate_up_weight = layer.weight.transpose(1, 0).contiguous()
|
||||
|
||||
gate_weight, up_weight = torch.chunk(gate_up_weight, 2, dim=-1)
|
||||
|
||||
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
is_br166 = die_spc_num > 16
|
||||
spc_num = die_spc_num // 2 if is_br166 else die_spc_num
|
||||
|
||||
if is_br166:
|
||||
# 2&2 for 4spc, 4&8 for 12spc, 8&8 for 16spc
|
||||
spc_for_shared = 2 if spc_num == 4 else 8
|
||||
spc_for_router = spc_num - spc_for_shared
|
||||
|
||||
align_size = 32
|
||||
weight_dtype = gate_weight.dtype
|
||||
hidden_size = gate_weight.shape[0]
|
||||
|
||||
gate_d0, gate_d1 = torch.chunk(gate_weight, 2, dim=-1)
|
||||
up_d0, up_d1 = torch.chunk(up_weight, 2, dim=-1)
|
||||
im_size = gate_d0.shape[-1]
|
||||
n_align_size = (align_size * 2) * spc_for_shared
|
||||
swiglu_w_aligned = ((
|
||||
(im_size * 2) + n_align_size - 1) // n_align_size) * n_align_size
|
||||
region_size = swiglu_w_aligned // spc_for_shared
|
||||
block_nums = (region_size // (align_size * 2)) * spc_for_shared
|
||||
|
||||
gate_d0_align = torch.nn.functional.pad(
|
||||
gate_d0, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
gate_d1_align = torch.nn.functional.pad(
|
||||
gate_d1, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
up_d0_align = torch.nn.functional.pad(
|
||||
up_d0, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
up_d1_align = torch.nn.functional.pad(
|
||||
up_d1, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
gate_weight_d0_reshape = gate_d0_align.reshape(
|
||||
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
||||
gate_weight_d1_reshape = gate_d1_align.reshape(
|
||||
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
||||
up_weight_d0_reshape = up_d0_align.reshape(
|
||||
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
||||
up_weight_d1_reshape = up_d1_align.reshape(
|
||||
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
||||
|
||||
gate_up_weight_d0 = torch.zeros(
|
||||
[hidden_size, block_nums, align_size * 2],
|
||||
dtype=weight_dtype,
|
||||
device='supa')
|
||||
|
||||
gate_up_weight_d0[:, :, 0:0 +
|
||||
align_size] = gate_weight_d0_reshape[:, :,
|
||||
0:align_size]
|
||||
|
||||
gate_up_weight_d0[:, :, align_size:align_size *
|
||||
2] = up_weight_d0_reshape[:, :, 0:align_size]
|
||||
gate_up_weight_d0 = gate_up_weight_d0.reshape(
|
||||
hidden_size, spc_for_shared,
|
||||
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
|
||||
|
||||
gate_up_d0_invalid = torch.zeros(
|
||||
[spc_for_router, hidden_size, region_size],
|
||||
dtype=weight_dtype,
|
||||
device='supa') # invalid regions
|
||||
gate_up_weight_d0_whole = torch.cat(
|
||||
[gate_up_weight_d0, gate_up_d0_invalid], dim=0)
|
||||
|
||||
gate_up_weight_d1 = torch.zeros(
|
||||
[hidden_size, block_nums, align_size * 2],
|
||||
dtype=weight_dtype,
|
||||
device='supa')
|
||||
|
||||
gate_up_weight_d1[:, :, 0:0 +
|
||||
align_size] = gate_weight_d1_reshape[:, :,
|
||||
0:align_size]
|
||||
|
||||
gate_up_weight_d1[:, :, align_size:align_size *
|
||||
2] = up_weight_d1_reshape[:, :, 0:align_size]
|
||||
gate_up_weight_d1 = gate_up_weight_d1.reshape(
|
||||
hidden_size, spc_for_shared,
|
||||
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
|
||||
|
||||
gate_up_d1_invalid = torch.zeros(
|
||||
[spc_for_router, hidden_size, region_size],
|
||||
dtype=weight_dtype,
|
||||
device='supa') # invalid regions
|
||||
gate_up_weight_d1_whole = torch.cat(
|
||||
[gate_up_weight_d1, gate_up_d1_invalid], dim=0)
|
||||
|
||||
gate_up_weight_whole = torch.cat(
|
||||
[gate_up_weight_d0_whole, gate_up_weight_d1_whole], dim=0)
|
||||
gate_up_weight_supa = torch_br._empty_ut_only(
|
||||
size=gate_up_weight_whole.shape,
|
||||
dtype=gate_weight.dtype,
|
||||
is_numa=True,
|
||||
device="supa",
|
||||
tensor_type="colmajor")
|
||||
gate_up_weight_supa.copy_(gate_up_weight_whole)
|
||||
|
||||
layer.weight.data = gate_up_weight_supa
|
||||
else:
|
||||
# 2&2 for 4spc, 4&8 for 12spc, 8&8 for 16spc
|
||||
spc_for_shared = 2 if spc_num == 4 else 8
|
||||
spc_for_router = spc_num - spc_for_shared
|
||||
|
||||
align_size = 32
|
||||
weight_dtype = gate_weight.dtype
|
||||
hidden_size = gate_weight.shape[0]
|
||||
im_size = gate_weight.shape[-1]
|
||||
n_align_size = (align_size * 2) * spc_for_shared
|
||||
swiglu_w_aligned = ((
|
||||
(im_size * 2) + n_align_size - 1) // n_align_size) * n_align_size
|
||||
region_size = swiglu_w_aligned // spc_for_shared
|
||||
block_nums = (region_size // (align_size * 2)) * spc_for_shared
|
||||
|
||||
gate_golden_align = torch.nn.functional.pad(
|
||||
gate_weight, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
up_golden_align = torch.nn.functional.pad(
|
||||
up_weight, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
gate_weight_golden_reshape = gate_golden_align.reshape(
|
||||
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
||||
up_weight_golden_reshape = up_golden_align.reshape(
|
||||
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
||||
|
||||
gate_up_weight_golden = torch.zeros(
|
||||
[hidden_size, block_nums, align_size * 2],
|
||||
dtype=weight_dtype,
|
||||
device='supa')
|
||||
|
||||
gate_up_weight_golden[:, :, 0:0 +
|
||||
align_size] = gate_weight_golden_reshape[:, :, 0:
|
||||
align_size]
|
||||
|
||||
gate_up_weight_golden[:, :, align_size:align_size *
|
||||
2] = up_weight_golden_reshape[:, :, 0:align_size]
|
||||
gate_up_weight_golden = gate_up_weight_golden.reshape(
|
||||
hidden_size, spc_for_shared,
|
||||
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
|
||||
|
||||
gate_up_invalid = torch.zeros(
|
||||
[spc_for_router, hidden_size, region_size],
|
||||
dtype=weight_dtype,
|
||||
device='supa') # invalid regions
|
||||
gate_up_weight_whole = torch.cat(
|
||||
[gate_up_weight_golden, gate_up_invalid], dim=0)
|
||||
|
||||
gate_up_weight_supa = torch_br._empty_ut_only(
|
||||
size=gate_up_weight_whole.shape,
|
||||
dtype=gate_weight.dtype,
|
||||
is_numa=True,
|
||||
device="supa",
|
||||
tensor_type="colmajor")
|
||||
gate_up_weight_supa.copy_(gate_up_weight_whole)
|
||||
|
||||
layer.weight.data = gate_up_weight_supa
|
||||
|
||||
|
||||
# NOTE: MergedColumnParallelLinear, usually used in MergedGateUpMLPSiluL2
|
||||
def process_weights_QuantMergedColumnParallelLinear(
|
||||
layer: MergedColumnParallelLinear):
|
||||
if 'shared_experts' not in layer.prefix:
|
||||
#NOTE: normal MLP gate_up, after load weight, convert to supa numa tensor
|
||||
if hasattr(layer, "qweight"):
|
||||
gate_weight, up_weight = torch.chunk(layer.qweight, 2, dim=-1)
|
||||
gate_up_weight_numa = _convert_to_crossed_numa_tensor(
|
||||
gate_weight,
|
||||
up_weight,
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM,
|
||||
dim=-1,
|
||||
need_pad=True,
|
||||
do_transpose=False)
|
||||
layer.qweight.data = gate_up_weight_numa
|
||||
else:
|
||||
gate_up_weight = layer.weight.permute(1, 0).contiguous()
|
||||
gate_up_weight_numa = _convert_to_numa_tensor(
|
||||
gate_up_weight,
|
||||
32,
|
||||
"colmajor",
|
||||
gate_up_weight.dtype,
|
||||
False,
|
||||
parallel_type="col_parallel")
|
||||
layer.weight.data = gate_up_weight_numa
|
||||
|
||||
if hasattr(layer, "scales") and layer.scales is not None:
|
||||
gate_scales, up_scales = torch.chunk(layer.scales, 2, dim=-1)
|
||||
gate_up_scales_internleaved_numa = _convert_to_crossed_numa_tensor(
|
||||
gate_scales,
|
||||
up_scales,
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM,
|
||||
dim=-1,
|
||||
need_pad=False,
|
||||
layout="linear_bias",
|
||||
do_transpose=False)
|
||||
layer.scales.data = gate_up_scales_internleaved_numa
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
gate_bias, up_bias = torch.chunk(layer.bias, 2, dim=-1)
|
||||
gate_up_bias_internleaved_numa = _convert_to_crossed_numa_tensor(
|
||||
gate_bias,
|
||||
up_bias,
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM,
|
||||
dim=-1,
|
||||
need_pad=False,
|
||||
layout="linear_bias",
|
||||
do_transpose=False)
|
||||
layer.bias.data = gate_up_bias_internleaved_numa
|
||||
else:
|
||||
process_share_expert_weight(layer)
|
||||
|
||||
|
||||
def process_weights_MergedColumnParallelLinear(
|
||||
layer: MergedColumnParallelLinear):
|
||||
if 'shared_experts' not in layer.prefix:
|
||||
gate_up_weight = layer.weight.permute(1, 0).contiguous()
|
||||
if not (hasattr(layer, "no_need_cross") and layer.no_need_cross):
|
||||
gate_weight, up_weight = torch.chunk(gate_up_weight, 2, dim=-1)
|
||||
gate_up_weight_internleaved_numa = _convert_to_crossed_numa_tensor(
|
||||
gate_weight,
|
||||
up_weight,
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM,
|
||||
dim=-1,
|
||||
need_pad=True,
|
||||
do_transpose=False)
|
||||
layer.weight.data = gate_up_weight_internleaved_numa
|
||||
else:
|
||||
gate_up_weight_numa = _convert_to_numa_tensor(
|
||||
gate_up_weight,
|
||||
align_size=32,
|
||||
layout="colmajor",
|
||||
dtype=gate_up_weight.dtype,
|
||||
do_transpose=False)
|
||||
layer.weight.data = gate_up_weight_numa
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
gate_bias, up_bias = torch.chunk(layer.bias, 2, dim=-1)
|
||||
gate_up_bias_internleaved_numa = _convert_to_crossed_numa_tensor(
|
||||
gate_bias,
|
||||
up_bias,
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM,
|
||||
dim=-1,
|
||||
need_pad=False,
|
||||
layout="linear_bias",
|
||||
do_transpose=False)
|
||||
layer.bias.data = gate_up_bias_internleaved_numa
|
||||
|
||||
else:
|
||||
#NOTE: by default, gate module and shared_expert(1) module will be involved into calculation in 1 kernel
|
||||
process_share_expert_weight(layer)
|
||||
|
||||
|
||||
@patch_to(UnquantizedLinearMethod)
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if _should_skip_linear_post_process(
|
||||
layer, self.use_ds_mla,
|
||||
self.use_ds_mla_sparse) or self.weight_type != "NUMA":
|
||||
return
|
||||
still_need_process = True
|
||||
do_transpose = True
|
||||
parallel_type = "col_parallel"
|
||||
# NOTE: all process_weights func should done before process_weights_after_loading
|
||||
match layer:
|
||||
case ReplicatedLinear():
|
||||
process_weights_ReplicatedLinear(layer)
|
||||
still_need_process = not ("indexer" not in layer.prefix and (
|
||||
layer.output_size == 64 or layer.output_size == 160 # Glm4-Moe
|
||||
or layer.output_size == 128 or layer.output_size == 256))
|
||||
do_transpose = False
|
||||
case MergedColumnParallelLinear():
|
||||
process_weights_MergedColumnParallelLinear(layer)
|
||||
still_need_process = False
|
||||
do_transpose = False
|
||||
case RowParallelLinear():
|
||||
parallel_type = "row_parallel"
|
||||
case _:
|
||||
pass
|
||||
|
||||
if not still_need_process or self.weight_type != "NUMA":
|
||||
return
|
||||
|
||||
# process numa weight and bias
|
||||
if hasattr(layer, "weight") and len(layer.weight.shape) == 2:
|
||||
if 'vision' in layer.prefix and is_br166_device():
|
||||
layer.weight.data = _convert_to_numa_tensor_vit(
|
||||
layer.weight,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"colmajor",
|
||||
torch.bfloat16,
|
||||
do_transpose=do_transpose,
|
||||
wk=(layer.weight.data.shape[1]
|
||||
if do_transpose else layer.weight.data.shape[0]),
|
||||
wn=(layer.weight.data.shape[0]
|
||||
if do_transpose else layer.weight.data.shape[1]),
|
||||
parallel_type=parallel_type) # noqa: SIM210
|
||||
else:
|
||||
layer.weight.data = _convert_to_numa_tensor(
|
||||
layer.weight,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"colmajor",
|
||||
torch.bfloat16,
|
||||
do_transpose=do_transpose,
|
||||
wk=(layer.weight.data.shape[1]
|
||||
if do_transpose else layer.weight.data.shape[0]),
|
||||
wn=(layer.weight.data.shape[0]
|
||||
if do_transpose else layer.weight.data.shape[1]),
|
||||
parallel_type=parallel_type) # noqa: SIM210
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
pad_zeros = (parallel_type == "row_parallel")
|
||||
if 'vision' in layer.prefix and is_br166_device():
|
||||
if (pad_zeros and layer.reduce_results):
|
||||
return
|
||||
layer.bias.data = _convert_to_numa_tensor_vit(
|
||||
layer.bias,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
else:
|
||||
layer.bias.data = _convert_to_numa_tensor(
|
||||
layer.bias,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
|
||||
@patch_to(UnquantizedLinearMethod)
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# numa weight is 3-dims
|
||||
if 'vision' in layer.prefix and is_br166_device():
|
||||
if len(layer.weight.shape) == 3:
|
||||
is_row = isinstance(layer, RowParallelLinear)
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
act_mode = "act_default"
|
||||
if isinstance(layer, MergedColumnParallelLinear) and not (hasattr(
|
||||
layer, "no_need_cross") and layer.no_need_cross):
|
||||
act_mode = "act_swiglu"
|
||||
output_size //= 2
|
||||
if bias is None or (is_row and layer.reduce_results):
|
||||
# return torch_br.br_matmul_infer(
|
||||
# x,
|
||||
# layer.weight,
|
||||
# bias=None,
|
||||
# output_w=output_size,
|
||||
# )
|
||||
return torch_br.br_fused_mlp_infer(x, [layer.weight],
|
||||
output_w=output_size,
|
||||
activation_mode=act_mode)
|
||||
else:
|
||||
return torch_br.br_matmul_infer(x, layer.weight, bias,
|
||||
output_size)
|
||||
supa_debug.set_enable_sublas_api(True)
|
||||
output = F.linear(x, layer.weight, bias)
|
||||
supa_debug.set_enable_sublas_api(False)
|
||||
return output
|
||||
if len(layer.weight.shape) == 3:
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
act_mode = "act_default"
|
||||
if isinstance(layer, MergedColumnParallelLinear) and not (hasattr(
|
||||
layer, "no_need_cross") and layer.no_need_cross):
|
||||
act_mode = "act_swiglu"
|
||||
output_size //= 2
|
||||
|
||||
bias = [bias] if bias is not None else None
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
seq_len = x.shape[-2]
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# TODO(CaoJun): This is WA, delete (16, 8) so that the test_vllm_model_accu_qwen25_72b_instruct can run through
|
||||
support_types = ((16, 4), (32, 2), (32, 4))
|
||||
# bypass tp8 and tp4pp2 allreduce
|
||||
pp_size = get_pp_group().world_size
|
||||
all_rank = tp_size * pp_size
|
||||
layer.reduce_results = not (
|
||||
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
||||
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
|
||||
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
||||
if layer.reduce_results:
|
||||
return torch_br.br_fused_mlp_infer(x, [layer.weight],
|
||||
output_w=output_size,
|
||||
bias=bias,
|
||||
activation_mode=act_mode)
|
||||
else:
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_tp_group().rank
|
||||
rank_i = global_rank % tp_size
|
||||
assert rank_i == tp_rank
|
||||
return torch_br.supa_fused_linear_allreduce_opt(
|
||||
x, layer.weight, output_size, tp_rank, tp_size,
|
||||
global_rank, 0)
|
||||
else:
|
||||
return torch_br.br_fused_mlp_infer(x, [layer.weight],
|
||||
output_w=output_size,
|
||||
bias=bias,
|
||||
activation_mode=act_mode)
|
||||
supa_debug.set_enable_sublas_api(True)
|
||||
output = F.linear(x, layer.weight, bias)
|
||||
supa_debug.set_enable_sublas_api(False)
|
||||
return output
|
||||
|
||||
|
||||
@patch_to(LinearBase)
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
super(LinearBase, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if quant_config is None:
|
||||
self.quant_method = UnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
self.prefix = prefix
|
||||
self.tp_rank = (get_tensor_model_parallel_rank() if not disable_tp else 0)
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
|
||||
|
||||
@patch_to(RowParallelLinear)
|
||||
def forward(
|
||||
self, input_
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
|
||||
self, "grandparent_pid"):
|
||||
self.grandparent_pid = get_grandparent_pid()
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
# CPU all reduce will be applied.
|
||||
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and self.tp_size >= 4 and output_parallel.shape[
|
||||
1] <= 32:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output = torch_br.supa_allreduce_pcie_infer(
|
||||
output_parallel, tp_rank, self.tp_size, self.grandparent_pid)
|
||||
else:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
@patch_to(QKVParallelLinear)
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
idx_map = {"q": 0, "k": 1, "v": 2}
|
||||
if loaded_shard_id is not None:
|
||||
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
else:
|
||||
param.shard_weight_type = {
|
||||
k: loaded_weight.item()
|
||||
for k in idx_map
|
||||
}
|
||||
return
|
||||
|
||||
if is_gguf_weight:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
|
||||
if loaded_shard_id is not None:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||
param.data_container.append(loaded_weight)
|
||||
return
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
|
||||
# Special case for per-tensor scales in fused case.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk (qkv).
|
||||
# (e.g., Phi-3's qkv_proj).
|
||||
if output_dim is None:
|
||||
if needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, 0)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("q", 0, self.total_num_heads * self.head_size),
|
||||
("k", self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
("v",
|
||||
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
]
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.total_num_heads * self.head_size),
|
||||
"k": (self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
"v":
|
||||
((self.total_num_heads + self.total_num_kv_heads) *
|
||||
self.head_size, self.total_num_kv_heads * self.head_size),
|
||||
"total":
|
||||
((self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_size, 0)
|
||||
}
|
||||
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_qkv_offsets, shard_id)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(output_dim,
|
||||
shard_offset,
|
||||
shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
|
||||
# If output dim is defined, use the default loading process.
|
||||
if output_dim is not None:
|
||||
if loaded_shard_id == "q":
|
||||
shard_offset = 0
|
||||
shard_size = self.num_heads * self.head_size
|
||||
elif loaded_shard_id == "k":
|
||||
shard_offset = self.num_heads * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
elif loaded_shard_id == "v":
|
||||
shard_offset = (self.num_heads +
|
||||
self.num_kv_heads) * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.num_heads * self.head_size),
|
||||
"k": (self.num_heads * self.head_size,
|
||||
self.num_kv_heads * self.head_size),
|
||||
"v": ((self.num_heads + self.num_kv_heads) * self.head_size,
|
||||
self.num_kv_heads * self.head_size),
|
||||
"total":
|
||||
((self.num_heads + 2 * self.num_kv_heads) * self.head_size, 0)
|
||||
}
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_qkv_offsets, loaded_shard_id)
|
||||
|
||||
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
|
||||
half_w = param_data.shape[output_dim] // 2
|
||||
param_data = (param_data.narrow(output_dim, shard_offset // 2,
|
||||
shard_size // 2),
|
||||
param_data.narrow(output_dim,
|
||||
shard_offset // 2 + half_w,
|
||||
shard_size // 2))
|
||||
else:
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
|
||||
if loaded_shard_id == "q":
|
||||
shard_id = tp_rank
|
||||
else:
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
|
||||
if not is_sharded_weight:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
# Special case for for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
||||
param_data = param_data.narrow(0, shard_index * shard_size, shard_size)
|
||||
# Special case for per-tensor scales in fused case.
|
||||
elif needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions.")
|
||||
|
||||
if isinstance(param_data, tuple):
|
||||
half_w = loaded_weight.shape[output_dim] // 2
|
||||
param_data[0].copy_(loaded_weight.narrow(output_dim, 0, half_w))
|
||||
param_data[1].copy_(loaded_weight.narrow(output_dim, half_w, half_w))
|
||||
else:
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
72
vllm_br/model_executor/layers/logits_processor.py
Normal file
72
vllm_br/model_executor/layers/logits_processor.py
Normal file
@@ -0,0 +1,72 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
# TODO(shouqing): need to open this patch when fix hang in mtp
|
||||
@patch_to(LogitsProcessor)
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.quant_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
if spc_num > 16:
|
||||
bb_input = torch_br._empty_ut_only(size=logits.shape,
|
||||
dtype=logits.dtype,
|
||||
is_numa=False,
|
||||
device=logits.device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
# work around the hang in s1b copy to bb
|
||||
bb_input.copy_(logits)
|
||||
logits = bb_input
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
logits_ = torch.zeros((logits.shape[0], logits.shape[-1] * tp_size),
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
|
||||
start = logits.shape[-1] * tp_rank
|
||||
end = start + logits.shape[-1]
|
||||
logits_[:, start:end].copy_(logits)
|
||||
logits = tensor_model_parallel_all_reduce(logits_)
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
return logits
|
||||
19
vllm_br/model_executor/layers/quantization/__init__.py
Normal file
19
vllm_br/model_executor/layers/quantization/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import compressed_tensors, gptq
|
||||
|
||||
__all__ = ["gptq", 'compressed_tensors']
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,18 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from .compressed_tensors import *
|
||||
from .compressed_tensors_moe import *
|
||||
from .compressed_tensors_wNa16 import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,64 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, cast
|
||||
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig)
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsConfig, cls_method=True)
|
||||
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
"""
|
||||
[PatchNote] add qkv_quantized param support
|
||||
"""
|
||||
ignore: list[str] = cast(list[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config)
|
||||
transform_config = config.get("transform_config")
|
||||
|
||||
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
|
||||
default=True)
|
||||
|
||||
return cls(target_scheme_map=target_scheme_map,
|
||||
ignore=ignore,
|
||||
quant_format=quant_format,
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
transform_config=transform_config,
|
||||
qkv_quantized=qkv_quantized)
|
||||
|
||||
|
||||
def wrapper_CompressedTensorsConfig_init(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
qkv_quantized = kwargs.pop("qkv_quantized", True)
|
||||
fn(self, *args, **kwargs)
|
||||
|
||||
self.qkv_quantized = qkv_quantized
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
CompressedTensorsConfig.__init__ = wrapper_CompressedTensorsConfig_init(
|
||||
CompressedTensorsConfig.__init__) # noqa: E501
|
||||
@@ -0,0 +1,594 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from compressed_tensors.compressors.quantized_compressors import (
|
||||
unpack_from_int32)
|
||||
from fastcore.basics import patch_to
|
||||
from torch_br.utils.tensor_methods import Sbp
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
WNA16_SUPPORTED_BITS, CompressedTensorsMoEMethod,
|
||||
CompressedTensorsWNA16MoEMethod, CompressionFormat)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm_br import envs
|
||||
from ...br_utils import (_convert_to_crossed_numa_tensor,
|
||||
_convert_to_numa_tensor, align_n, cross_weight_32)
|
||||
from ...fused_moe.supa_moe import fused_moe_quant_device, fused_moe_quant_dyn
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsMoEMethod)
|
||||
def get_moe_method(
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
layer: torch.nn.Module,
|
||||
) -> "CompressedTensorsMoEMethod":
|
||||
"""NOTE:
|
||||
1. SUPA only supports CompressedTensorsWNA16MoEMethod without Marlin
|
||||
2. Only Linear targets are supported for MoE layers
|
||||
"""
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
keys = list(quant_config.target_scheme_map.keys())
|
||||
assert len(keys) > 0, ("No valid quant key!!!")
|
||||
# assert "Linear" in quant_config.target_scheme_map
|
||||
# [Patch]: Only Linear target is supported for MoE layers, for temporary compatibility, we change the key of target_scheme_map to the first one
|
||||
quant_config.target_scheme_map[
|
||||
"Linear"] = quant_config.target_scheme_map.pop(keys[0])
|
||||
target_key = "Linear"
|
||||
# target_key = keys[0] # normal only one key
|
||||
weight_quant = quant_config.target_scheme_map[target_key].get("weights")
|
||||
input_quant = quant_config.target_scheme_map[target_key].get(
|
||||
"input_activations")
|
||||
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16MoEMethod)
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super(CompressedTensorsWNA16MoEMethod, self).__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
config = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.num_bits = config.num_bits
|
||||
self.packed_factor = 32 // config.num_bits
|
||||
self.strategy = config.strategy
|
||||
# channelwise is not supported by this kernel
|
||||
# [Patch]: SUPA use CompressedTensorsWNA16MoEMethod for both channel/group strategies
|
||||
# assert config.strategy == "group"
|
||||
self.group_size = config.group_size
|
||||
# grouped actorder isn't supported by this kernel
|
||||
# assert config.actorder != "group"
|
||||
assert config.symmetric, (
|
||||
"Only symmetric quantization is supported for MoE")
|
||||
|
||||
if not (self.quant_config.quant_format
|
||||
== CompressionFormat.pack_quantized.value
|
||||
and self.num_bits in WNA16_SUPPORTED_BITS):
|
||||
raise ValueError("For Fused MoE layers, only ",
|
||||
f"{CompressionFormat.pack_quantized.value} ",
|
||||
"is supported for the following bits: ",
|
||||
f"{WNA16_SUPPORTED_BITS}")
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16MoEMethod)
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
# shard for TP along the transposed dims
|
||||
extra_weight_attrs.update({
|
||||
"is_transposed": True,
|
||||
"quant_method": self.strategy
|
||||
})
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.packed_factor,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_packed", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition // self.packed_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
device="cpu"),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_packed", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w2_scales_size = intermediate_size_per_partition
|
||||
|
||||
if self.strategy == "channel":
|
||||
num_groups_w2 = num_groups_w13 = 1
|
||||
self.group_size = -1
|
||||
else:
|
||||
num_groups_w2 = w2_scales_size // self.group_size
|
||||
num_groups_w13 = hidden_size // self.group_size
|
||||
|
||||
w13_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
device="cpu",
|
||||
),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_scale)
|
||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||
|
||||
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
device="cpu"),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_scale)
|
||||
set_weight_attrs(w2_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_scale, {"load_full_w2": False})
|
||||
|
||||
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2,
|
||||
device="cpu"),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_shape", w2_weight_shape)
|
||||
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
|
||||
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2,
|
||||
device="cpu"),
|
||||
requires_grad=False)
|
||||
|
||||
layer.register_parameter("w13_weight_shape", w13_weight_shape)
|
||||
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
|
||||
|
||||
w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
|
||||
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
|
||||
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||
|
||||
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
layer.a13_scale = None
|
||||
layer.a2_scale = None
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16MoEMethod)
|
||||
def process_weights_after_loading(self: CompressedTensorsWNA16MoEMethod,
|
||||
layer: FusedMoE) -> None:
|
||||
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
die_num = 1 if die_spc_num <= 16 else 2
|
||||
spc_num = die_spc_num // die_num
|
||||
cur_device = torch.supa.current_device()
|
||||
is_dual_die = (die_spc_num > 16)
|
||||
|
||||
if self.num_bits == 8:
|
||||
# NOTE: w13_weight
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# [num_experts, hidden_size // 4, 2 * intermediate_size_per_partition] INT32
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] INT8
|
||||
wk = layer.hidden_size
|
||||
wn = layer.intermediate_size_per_partition * 2
|
||||
align_size = 64
|
||||
wn_block = align_n(wn // die_num,
|
||||
align_size=align_size,
|
||||
spc_num=spc_num)
|
||||
|
||||
supa_w13_weight_packed = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
|
||||
dtype=torch.int8,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SS" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13 = layer.w13_weight_packed[
|
||||
expert_id] # hidden_size // 4, 2 * intermediate_size_per_partition
|
||||
expert_1, expert_3 = expert_w13.chunk(
|
||||
2, dim=1) # each is a packed int4 weight
|
||||
|
||||
unpacked_expert_1 = unpack_from_int32(
|
||||
expert_1, self.num_bits,
|
||||
torch.Size(
|
||||
[layer.hidden_size,
|
||||
layer.intermediate_size_per_partition]), 0)
|
||||
|
||||
unpacked_expert_3 = unpack_from_int32(
|
||||
expert_3, self.num_bits,
|
||||
torch.Size(
|
||||
[layer.hidden_size,
|
||||
layer.intermediate_size_per_partition]), 0)
|
||||
|
||||
pad_expert_w13 = _convert_to_crossed_numa_tensor(
|
||||
unpacked_expert_1,
|
||||
unpacked_expert_3,
|
||||
die_spc_num,
|
||||
dim=1,
|
||||
need_pad=True,
|
||||
layout='COLMAJOR',
|
||||
do_transpose=False)
|
||||
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
|
||||
narrow_data = supa_w13_weight_packed.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w13)
|
||||
|
||||
layer.w13_weight_packed.data = supa_w13_weight_packed
|
||||
|
||||
# NOTE: w13_scale
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# S8: [num_experts, 1, 2 * intermediate_size_per_partition]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [num_experts, wn]
|
||||
supa_w13_scales = torch_br._empty_ut_only(
|
||||
size=(layer.local_num_experts, wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="linear_bias",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13_scales = layer.w13_weight_scale[expert_id]
|
||||
expert_1_scale, expert_3_scale = expert_w13_scales.chunk(
|
||||
2, dim=1) # each is a packed int4 weight
|
||||
crossed_expert_w13_scales = cross_weight_32(
|
||||
expert_1_scale.squeeze(),
|
||||
expert_3_scale.squeeze(),
|
||||
die_spc_num,
|
||||
dim=0,
|
||||
need_pad=False,
|
||||
)
|
||||
narrow_data = supa_w13_scales[expert_id]
|
||||
narrow_data.copy_(crossed_expert_w13_scales)
|
||||
|
||||
layer.w13_weight_scale.data = supa_w13_scales
|
||||
|
||||
# NOTE: w2_weight
|
||||
# after _load_w2, w2_weight is a colparallel weight, shape
|
||||
# [num_experts, intermediate_size_per_partition // 4, hidden_size] INT32
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] INT8
|
||||
wk = layer.intermediate_size_per_partition
|
||||
wn = layer.hidden_size
|
||||
align_size = 32
|
||||
wn_block = align_n(wn, align_size=align_size, spc_num=spc_num)
|
||||
|
||||
supa_w2_weight_packed = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * layer.local_num_experts, wk // die_num,
|
||||
wn_block),
|
||||
dtype=torch.int8,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SS" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight_packed[expert_id]
|
||||
|
||||
unpacked_expert_2 = unpack_from_int32(
|
||||
expert_w2, self.num_bits,
|
||||
torch.Size(
|
||||
[layer.intermediate_size_per_partition,
|
||||
layer.hidden_size]), 0)
|
||||
|
||||
pad_expert_w2 = _convert_to_numa_tensor(
|
||||
unpacked_expert_2,
|
||||
align_size,
|
||||
'COLMAJOR',
|
||||
expert_w2.dtype,
|
||||
do_transpose=False,
|
||||
parallel_type="row_parallel")
|
||||
pad_expert_w2_shape = pad_expert_w2.shape
|
||||
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
|
||||
narrow_data = supa_w2_weight_packed.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w2_shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w2)
|
||||
|
||||
layer.w2_weight_packed.data = supa_w2_weight_packed
|
||||
|
||||
# NOTE: w2_scale
|
||||
# after _load_w2, w2_weight is a colparallel weight, shape
|
||||
# S8: [num_experts, 1, hidden_size]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [num_experts, wn]
|
||||
supa_w2_scales = torch_br._empty_ut_only(size=(layer.local_num_experts,
|
||||
wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="linear_bias")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight_scale[expert_id]
|
||||
narrow_data = supa_w2_scales[expert_id::layer.local_num_experts]
|
||||
narrow_data.copy_(expert_w2)
|
||||
|
||||
layer.w2_weight_scale.data = supa_w2_scales
|
||||
|
||||
elif self.num_bits == 4:
|
||||
# NOTE: w13_weight
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# [num_experts, hidden_size // 8, 2 * intermediate_size_per_partition] INT32
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] INT32
|
||||
wk = layer.hidden_size // 8
|
||||
wn = layer.intermediate_size_per_partition * 2
|
||||
wn_block = align_n(wn, align_size=64, spc_num=spc_num)
|
||||
|
||||
supa_w13_weight_packed = torch_br._empty_ut_only(
|
||||
size=(spc_num * layer.local_num_experts, wk, wn_block),
|
||||
dtype=torch.int32,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13 = layer.w13_weight_packed[
|
||||
expert_id] # hidden_size // 4, 2 * intermediate_size_per_partition
|
||||
expert_1, expert_3 = expert_w13.chunk(
|
||||
2, dim=0) # each is a packed int4 weight
|
||||
|
||||
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
|
||||
expert_3,
|
||||
spc_num,
|
||||
dim=1,
|
||||
need_pad=True,
|
||||
layout='COLMAJOR',
|
||||
do_transpose=True)
|
||||
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
|
||||
narrow_data = supa_w13_weight_packed.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w13)
|
||||
|
||||
layer.w13_weight_packed.data = supa_w13_weight_packed
|
||||
|
||||
# NOTE: w13_scale
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# S4: [num_experts, hidden_size // 128, 2 * intermediate_size_per_partition]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [num_experts, group_nums, wn]
|
||||
supa_w13_scales = torch_br._empty_ut_only(
|
||||
size=(layer.local_num_experts,
|
||||
layer.hidden_size // self.group_size, wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13_scales = layer.w13_weight_scale[expert_id]
|
||||
expert_1_scale, expert_3_scale = expert_w13_scales.chunk(
|
||||
2, dim=0) # each is a packed int4 weight
|
||||
crossed_expert_w13_scales = cross_weight_32(
|
||||
expert_1_scale,
|
||||
expert_3_scale,
|
||||
spc_num,
|
||||
dim=1,
|
||||
need_pad=False,
|
||||
)
|
||||
narrow_data = supa_w13_scales[expert_id]
|
||||
narrow_data.copy_(crossed_expert_w13_scales)
|
||||
|
||||
layer.w13_weight_scale.data = supa_w13_scales
|
||||
|
||||
# NOTE: w2_weight
|
||||
# after _load_w2, w2_weight is a colparallel weight, shape
|
||||
# [num_experts, intermediate_size_per_partition // 8, hidden_size] INT32
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] INT32
|
||||
wk = layer.intermediate_size_per_partition // 8
|
||||
wn = layer.hidden_size
|
||||
wn_block = align_n(wn, align_size=32, spc_num=spc_num)
|
||||
|
||||
supa_w2_weight_packed = torch_br._empty_ut_only(
|
||||
size=(spc_num * layer.local_num_experts, wk, wn_block),
|
||||
dtype=torch.int32,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight_packed[expert_id]
|
||||
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
|
||||
spc_num,
|
||||
'COLMAJOR',
|
||||
expert_w2.dtype,
|
||||
do_transpose=True)
|
||||
pad_expert_w2_shape = pad_expert_w2.shape
|
||||
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
|
||||
narrow_data = supa_w2_weight_packed.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w2_shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w2)
|
||||
|
||||
layer.w2_weight_packed.data = supa_w2_weight_packed
|
||||
|
||||
# NOTE: w2_scale
|
||||
# after _load_w2, w2_weight is a colparallel weight, shape
|
||||
# S4: [num_experts, intermediate_size_per_partition // 128, hidden_size]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [num_experts, group_nums, wn]
|
||||
supa_w2_scales = torch_br._empty_ut_only(
|
||||
size=(layer.local_num_experts,
|
||||
layer.intermediate_size_per_partition // self.group_size,
|
||||
wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight_scale[expert_id]
|
||||
narrow_data = supa_w2_scales[expert_id::layer.local_num_experts]
|
||||
narrow_data.copy_(expert_w2)
|
||||
|
||||
layer.w2_weight_scale.data = supa_w2_scales
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits: {self.num_bits}. Only 4 and 8 are supported."
|
||||
)
|
||||
|
||||
# remove other CompressedTensorsWNA16MoEMethod registied buffer to reduce memory usage
|
||||
layer.w13_weight_shape = None
|
||||
layer.w13_weight_g_idx = None
|
||||
layer.w13_g_idx_sort_indices = None
|
||||
|
||||
layer.w2_weight_shape = None
|
||||
layer.w2_weight_g_idx = None
|
||||
layer.w2_g_idx_sort_indices = None
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16MoEMethod)
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
b_seq = x.shape[0]
|
||||
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
|
||||
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
|
||||
return fused_moe_quant_dyn(
|
||||
x,
|
||||
shared_gate_up_weight,
|
||||
shared_down_weight,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
gating_weight,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size,
|
||||
)
|
||||
else:
|
||||
return fused_moe_quant_device(
|
||||
x,
|
||||
shared_gate_up_weight,
|
||||
shared_down_weight,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
gating_weight,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
tp_rank=get_tp_group().rank_in_group,
|
||||
global_rank=get_tp_group().rank,
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size,
|
||||
)
|
||||
@@ -0,0 +1,267 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from compressed_tensors.compressors.quantized_compressors import (
|
||||
unpack_from_int32)
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.distributed import (get_pipeline_model_parallel_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
# yapf: enable
|
||||
from vllm_br import envs
|
||||
from ...br_utils import _convert_to_numa_tensor
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16)
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
input_size_per_partition: int, output_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
self.output_size_per_partition = sum(output_partition_sizes)
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = (input_size != input_size_per_partition)
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel)
|
||||
scales_and_zp_size = input_size // group_size
|
||||
if partition_scales:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
weight = PackedvLLMParameter(
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
packed_factor=self.pack_factor,
|
||||
packed_dim=1,
|
||||
data=torch.empty(self.output_size_per_partition,
|
||||
input_size_per_partition // self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
device="cpu"))
|
||||
|
||||
weight_scale_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.empty(
|
||||
self.output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
device="cpu",
|
||||
dtype=params_dtype,
|
||||
)
|
||||
}
|
||||
zeros_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.zeros(
|
||||
self.output_size_per_partition // self.pack_factor,
|
||||
scales_and_zp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
}
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
if not self.symmetric:
|
||||
qzeros = PackedColumnParameter(output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
if not self.symmetric:
|
||||
qzeros = PackedvLLMParameter(input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
weight_shape = BasevLLMParameter(data=torch.empty(2,
|
||||
dtype=torch.int64,
|
||||
device="cpu"),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
|
||||
if not self.symmetric:
|
||||
layer.register_parameter("weight_zero_point", qzeros)
|
||||
# group index (for activation reordering)
|
||||
if self.has_g_idx:
|
||||
weight_g_idx = RowvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_g_idx", weight_g_idx)
|
||||
self.input_size_per_partition = input_size_per_partition
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16)
|
||||
def process_weights_after_loading(self: CompressedTensorsWNA16,
|
||||
layer: torch.nn.Module) -> None:
|
||||
# spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
# cur_device = torch.supa.current_device()
|
||||
self.num_bits = 32 // self.pack_factor
|
||||
layer.weight_packed.data = unpack_from_int32(
|
||||
layer.weight_packed.data, self.num_bits,
|
||||
torch.Size(
|
||||
[self.output_size_per_partition, self.input_size_per_partition]),
|
||||
1)
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
br_scales = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_scale.data = br_scales
|
||||
|
||||
do_transpose = True
|
||||
parallel_type = "col_parallel"
|
||||
match layer:
|
||||
case RowParallelLinear():
|
||||
parallel_type = "row_parallel"
|
||||
case _:
|
||||
pass
|
||||
|
||||
if hasattr(layer, 'weight_packed') and len(layer.weight_packed.shape) == 2:
|
||||
weight_packed = layer.weight_packed.data
|
||||
layer.weight_packed.data = _convert_to_numa_tensor(
|
||||
weight_packed,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"colmajor",
|
||||
torch.int8,
|
||||
do_transpose=do_transpose,
|
||||
wk=(weight_packed.shape[1]
|
||||
if do_transpose else weight_packed.shape[0]),
|
||||
wn=(weight_packed.shape[0]
|
||||
if do_transpose else weight_packed.shape[1]),
|
||||
parallel_type=parallel_type) # noqa: SIM210
|
||||
|
||||
if hasattr(layer, 'weight_scale') and layer.weight_scale is not None:
|
||||
pad_zeros = False
|
||||
layer.weight_scale.data = _convert_to_numa_tensor(
|
||||
layer.weight_scale.data.T,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||||
pad_zeros = (parallel_type == "row_parallel")
|
||||
layer.bias.data = _convert_to_numa_tensor(
|
||||
layer.bias.data,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16)
|
||||
def apply_weights(self: CompressedTensorsWNA16,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# numa weight is 3-dims
|
||||
|
||||
if len(layer.weight_packed.shape) == 3:
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
act_mode = "act_default"
|
||||
if isinstance(layer, MergedColumnParallelLinear):
|
||||
act_mode = "act_swiglu"
|
||||
output_size //= 2
|
||||
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
seq_len = x.shape[-2]
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# bypass tp8 and tp4pp2 allreduce
|
||||
pp_size = get_pipeline_model_parallel_group().world_size
|
||||
all_rank = tp_size * pp_size
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
layer.reduce_results = not (
|
||||
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
||||
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
|
||||
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
||||
|
||||
if layer.reduce_results:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.weight_packed.data],
|
||||
output_w=output_size,
|
||||
scales=[layer.weight_scale.data]
|
||||
if layer.weight_scale.data is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activaion_mode=act_mode)
|
||||
else:
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_tp_group().rank
|
||||
rank_i = global_rank % tp_size
|
||||
assert rank_i == tp_rank
|
||||
return torch_br.supa_fused_linear_allreduce_opt(
|
||||
x,
|
||||
layer.weight_packed.data,
|
||||
output_size,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
global_rank,
|
||||
0,
|
||||
scales=layer.weight_scale.data,
|
||||
bias=bias,
|
||||
act_mode=act_mode)
|
||||
else:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.weight_packed.data],
|
||||
output_w=output_size,
|
||||
scales=[layer.weight_scale.data]
|
||||
if layer.weight_scale.data is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activation_mode=act_mode)
|
||||
|
||||
xn = x.shape[0]
|
||||
xh = x.shape[1]
|
||||
ww = layer.weight_packed.shape[1]
|
||||
# TODO, hard code to skip dry_run stage
|
||||
if xh >= 4096:
|
||||
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
|
||||
return torch_br.sudnn_qmatmul_infer(x,
|
||||
layer.weight_packed,
|
||||
layer.weight_scale,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,34 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
244
vllm_br/model_executor/layers/quantization/gptq.py
Normal file
244
vllm_br/model_executor/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,244 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.distributed import (get_pipeline_model_parallel_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
|
||||
GPTQLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm_br import envs
|
||||
from ..br_utils import _br_qweight_cvt, _convert_to_numa_tensor
|
||||
from ..linear import (process_weights_MergedColumnParallelLinear,
|
||||
process_weights_QuantMergedColumnParallelLinear,
|
||||
process_weights_ReplicatedLinear)
|
||||
|
||||
|
||||
@patch_to(GPTQConfig, cls_method=True)
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
|
||||
@patch_to(GPTQConfig)
|
||||
def get_quant_method(self: GPTQConfig, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQLinearMethod"]:
|
||||
quant_method = get_linear_quant_method(self, layer, prefix,
|
||||
GPTQLinearMethod)
|
||||
|
||||
return quant_method
|
||||
|
||||
|
||||
@patch_to(GPTQConfig, cls_method=True)
|
||||
def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
|
||||
"""
|
||||
[PatchNote] add qkv_quantized param support
|
||||
"""
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
|
||||
default="")
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None)
|
||||
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
|
||||
default=True)
|
||||
return cls(weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
desc_act=desc_act,
|
||||
lm_head_quantized=lm_head_quantized,
|
||||
dynamic=dynamic,
|
||||
autoround_version=autoround_version,
|
||||
modules_in_block_to_quantize=modules_in_block_to_quantize,
|
||||
qkv_quantized=qkv_quantized)
|
||||
|
||||
|
||||
def wrapper_GPTQConfig_init(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
qkv_quantized = kwargs.pop("qkv_quantized", True)
|
||||
fn(self, *args, **kwargs)
|
||||
|
||||
self.qkv_quantized = qkv_quantized
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
GPTQConfig.__init__ = wrapper_GPTQConfig_init(
|
||||
GPTQConfig.__init__) # noqa: E501
|
||||
|
||||
|
||||
@patch_to(GPTQLinearMethod)
|
||||
def process_weights_after_loading(self: GPTQLinearMethod,
|
||||
layer: torch.nn.Module) -> None:
|
||||
still_need_process = True
|
||||
merge_col_quant = False
|
||||
# NOTE: all process_weights func should done before process_weights_after_loading
|
||||
parallel_type = "col_parallel"
|
||||
match layer:
|
||||
case ReplicatedLinear():
|
||||
process_weights_ReplicatedLinear(layer)
|
||||
still_need_process = layer.output_size == 64 or layer.output_size == 256
|
||||
case MergedColumnParallelLinear():
|
||||
if hasattr(layer, "qweight"):
|
||||
merge_col_quant = True
|
||||
else:
|
||||
process_weights_MergedColumnParallelLinear(layer)
|
||||
still_need_process = False
|
||||
case RowParallelLinear():
|
||||
parallel_type = "row_parallel"
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
# NOTE: if use exllama, br gptq needs similar treatment
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if layer.qweight.dtype == torch.int32:
|
||||
input_size = layer.input_size_per_partition if hasattr(
|
||||
layer, 'input_size_per_partition') else layer.input_size
|
||||
output_size = layer.output_size_per_partition if hasattr(
|
||||
layer, 'output_size_per_partition') else layer.output_size
|
||||
br_qweight = _br_qweight_cvt(self, layer.qweight, layer.qzeros,
|
||||
input_size, output_size)
|
||||
layer.qweight.data = br_qweight
|
||||
if merge_col_quant:
|
||||
process_weights_QuantMergedColumnParallelLinear(layer)
|
||||
still_need_process = False
|
||||
|
||||
br_scales = layer.scales.to(torch.float32)
|
||||
layer.scales.data = br_scales
|
||||
|
||||
# for torch.compile
|
||||
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
if not still_need_process or self.weight_type != "NUMA":
|
||||
return
|
||||
|
||||
if hasattr(layer, 'qweight') and len(layer.qweight.shape) == 2:
|
||||
layer.qweight.data = _convert_to_numa_tensor(
|
||||
layer.qweight,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"colmajor",
|
||||
torch.int8,
|
||||
parallel_type=parallel_type)
|
||||
|
||||
if hasattr(layer, 'scales') and layer.scales is not None:
|
||||
pad_zeros = False
|
||||
layer.scales.data = _convert_to_numa_tensor(
|
||||
layer.scales,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||||
pad_zeros = (parallel_type == "row_parallel")
|
||||
layer.bias.data = _convert_to_numa_tensor(
|
||||
layer.bias,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
|
||||
@patch_to(GPTQLinearMethod)
|
||||
def apply(self: GPTQLinearMethod,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# numa weight is 3-dims
|
||||
if len(layer.qweight.shape) == 3:
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
act_mode = "act_default"
|
||||
if isinstance(layer, MergedColumnParallelLinear):
|
||||
act_mode = "act_swiglu"
|
||||
output_size //= 2
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
seq_len = x.shape[-2]
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# bypass tp8 and tp4pp2 allreduce
|
||||
pp_size = get_pipeline_model_parallel_group().world_size
|
||||
all_rank = tp_size * pp_size
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
layer.reduce_results = not (
|
||||
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
||||
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
|
||||
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
||||
if layer.reduce_results:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.qweight],
|
||||
output_w=output_size,
|
||||
scales=[layer.scales]
|
||||
if layer.scales is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activation_mode=act_mode)
|
||||
else:
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_tp_group().rank
|
||||
rank_i = global_rank % tp_size
|
||||
assert rank_i == tp_rank
|
||||
return torch_br.supa_fused_linear_allreduce_opt(
|
||||
x,
|
||||
layer.qweight,
|
||||
output_size,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
global_rank,
|
||||
0,
|
||||
scales=layer.scales,
|
||||
bias=bias,
|
||||
act_mode=act_mode)
|
||||
else:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.qweight],
|
||||
output_w=output_size,
|
||||
scales=[layer.scales] if layer.scales is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activation_mode=act_mode)
|
||||
xn = x.shape[0]
|
||||
xh = x.shape[1]
|
||||
ww = layer.qweight.shape[1]
|
||||
# TODO, hard code to skip dry_run stage
|
||||
if xh >= 4096:
|
||||
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
|
||||
return torch_br.sudnn_qmatmul_infer(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
bias=bias)
|
||||
924
vllm_br/model_executor/layers/rotary_embedding.py
Normal file
924
vllm_br/model_executor/layers/rotary_embedding.py
Normal file
@@ -0,0 +1,924 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import itertools
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.model_executor.layers.rotary_embedding
|
||||
import vllm.model_executor.models.chatglm
|
||||
import vllm.model_executor.models.deepseek_v2
|
||||
import vllm_br.envs as br_envs
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
_ROPE_DICT, DeepseekScalingRotaryEmbedding, DualChunkRotaryEmbedding,
|
||||
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
|
||||
Llama3RotaryEmbedding, Llama4VisionRotaryEmbedding, MRotaryEmbedding,
|
||||
NTKScalingRotaryEmbedding, Phi3LongRoPEScaledRotaryEmbedding,
|
||||
RotaryEmbedding, YaRNScalingRotaryEmbedding)
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
rotate_gptj, rotate_neox, yarn_find_correction_range,
|
||||
yarn_linear_ramp_mask)
|
||||
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import (
|
||||
yarn_get_mscale)
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import (
|
||||
apply_interleaved_rope)
|
||||
|
||||
|
||||
@patch_to(RotaryEmbedding)
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
op_type: str = "Half", # FIXME: other op type not supported yet
|
||||
) -> None:
|
||||
logger.info('[Patch] RotaryEmbedding use SUPA RoPE')
|
||||
super(RotaryEmbedding, self).__init__() # type: ignore
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
self.op_type = op_type # FIXME: other op type not supported yet
|
||||
|
||||
if isinstance(self, MRotaryEmbedding):
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
device = torch.cuda.current_device()
|
||||
cache = cache.to(device)
|
||||
self.cos_sin_cache: torch.Tensor # type: ignore
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
elif isinstance(self, DeepseekScalingRotaryEmbedding):
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
device = torch.supa.current_device()
|
||||
cache = cache.to(device)
|
||||
self.cos_sin_cache: torch.Tensor # type: ignore
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
else:
|
||||
sin_cache, cos_cache = self._compute_cos_sin_cache()
|
||||
sin_cache = sin_cache.to(torch.float32)
|
||||
cos_cache = cos_cache.to(torch.float32)
|
||||
device = torch.cuda.current_device()
|
||||
sin_cache = sin_cache.to(device)
|
||||
cos_cache = cos_cache.to(device)
|
||||
self.register_buffer("sin_cache", sin_cache, persistent=False)
|
||||
self.register_buffer("cos_cache", cos_cache, persistent=False)
|
||||
|
||||
|
||||
@patch_to(RotaryEmbedding)
|
||||
def _compute_cos_sin_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute the cos and sin cache."""
|
||||
with torch.device('cpu'):
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
if isinstance(self, MRotaryEmbedding):
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
else:
|
||||
if self.op_type == "Half" or self.op_type == "TeleChat":
|
||||
freqs = freqs.repeat(1, 2)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
else:
|
||||
cos_freqs = freqs.repeat_interleave(2, dim=-1)
|
||||
cos = cos_freqs.cos()
|
||||
scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1
|
||||
sin_freqs = cos_freqs * scales.reshape_as(cos_freqs)
|
||||
sin = sin_freqs.sin()
|
||||
return sin, cos
|
||||
|
||||
|
||||
@patch_to(RotaryEmbedding)
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query_, key_ = torch_br.supa_rope_infer_v2(query,
|
||||
key,
|
||||
self.sin_cache,
|
||||
self.cos_cache,
|
||||
positions,
|
||||
self.head_size,
|
||||
rope_type=self.op_type,
|
||||
rotary_size=self.rotary_dim)
|
||||
return query_, key_
|
||||
|
||||
|
||||
@patch_to(RotaryEmbedding)
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class SupaDeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
with torch.device('cpu'):
|
||||
pos_freqs = self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device="cpu") /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast, self.beta_slow, self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2,
|
||||
dtype=torch.float)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
with torch.device('cpu'):
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings *
|
||||
self.scaling_factor,
|
||||
dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos_freqs = freqs.repeat_interleave(2, dim=-1)
|
||||
cos = (cos_freqs.cos() * self.mscale)
|
||||
scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1
|
||||
sin_freqs = cos_freqs * scales.reshape_as(cos_freqs)
|
||||
sin = (sin_freqs.sin() * self.mscale)
|
||||
return sin, cos
|
||||
|
||||
|
||||
@patch_to(DeepseekScalingRotaryEmbedding)
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
with torch.device('cpu'):
|
||||
pos_freqs = self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device="cpu") /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2,
|
||||
dtype=torch.float)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
|
||||
@patch_to(DeepseekScalingRotaryEmbedding)
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
with torch.device('cpu'):
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
|
||||
@patch_to(DeepseekScalingRotaryEmbedding)
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
assert key is not None
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
cos_sin = self.cos_sin_cache[
|
||||
torch.add(positions, offsets) if offsets is not None else positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if self.is_neox_style:
|
||||
# NOTE(woosuk): Here we assume that the positions tensor has the
|
||||
# shape [batch_size, seq_len].
|
||||
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
else:
|
||||
device = torch.supa.current_device()
|
||||
cos = cos.to('cpu')
|
||||
sin = sin.to('cpu')
|
||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
cos = cos.to(device)
|
||||
sin = sin.to(device)
|
||||
|
||||
rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
|
||||
device = query_rot.device
|
||||
if query.shape[0] > 1024:
|
||||
query_rot = query_rot.to('cpu')
|
||||
key_rot = key_rot.to('cpu')
|
||||
cos = cos.to('cpu')
|
||||
sin = sin.to('cpu')
|
||||
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
if query.shape[0] > 1024:
|
||||
query_rot = query_rot.to(device)
|
||||
key_rot = key_rot.to(device)
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
else:
|
||||
query = query_rot
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
|
||||
@patch_to(DeepseekScalingRotaryEmbedding)
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query, key = self.forward_native(positions, query, key, offsets)
|
||||
return query, key
|
||||
|
||||
|
||||
@patch_to(YaRNScalingRotaryEmbedding)
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
with torch.device('cpu'):
|
||||
pos_freqs = self.base**(
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2,
|
||||
dtype=torch.float)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
|
||||
@patch_to(YaRNScalingRotaryEmbedding)
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
with torch.device('cpu'):
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
dtype=torch.float32)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
freqs = freqs.repeat(1, 2)
|
||||
cos = freqs.cos() * self.mscale
|
||||
sin = freqs.sin() * self.mscale
|
||||
return sin, cos
|
||||
|
||||
|
||||
def dtnamicNTK_compute_cos_sin_cache(
|
||||
self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute the cos and sin cache."""
|
||||
with torch.device('cpu'):
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
if self.op_type == "Half" or self.op_type == "TeleChat":
|
||||
freqs = freqs.repeat(1, 2)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
else:
|
||||
cos_freqs = freqs.repeat_interleave(2, dim=-1)
|
||||
cos = cos_freqs.cos()
|
||||
scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1
|
||||
sin_freqs = cos_freqs * scales.reshape_as(cos_freqs)
|
||||
sin = sin_freqs.sin()
|
||||
return sin, cos
|
||||
|
||||
|
||||
def dynamicNTKScaling_rope_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if query.shape[-1] != key.shape[-1]:
|
||||
query_, key_ = torch_br.supa_rope_infer_v2(query,
|
||||
key,
|
||||
self.sin_cache,
|
||||
self.cos_cache,
|
||||
positions,
|
||||
self.head_size,
|
||||
rope_type="MRope")
|
||||
else:
|
||||
query_, key_ = torch_br.supa_rope_infer_v2(query,
|
||||
key,
|
||||
self.sin_cache,
|
||||
self.cos_cache,
|
||||
positions,
|
||||
self.head_size,
|
||||
rope_type=self.op_type)
|
||||
return query_, key_
|
||||
|
||||
|
||||
DynamicNTKScalingRotaryEmbedding._compute_cos_sin_cache = dtnamicNTK_compute_cos_sin_cache
|
||||
DynamicNTKScalingRotaryEmbedding.forward = dynamicNTKScaling_rope_forward
|
||||
|
||||
|
||||
def _apply_rotary_emb_torch(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
|
||||
is_neox_style: bool) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
|
||||
|
||||
|
||||
def forward_MRotaryEmbedding_0_9_2(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
|
||||
cos = torch.cat([
|
||||
m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
sin = torch.cat([
|
||||
m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
def forward_supa(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
if br_envs.VLLM_BR_USE_MROPE_0_9_2:
|
||||
return forward_MRotaryEmbedding_0_9_2(self, positions, query, key)
|
||||
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
data_in_supa = lambda t: str(t.device).startswith('supa')
|
||||
data_in_cpu = lambda t: t.device == torch.device('cpu')
|
||||
|
||||
if positions.ndim == 2:
|
||||
# use bypass for decode stage
|
||||
if (positions.shape[1] == 1):
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = cos[0]
|
||||
sin = sin[0]
|
||||
else:
|
||||
cos_sin = self.cos_sin_cache[positions.to(torch.int64)]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
assert self.mrope_section
|
||||
|
||||
if self.mrope_interleaved:
|
||||
cos = apply_interleaved_rope(cos, self.mrope_section)
|
||||
sin = apply_interleaved_rope(sin, self.mrope_section)
|
||||
else:
|
||||
cos = torch.cat([
|
||||
m[i] for i, m in enumerate(
|
||||
cos.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
sin = torch.cat([
|
||||
m[i] for i, m in enumerate(
|
||||
sin.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
if data_in_supa(query) and data_in_supa(key):
|
||||
sin = sin.supa() if data_in_cpu(sin) else sin
|
||||
cos = cos.supa() if data_in_cpu(cos) else cos
|
||||
positions = positions.supa() if data_in_cpu(positions) else positions
|
||||
|
||||
query, key = torch_br.supa_rope_infer_v2(query,
|
||||
key,
|
||||
sin.to(torch.float32),
|
||||
cos.to(torch.float32),
|
||||
positions.to(torch.int32),
|
||||
self.head_size,
|
||||
rope_type="MRope")
|
||||
return query, key
|
||||
|
||||
|
||||
MRotaryEmbedding.forward = forward_supa
|
||||
|
||||
|
||||
def get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
op_type: str = "Half",
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in rope_scaling.items()
|
||||
}
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
|
||||
if dual_chunk_attention_config is not None:
|
||||
dual_chunk_attention_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k != "sparse_attention_config"
|
||||
}
|
||||
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
|
||||
else:
|
||||
dual_chunk_attention_args = None
|
||||
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling_args, dual_chunk_attention_args, dtype)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
if dual_chunk_attention_config is not None:
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k in ("chunk_size", "local_size")
|
||||
}
|
||||
rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style, dtype,
|
||||
**extra_kwargs)
|
||||
elif not rope_scaling:
|
||||
rotary_emb = RotaryEmbedding(head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
op_type=op_type)
|
||||
else:
|
||||
scaling_type = rope_scaling["rope_type"]
|
||||
|
||||
if scaling_type == "llama3":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style, dtype,
|
||||
scaling_factor, low_freq_factor,
|
||||
high_freq_factor,
|
||||
original_max_position)
|
||||
elif scaling_type == "mllama4":
|
||||
rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style, dtype)
|
||||
elif scaling_type == "default":
|
||||
if "mrope_section" in rope_scaling:
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype=torch.float32,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
mrope_interleaved=rope_scaling.get("mrope_interleaved",
|
||||
False),
|
||||
)
|
||||
else:
|
||||
rotary_emb = RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "linear":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style,
|
||||
scaling_factor, dtype)
|
||||
elif scaling_type == "ntk":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mixed_b = rope_scaling.get('mixed_b', None)
|
||||
rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style,
|
||||
scaling_factor, dtype,
|
||||
mixed_b)
|
||||
elif scaling_type == "dynamic":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
scaling_factor, dtype)
|
||||
elif scaling_type == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
"beta_slow")
|
||||
}
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
original_max_position,
|
||||
base, is_neox_style,
|
||||
scaling_factor, dtype,
|
||||
**extra_kwargs)
|
||||
elif scaling_type == "deepseek_yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
# assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
"beta_slow", "mscale", "mscale_all_dim")
|
||||
}
|
||||
rotary_emb = DeepseekScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, original_max_position, base,
|
||||
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
||||
elif scaling_type == "deepseek_yarn_supa":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
# assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
"beta_slow", "mscale", "mscale_all_dim")
|
||||
}
|
||||
rotary_emb = SupaDeepseekScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, original_max_position, base,
|
||||
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
||||
elif scaling_type == "longrope":
|
||||
short_factor = rope_scaling["short_factor"]
|
||||
long_factor = rope_scaling["long_factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("short_mscale", "long_mscale")
|
||||
}
|
||||
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, original_max_position,
|
||||
base, is_neox_style, dtype, short_factor, long_factor,
|
||||
**extra_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
return rotary_emb
|
||||
|
||||
|
||||
def deepseek_get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
) -> RotaryEmbedding:
|
||||
return get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling, dtype, partial_rotary_factor,
|
||||
dual_chunk_attention_config, "DeepSeek")
|
||||
|
||||
|
||||
def chatglm2_get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
) -> RotaryEmbedding:
|
||||
return get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling, dtype, partial_rotary_factor,
|
||||
dual_chunk_attention_config, "DeepSeek")
|
||||
|
||||
|
||||
vllm.model_executor.layers.rotary_embedding.get_rope = get_rope
|
||||
vllm.model_executor.models.deepseek_v2.get_rope = deepseek_get_rope
|
||||
vllm.model_executor.models.chatglm.get_rope = chatglm2_get_rope
|
||||
|
||||
|
||||
@patch_to(MRotaryEmbedding)
|
||||
def _glm4v_get_input_positions_tensor(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value for GLM4V."""
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_start_token_id = hf_config.video_start_token_id
|
||||
video_end_token_id = hf_config.video_end_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
if not (image_grid_thw is None and video_grid_thw is None):
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
|
||||
input_token_type: list[str] = []
|
||||
video_check_flg = False
|
||||
for token in input_tokens:
|
||||
if token == video_start_token_id:
|
||||
video_check_flg = True
|
||||
elif token == video_end_token_id:
|
||||
video_check_flg = False
|
||||
|
||||
if (token == image_token_id) and (video_check_flg is False):
|
||||
input_token_type.append("image")
|
||||
elif (token == image_token_id) and (video_check_flg is True):
|
||||
input_token_type.append("video")
|
||||
else:
|
||||
input_token_type.append("text")
|
||||
|
||||
input_type_group: list[tuple[str, int, int]] = []
|
||||
for key, group_iter in itertools.groupby(enumerate(input_token_type),
|
||||
lambda x: x[1]):
|
||||
group_list = list(group_iter)
|
||||
start_index = group_list[0][0]
|
||||
end_index = group_list[-1][0] + 1
|
||||
input_type_group.append((key, start_index, end_index))
|
||||
|
||||
video_frame_num = 1
|
||||
mm_data_idx = 0
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
if modality_type == "image":
|
||||
t, h, w = (
|
||||
image_grid_thw[mm_data_idx][0],
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_merge_size, w // spatial_merge_size
|
||||
|
||||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).flatten()
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
llm_grid_t, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
llm_grid_t, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||
mm_data_idx += 1
|
||||
|
||||
elif modality_type == "video":
|
||||
t, h, w = (
|
||||
video_frame_num,
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_merge_size, w // spatial_merge_size
|
||||
|
||||
for t_idx in range(llm_grid_t):
|
||||
t_index = torch.tensor(t_idx).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).flatten()
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
1, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
1, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||
|
||||
mm_data_idx += 1
|
||||
video_frame_num += 1
|
||||
|
||||
else:
|
||||
text_len = end_idx - start_idx
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
video_frame_num = 1
|
||||
|
||||
else:
|
||||
text_len = len(input_tokens)
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
|
||||
@patch_to(MRotaryEmbedding)
|
||||
def get_input_positions_tensor_for_glm(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
second_per_grid_ts: list[float],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
from vllm.transformers_utils.config import thinker_uses_mrope
|
||||
if thinker_uses_mrope(hf_config):
|
||||
return cls._omni_get_input_positions_tensor(
|
||||
input_tokens=input_tokens,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=context_len,
|
||||
seq_len=seq_len,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
elif "glm4v" in hf_config.model_type:
|
||||
return cls._glm4v_get_input_positions_tensor(
|
||||
cls,
|
||||
input_tokens=input_tokens,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
context_len=context_len,
|
||||
seq_len=seq_len,
|
||||
)
|
||||
else:
|
||||
return cls._vl_get_input_positions_tensor(
|
||||
input_tokens=input_tokens,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=context_len,
|
||||
seq_len=seq_len,
|
||||
)
|
||||
65
vllm_br/model_executor/layers/utils.py
Normal file
65
vllm_br/model_executor/layers/utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask
|
||||
|
||||
|
||||
def apply_penalties_fit(logits: torch.Tensor,
|
||||
prompt_tokens_tensor: torch.Tensor,
|
||||
output_tokens_tensor: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies penalties in place to the logits tensor
|
||||
logits : The input logits tensor of shape [num_seqs, vocab_size]
|
||||
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
|
||||
are padded to the maximum prompt length within the batch using
|
||||
`vocab_size` as the padding value. The value `vocab_size` is used
|
||||
for padding because it does not correspond to any valid token ID
|
||||
in the vocabulary.
|
||||
output_tokens_tensor: The output tokens tensor.
|
||||
presence_penalties: The presence penalties of shape (num_seqs, )
|
||||
frequency_penalties: The frequency penalties of shape (num_seqs, )
|
||||
repetition_penalties: The repetition penalties of shape (num_seqs, )
|
||||
"""
|
||||
num_seqs, vocab_size = logits.shape
|
||||
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
|
||||
vocab_size, num_seqs)
|
||||
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
||||
output_tokens_tensor, vocab_size, num_seqs)
|
||||
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
||||
1, vocab_size)
|
||||
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
|
||||
1.0)
|
||||
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
||||
logits *= scaling
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
||||
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
||||
return logits
|
||||
|
||||
|
||||
vllm.model_executor.layers.utils.apply_penalties = apply_penalties_fit
|
||||
139
vllm_br/model_executor/layers/vocab_parallel_embedding.py
Normal file
139
vllm_br/model_executor/layers/vocab_parallel_embedding.py
Normal file
@@ -0,0 +1,139 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_br
|
||||
import torch_br.supa._debug as supa_debug
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
import vllm
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
@patch_to(UnquantizedEmbeddingMethod)
|
||||
def process_weights_after_loading(self, module):
|
||||
if envs.VLLM_BR_EMBEDDING_S0B:
|
||||
ori_weight = module.weight.data.cpu()
|
||||
module.weight.data = torch_br._empty_ut_only(module.weight.shape,
|
||||
"colmajor",
|
||||
False,
|
||||
0,
|
||||
dtype=module.weight.dtype,
|
||||
sbp='SB')
|
||||
module.weight.data.copy_(ori_weight)
|
||||
|
||||
|
||||
@patch_to(UnquantizedEmbeddingMethod)
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
if envs.VLLM_BR_EMBEDDING_S0B:
|
||||
y_supa = torch_br._empty_ut_only(
|
||||
[1, input_.shape[0], layer.weight.shape[-1]],
|
||||
is_numa=False,
|
||||
dtype=layer.weight.dtype,
|
||||
sbp='BB',
|
||||
tensor_type="colmajor",
|
||||
)
|
||||
torch_br.out_embedding(y_supa, layer.weight.data, input_, -1, -1)
|
||||
y_supa.squeeze_(0)
|
||||
return y_supa
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.jit.script will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
if spc_num > 16:
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
||||
input_ < org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index -
|
||||
org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
else:
|
||||
input_, inv_vocab_mask = torch_br.supa_embedding_mask_infer(
|
||||
input_, org_vocab_start_index, org_vocab_end_index,
|
||||
num_org_vocab_padding, added_vocab_start_index,
|
||||
added_vocab_end_index)
|
||||
return input_, inv_vocab_mask
|
||||
|
||||
|
||||
vllm.model_executor.layers.vocab_parallel_embedding.get_masked_input_and_mask = get_masked_input_and_mask
|
||||
|
||||
|
||||
def vocab_parallel_embedding_forward(self, input_) -> torch.Tensor:
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_,
|
||||
self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index,
|
||||
)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1),
|
||||
0) # type: ignore
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# numa weight is 3-dims
|
||||
if len(layer.weight.shape) == 3:
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.weight],
|
||||
output_w=output_size,
|
||||
bias=[bias] if bias is not None else None)
|
||||
supa_debug.set_enable_sublas_api(True)
|
||||
output = F.linear(x, layer.weight, bias)
|
||||
supa_debug.set_enable_sublas_api(False)
|
||||
return output
|
||||
|
||||
|
||||
UnquantizedEmbeddingMethod.apply = apply
|
||||
|
||||
VocabParallelEmbedding.forward = vocab_parallel_embedding_forward
|
||||
Reference in New Issue
Block a user