first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

View File

@@ -0,0 +1,25 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import vllm_br.model_executor.layers.activation
import vllm_br.model_executor.layers.fused_moe
import vllm_br.model_executor.layers.layernorm
import vllm_br.model_executor.layers.linear
import vllm_br.model_executor.layers.logits_processor
import vllm_br.model_executor.layers.quantization
import vllm_br.model_executor.layers.rotary_embedding
import vllm_br.model_executor.layers.utils
import vllm_br.model_executor.layers.vocab_parallel_embedding # noqa: F401

View File

@@ -0,0 +1,31 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import torch
import torch_br
from fastcore.basics import patch_to
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
@patch_to(SiluAndMul)
def silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return torch_br.supa_silumul(x[..., :d], x[..., d:]) # type: ignore
@patch_to(QuickGELU)
def quick_gelu_forward_oot(self, x: torch.Tensor) -> torch.Tensor: # noqa:F811
return self.forward_native(x)

View File

@@ -0,0 +1,619 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import torch
import torch_br
import torch_br.supa._debug as supa_debug
from vllm_br import envs
def align_n(n, align_size, spc_num=envs.VLLM_BR_DEVICE_SPC_NUM) -> int:
n_block = (n + spc_num - 1) // spc_num
n_block = (n_block + align_size - 1) // align_size * align_size
return n_block
def _br_qweight_cvt(quant_method,
qweight,
qzeros,
size_k,
size_n,
override_group_size=None):
group_size = override_group_size or quant_method.quant_config.group_size
curr_dev = qweight.device
group_num = size_k // group_size if group_size > 0 else 1
qweight = qweight.cpu().view(torch.int8).reshape(
size_k // 4, size_n,
4).permute(0, 2, 1).contiguous().reshape(group_num,
size_k // group_num, size_n)
if qzeros is not None and not torch.all(qzeros == 0):
qzeros = qzeros.cpu().view(torch.int8).to(torch.int32) + 1
qweight = (qweight.to(torch.int32) - qzeros.unsqueeze(1)).to(
torch.int8)
qwei_int8 = qweight.reshape(size_k, size_n).to(curr_dev)
return qwei_int8
def _numa_scales_cvt(scales, wn, spc_num):
align_size = 32
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
cvt_scales = torch.nn.functional.pad(scales, (0, spc_num * wn_block - wn),
mode='constant',
value=0)
cvt_scales = cvt_scales.reshape(spc_num, wn_block).contiguous()
return cvt_scales
def cross_weight_32(t1, t2, spc_num, dim=1, need_pad=True):
width = t1.shape[dim]
# NOTE: br166 must ensure dual-dies width are 32-aligned
if spc_num > 16:
assert width % 2 == 0
half_width = width // 2
half_width_ = (half_width + 32 - 1) // 32 * 32
half_pad = half_width_ - half_width
if half_pad > 0:
t10, t11 = torch.chunk(t1, 2, dim=-1)
t10 = torch.nn.functional.pad(t10, (0, half_pad), "constant", 0)
t11 = torch.nn.functional.pad(t11, (0, half_pad), "constant", 0)
t1 = torch.cat([t10, t11], dim=-1)
t20, t21 = torch.chunk(t2, 2, dim=-1)
t20 = torch.nn.functional.pad(t20, (0, half_pad), "constant", 0)
t21 = torch.nn.functional.pad(t21, (0, half_pad), "constant", 0)
t2 = torch.cat([t20, t21], dim=-1)
width = half_width_ * 2
else:
width_ = (width + 32 - 1) // 32 * 32
t1 = torch.nn.functional.pad(t1, (0, width_ - width), "constant", 0)
t2 = torch.nn.functional.pad(t2, (0, width_ - width), "constant", 0)
width = width_
cnt = width // 32
t1_list = torch.chunk(t1, cnt, dim)
t2_list = torch.chunk(t2, cnt, dim)
tt = []
for i in range(cnt):
tt.append(t1_list[i])
tt.append(t2_list[i])
no_pad = torch.cat(tt, dim=dim)
if not need_pad:
return no_pad
if spc_num > 16:
align = (spc_num // 2) * 32 * 2
width_align = (width + align - 1) // align * align
pad_size = width_align - width
out0, out1 = torch.chunk(no_pad, 2, dim=-1)
out0 = torch.nn.functional.pad(out0, (0, pad_size), "constant", 0)
out1 = torch.nn.functional.pad(out1, (0, pad_size), "constant", 0)
out = torch.cat([out0, out1], dim=-1)
else:
align = spc_num * 32 * 2 # 768
width_align = (width * 2 + align - 1) // align * align
pad_size = width_align - width * 2
out = torch.nn.functional.pad(no_pad, (0, pad_size), "constant", 0)
return out
# # NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
def _convert_to_uma_tensor(tensor,
align_size,
layout,
dtype,
do_transpose=False,
wk=None,
wn=None,
parallel_type="col_parallel"):
assert parallel_type in ("col_parallel", "row_parallel")
layout = layout.lower()
if layout == "colmajor":
wk = wk or tensor.shape[1]
wn = wn or tensor.shape[0]
d_shape = (wn, wk)
if do_transpose:
data = tensor.cpu().permute(1, 0).contiguous()
d_shape = (wk, wn)
else:
data = tensor.cpu().contiguous()
if parallel_type == "col_parallel":
uma_tensor = torch_br._empty_ut_only(
size=d_shape,
dtype=dtype,
is_numa=False,
device=torch.supa.current_device(),
tensor_type=layout,
sbp="SB",
axis=0)
else:
uma_tensor = torch_br._empty_ut_only(
size=d_shape,
dtype=dtype,
is_numa=False,
device=torch.supa.current_device(),
tensor_type=layout,
axis=1,
sbp="SB")
torch.supa.synchronize()
uma_tensor.copy_(data.to(torch.supa.current_device()))
elif layout == "linear_bias":
axis = 0
wn = wn or tensor.shape[-1]
wk = 1
data = tensor
if len(data.shape) == 2 and data.shape[0] == 1:
data = tensor.cpu().reshape(-1).contiguous()
elif len(data.shape) == 2:
axis = 1
wk = data.shape[0]
elif len(data.shape) == 3 and data.shape[1] == 1:
data = tensor.cpu().reshape(
(data.shape[0], data.shape[2])).contiguous()
axis = 1
wk = data.shape[0]
d_shape = (wn, ) if axis == 0 else (wk, wn)
if parallel_type == "row_parallel":
uma_tensor = torch_br._empty_ut_only(
size=d_shape,
dtype=dtype,
device=torch.supa.current_device(),
tensor_type=layout)
elif parallel_type == "col_parallel":
uma_tensor = torch_br._empty_ut_only(
size=d_shape,
dtype=dtype,
device=torch.supa.current_device(),
tensor_type=layout,
axis=axis,
sbp="SB")
torch.supa.synchronize()
uma_tensor.copy_(data.to(torch.supa.current_device()))
else:
raise ValueError("uma tensor only support colmajor and linear_bias")
return uma_tensor
def _convert_to_numa_tensor_vit(tensor,
align_size,
layout,
dtype,
do_transpose=False,
wk=None,
wn=None,
parallel_type="col_parallel",
pad_zeros=False):
assert parallel_type in ("col_parallel", "row_parallel")
enable_force_uma = supa_debug.is_enable_force_uma()
supa_debug.set_enable_force_uma(False)
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
layout = layout.lower()
die_num = 1
if spc_num > 16:
spc_num = spc_num // 2
die_num = 2
die_spc_num = die_num * spc_num
if layout == "colmajor":
wk = wk or tensor.shape[0]
wn = wn or tensor.shape[1]
if die_num == 1:
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(spc_num, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
data = tensor.cpu().permute(1, 0).contiguous()
else:
data = tensor.cpu().contiguous()
data = torch.nn.functional.pad(data,
(0, spc_num * wn_block - wn, 0, 0),
mode='constant',
value=0)
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
2).contiguous()
torch.supa.synchronize()
numa_tensor.copy_(data.to(torch.supa.current_device()))
else:
if parallel_type == "col_parallel":
wn_block = (wn // die_num + spc_num - 1) // spc_num
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
wk, die_num, wn // die_num)
else:
weight = tensor.cpu().contiguous().reshape(
wk, die_num, wn // die_num)
weight = torch.nn.functional.pad(
weight,
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(wk, die_spc_num,
wn_block).permute(1, 0,
2).contiguous()
numa_tensor.copy_(weight)
else:
wn_block = (wn + spc_num - 1) // spc_num
# w_block must align with 32 (warp_size)
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wk // die_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
weight = tensor.cpu().permute(1, 0).contiguous()
else:
weight = tensor.cpu().contiguous()
weight = torch.nn.functional.pad(
weight, (0, spc_num * wn_block - wn, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(
die_num, wk // die_num, spc_num,
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
die_spc_num, wk // die_num, wn_block)
numa_tensor.copy_(weight)
elif layout == "linear_bias":
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
# NOTE: index -1 for both scales and bias
wn = tensor.shape[-1] if wn is None else wn
group_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
if die_num == 1:
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(spc_num * group_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
data = torch.nn.functional.pad(tensor.cpu(),
(0, spc_num * wn_block - wn),
mode='constant',
value=0)
if group_num > 1:
data = data.type(dtype).reshape(
group_num, spc_num,
wn_block).permute(1, 0, 2).contiguous().reshape(
spc_num * group_num, wn_block)
else:
data = data.type(dtype).reshape(spc_num, wn_block).contiguous()
torch.supa.synchronize()
numa_tensor.copy_(data.to(torch.supa.current_device()))
else:
if parallel_type == "col_parallel":
wn_block = (wn // die_num + spc_num - 1) // spc_num
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type="linear_bias")
bias = tensor.cpu().reshape(die_num, wn // die_num)
bias = torch.nn.functional.pad(
bias, (0, spc_num * wn_block - wn // die_num, 0, 0),
mode='constant',
value=0)
bias = bias.type(torch.float32).reshape(die_spc_num,
wn_block).contiguous()
numa_tensor.copy_(bias)
else:
wn_block = (wn + spc_num - 1) // spc_num
# w_block must align with 32 (warp_size)
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type="linear_bias")
bias = torch.nn.functional.pad(tensor.cpu(),
(0, spc_num * wn_block - wn),
mode='constant',
value=0)
bias = bias.type(torch.float32).reshape(spc_num,
wn_block).contiguous()
if pad_zeros:
bias_zeros_die2 = torch.zeros((spc_num, wn_block),
dtype=bias.dtype)
bias = torch.concat([bias, bias_zeros_die2], dim=0)
else:
bias = torch.concat([bias, bias], dim=0)
numa_tensor.copy_(bias)
else:
raise ValueError(f"Unsupported tensor_type: {layout}")
torch.supa.synchronize()
supa_debug.set_enable_force_uma(enable_force_uma)
return numa_tensor
def _convert_to_numa_tensor(tensor,
align_size,
layout,
dtype,
do_transpose=False,
wk=None,
wn=None,
parallel_type="col_parallel",
pad_zeros=False):
assert parallel_type in ("col_parallel", "row_parallel")
enable_force_uma = supa_debug.is_enable_force_uma()
supa_debug.set_enable_force_uma(False)
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
layout = layout.lower()
die_num = 1
if spc_num > 16:
spc_num = spc_num // 2
die_num = 2
die_spc_num = die_num * spc_num
if layout == "colmajor":
wk = wk or tensor.shape[0]
wn = wn or tensor.shape[1]
if die_num == 1:
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(spc_num, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
data = tensor.cpu().permute(1, 0).contiguous()
else:
data = tensor.cpu().contiguous()
data = torch.nn.functional.pad(data,
(0, spc_num * wn_block - wn, 0, 0),
mode='constant',
value=0)
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
2).contiguous()
torch.supa.synchronize()
numa_tensor.copy_(data.to(torch.supa.current_device()))
else:
if parallel_type == "col_parallel":
wn_block = (wn // die_num + spc_num - 1) // spc_num
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout,
axis=0,
sbp="SS")
if do_transpose:
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
wk, die_num, wn // die_num)
else:
weight = tensor.cpu().contiguous().reshape(
wk, die_num, wn // die_num)
weight = torch.nn.functional.pad(
weight,
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(wk, die_spc_num,
wn_block).permute(1, 0,
2).contiguous()
numa_tensor.copy_(weight)
else:
wn_block = (wn + spc_num - 1) // spc_num
# w_block must align with 32 (warp_size)
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wk // die_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout,
axis=0,
sbp="SS")
if do_transpose:
weight = tensor.cpu().permute(1, 0).contiguous()
else:
weight = tensor.cpu().contiguous()
weight = torch.nn.functional.pad(
weight, (0, spc_num * wn_block - wn, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(
die_num, wk // die_num, spc_num,
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
die_spc_num, wk // die_num, wn_block)
numa_tensor.copy_(weight)
elif layout == "linear_bias":
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
# NOTE: index -1 for both scales and bias
wn = tensor.shape[-1] if wn is None else wn
expert_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
bias_shape = (expert_num, wn) if expert_num > 1 else (wn, )
if die_num == 1:
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
dtype=dtype,
is_numa=False,
device=tensor.device,
tensor_type=layout)
data = tensor.cpu().type(dtype)
if expert_num > 1:
data = data.reshape(expert_num, wn)
else:
data = data.reshape(wn).type(dtype)
torch.supa.synchronize()
numa_tensor.copy_(data.to(tensor.device))
else:
if parallel_type == "col_parallel":
axis = 1 if expert_num > 1 else 0
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
dtype=dtype,
is_numa=False,
device=tensor.device,
tensor_type="buffer_any",
axis=axis,
sbp="SB")
if expert_num == 1:
tensor = tensor.reshape(-1)
numa_tensor.copy_(tensor.to(torch.supa.current_device()))
else:
numa_tensor = torch_br._empty_ut_only(
size=bias_shape,
dtype=dtype,
is_numa=False,
device=tensor.device,
tensor_type="linear_bias",
sbp="BB")
bias = tensor.reshape(expert_num, wn).cpu().type(dtype)
if expert_num == 1:
bias = bias.reshape(-1)
numa_tensor.copy_(bias.to(torch.supa.current_device()))
else:
raise ValueError(f"Unsupported tensor_type: {layout}")
torch.supa.synchronize()
supa_debug.set_enable_force_uma(enable_force_uma)
return numa_tensor
def _convert_to_crossed_numa_tensor(t1,
t2,
spc_num,
dim=1,
need_pad=True,
layout="colmajor",
do_transpose=False):
"""Equals to V0: cross_weight_32 + numa_weight_convert/_numa_weight_cvt
"""
uma_weight = cross_weight_32(t1, t2, spc_num, dim, need_pad)
numa_weight = _convert_to_numa_tensor(uma_weight, 32, layout,
uma_weight.dtype, do_transpose)
return numa_weight
def _convert_to_numa_tensor_moe(tensor,
align_size,
layout,
dtype,
do_transpose=False,
wb=None,
wk=None,
wn=None,
parallel_type="col_parallel",
pad_zeros=False):
assert parallel_type in ("col_parallel", "row_parallel")
enable_force_uma = supa_debug.is_enable_force_uma()
supa_debug.set_enable_force_uma(False)
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
layout = layout.lower()
die_num = 1
if spc_num > 16:
spc_num = spc_num // 2
die_num = 2
die_spc_num = die_num * spc_num
assert die_num == 2
if layout == "colmajor":
wb = wb or tensor.shape[0]
wk = wk or tensor.shape[1]
wn = wn or tensor.shape[2]
if parallel_type == "col_parallel":
wn_block = (wn // die_num + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num * wb, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
weight = tensor.cpu().permute(0, 2, 1).contiguous().reshape(
wb, wk, die_num, wn // die_num)
else:
weight = tensor.cpu().contiguous().reshape(
wb, wk, die_num, wn // die_num)
weight = torch.nn.functional.pad(
weight,
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(wb, wk, die_spc_num, wn_block).permute(
2, 0, 1, 3).reshape(wb * die_spc_num, wk,
wn_block).contiguous()
numa_tensor.copy_(weight)
elif parallel_type == "row_parallel":
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num * wb, wk // die_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
weight = tensor.cpu().permute(0, 2, 1).contiguous()
else:
weight = tensor.cpu().contiguous()
weight = torch.nn.functional.pad(
weight, (0, spc_num * wn_block - wn, 0, 0, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(wb, die_num, wk // die_num, spc_num,
wn_block).permute(1, 3, 0, 2,
4).contiguous().reshape(
die_spc_num * wb,
wk // die_num,
wn_block)
numa_tensor.copy_(weight)
else:
raise ValueError(f"Unsupported tensor_type: {layout}")
torch.supa.synchronize()
supa_debug.set_enable_force_uma(enable_force_uma)
return numa_tensor, (die_spc_num, wk, wn_block)
def is_br166_device():
spc_num = torch_br.supa.get_device_properties(
torch.device("supa")).max_compute_units
return bool(spc_num > 16 and spc_num <= 32)

View File

@@ -0,0 +1,23 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import layer, supa_moe # noqa: E402
from .layer import * # noqa: E402
__all__ = [
"layer",
"supa_moe",
]

View File

@@ -0,0 +1,413 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
from typing import Callable, Optional
import torch
import torch_br
from fastcore.basics import patch_to
from torch_br.utils.tensor_methods import Sbp
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm.model_executor.utils import set_weight_attrs
from vllm_br import envs
from ..br_utils import (_convert_to_crossed_numa_tensor,
_convert_to_numa_tensor, align_n, cross_weight_32)
from .supa_moe import (fused_moe_quant_device, fused_moe_quant_dyn,
fused_oss_moe_dyn)
@patch_to(UnquantizedFusedMoEMethod)
def forward_oot(
self,
layer: FusedMoE,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
"""Forward for UnquantizedFusedMoEMethod with SUPA out-of-tree support.
"""
if activation == "swigluoai":
return fused_oss_moe_dyn(
x,
layer.w13_weight,
layer.w13_bias,
layer.w2_weight,
layer.w2_bias,
router_logits,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
ep_rank=layer.ep_rank,
ep_size=layer.ep_size)
b_seq = x.shape[0]
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
# prefill
return fused_moe_quant_dyn(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight,
layer.w2_weight,
None,
None,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
tp_rank=get_tp_group().rank_in_group,
global_rank=get_tp_group().rank,
tp_size=get_tensor_model_parallel_world_size(),
ep_rank=layer.ep_rank,
ep_size=layer.ep_size)
else:
# decoder
return fused_moe_quant_device(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight,
layer.w2_weight,
None,
None,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
tp_rank=get_tp_group().rank_in_group,
global_rank=get_tp_group().rank,
tp_size=get_tensor_model_parallel_world_size(),
ep_rank=layer.ep_rank,
ep_size=layer.ep_size)
@patch_to(UnquantizedFusedMoEMethod)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size_per_partition,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.moe.has_bias:
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
@patch_to(UnquantizedFusedMoEMethod)
def process_weights_after_loading(self: UnquantizedFusedMoEMethod,
layer: FusedMoE) -> None:
cur_device = torch.supa.current_device()
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
die_num = 1 if die_spc_num <= 16 else 2
spc_num = die_spc_num // die_num
align_size = 32 if layer.activation == "swigluoai" else 64
is_dual_die = (die_spc_num > 16)
# NOTE: w13_weight
# after _load_w13, w13_weight is a colparallel weight, shape
# [num_experts, 2 * intermediate_size_per_partition, hidden_size]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] (wn = aligned(2 * intermediate_size_per_partition, align_size=64))
wk = layer.hidden_size
wn_block = align_n((layer.intermediate_size_per_partition * 2) // die_num,
align_size=align_size,
spc_num=spc_num)
supa_w13_weight = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.bfloat16,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13 = layer.w13_weight[expert_id].transpose(0, 1).contiguous()
# swigluoai activation, no need do interweave
if layer.activation and layer.activation == "swigluoai":
pad_expert_w13 = _convert_to_numa_tensor(expert_w13, align_size,
'COLMAJOR',
expert_w13.dtype)
pad_expert_w13_shape = pad_expert_w13.shape
hw_size = pad_expert_w13_shape[-2] * pad_expert_w13_shape[-1]
narrow_data = supa_w13_weight.view_as_usharp(
"COLMAJOR", pad_expert_w13_shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
else:
expert_1, expert_3 = expert_w13.chunk(2, dim=1)
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
expert_3,
die_spc_num,
dim=1,
need_pad=True,
layout='COLMAJOR')
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
narrow_data = supa_w13_weight.view_as_usharp(
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
layer.w13_weight.data = supa_w13_weight
# NOTE: w13_bias
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
wn = layer.intermediate_size_per_partition * 2
supa_w13_bias = torch_br._empty_ut_only(
size=(layer.local_num_experts, wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="linear_bias",
sbp="BB" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13_bias = layer.w13_bias[expert_id]
# swigluoai activation, no need do interweave
if layer.activation and layer.activation == "swigluoai":
narrow_data = supa_w13_bias[expert_id]
narrow_data.copy_(expert_w13_bias)
else:
expert_1_bias, expert_3_bias = expert_w13_bias.chunk(2, dim=-1)
crossed_expert_w13_bias = cross_weight_32(
expert_1_bias,
expert_3_bias,
die_spc_num,
dim=0,
need_pad=False,
)
narrow_data = supa_w13_bias[expert_id]
narrow_data.copy_(crossed_expert_w13_bias)
layer.w13_bias.data = supa_w13_bias
# NOTE: w2_weight
# after _load_w2, w2_weight is a rowparallel weight, shape
# [num_experts, hidden_size, intermediate_size_per_partition]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block]
align_size = 32
wk = layer.intermediate_size_per_partition
wn_block = align_n(layer.hidden_size,
align_size=align_size,
spc_num=spc_num)
supa_w2_weight = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk // die_num, wn_block),
dtype=torch.bfloat16,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight[expert_id].transpose(0, 1).contiguous()
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
align_size,
'COLMAJOR',
expert_w2.dtype,
parallel_type="row_parallel")
pad_expert_w2_shape = pad_expert_w2.shape
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
narrow_data = supa_w2_weight.view_as_usharp("COLMAJOR",
pad_expert_w2_shape,
Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w2)
layer.w2_weight.data = supa_w2_weight
# NOTE: w2_bias
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
wn = layer.hidden_size
supa_w2_bias = torch.zeros((layer.local_num_experts, wn),
dtype=torch.float32,
device=cur_device)
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_bias[expert_id]
narrow_data = supa_w2_bias[expert_id]
narrow_data.copy_(expert_w2)
layer.w2_bias.data = supa_w2_bias
@patch_to(FusedMoE)
def forward(self: FusedMoE, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
"""
! router_logits is a tuple of gate, shared_experts.gate_up_proj,
shared_experts.down_proj weights.
"""
assert self.quant_method is not None
assert self.dp_size == 1, 'dp_size > 1 is not supported for now, please refer v0.11.0 moe codes'
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
)
# NOTE: if using supa-moe-ccl kernel, add property `all_reduced` to the final_hidden_states
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
tp_size = get_tensor_model_parallel_world_size()
if hidden_states.shape[
0] <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and envs.VLLM_BR_QUANT_METHOD != "INT4" and envs.VLLM_BR_USE_FUSED_ALLREDUCE and (
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
final_hidden_states.all_reduced = True
return final_hidden_states
@patch_to(FusedMoE)
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str,
loaded_weight: torch.Tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight.cpu())
@patch_to(FusedMoE)
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight.cpu())
def wrapper_FusedMoE_init(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if self.e_score_correction_bias is not None:
self.e_score_correction_bias.data = self.e_score_correction_bias.float(
)
return wrapper
FusedMoE.__init__ = wrapper_FusedMoE_init(FusedMoE.__init__) # noqa: E501

View File

@@ -0,0 +1,518 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Callable, Optional
import torch
import torch_br
from torch_br.utils.tensor_methods import Sbp
from vllm_br import envs
# gpt-oss moe forward version
def fused_oss_moe_dyn(
hidden_states: torch.Tensor,
w13: torch.Tensor,
w13_bias: torch.Tensor,
w2: torch.Tensor,
w2_bias: torch.Tensor,
gating_weight: torch.Tensor,
topk: int,
intermediate_size: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
) -> torch.Tensor:
total_expert_num = gating_weight.shape[-2]
probs_supa, indices_supa, prob_per_expert, indice_per_expert = torch_br.supa_moe_router_v2_infer(
hidden_states,
gating_weight,
topk,
ep_size,
ep_rank,
gating_bias=e_score_correction_bias)
cur_device = hidden_states.device
probs_supa = probs_supa.cpu().permute(1, 0).contiguous().to(cur_device)
indices_supa = indices_supa.cpu().permute(1, 0).contiguous().to(cur_device)
indice_per_expert = indice_per_expert.cpu().permute(
1, 0).contiguous().to(cur_device)
prob_per_expert = prob_per_expert.cpu().permute(
1, 0).contiguous().to(cur_device)
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
local_expert_num = total_expert_num // ep_size # type: ignore
b_seq = hidden_states.shape[0]
indices_trans_supa = torch_br._empty_ut_only(
size=(local_expert_num, (b_seq + 64 - 1) // 64 * 64),
dtype=torch.int32,
is_numa=False,
device=hidden_states.device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
tokens_per_expert_supa = torch.bincount(indices_supa.reshape(-1),
minlength=total_expert_num)
tokens_per_expert_list = tokens_per_expert_supa.cpu().data.numpy().tolist(
)[ep_rank * local_expert_num:(ep_rank + 1) * # type: ignore
local_expert_num]
topk_per_expert = sum(1 for x in tokens_per_expert_list if x != 0)
if topk_per_expert > 0:
expert_tokens = torch_br.supa_permutation_infer(
global_hidden_states=hidden_states,
indices=indice_per_expert,
tokens_per_expert=tokens_per_expert_list,
indices_trans=indices_trans_supa)
assert len(
expert_tokens) == local_expert_num, "Number of experts mismatch"
gate_up_outputs = []
down_outputs = []
cur_device = expert_tokens[0].device
hidden_size = expert_tokens[0].shape[-1]
for i in range(local_expert_num):
if tokens_per_expert_list[i] == 0:
gate_up_outputs.append(
torch.empty(size=(0, intermediate_size),
dtype=torch.bfloat16,
device=cur_device))
down_outputs.append(
torch.empty(size=(0, hidden_size),
dtype=torch.float32,
device=cur_device))
continue
gate_up_output = torch_br._empty_ut_only(
size=(tokens_per_expert_list[i], intermediate_size),
dtype=torch.bfloat16,
is_numa=False,
device=cur_device,
tensor_type="colmajor",
sbp="SB" if is_dual_die else None,
axis=1)
gate_up_outputs.append(gate_up_output)
down_output = torch_br._empty_ut_only(
size=(tokens_per_expert_list[i], hidden_size),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
down_outputs.append(down_output)
torch_br.supa_moe_fused_ffn_dyn_infer(gate_up_outputs,
expert_tokens,
w13,
tokens_per_expert_list,
max(tokens_per_expert_list),
bias=w13_bias,
act_mode="act_swiglu_oai")
torch_br.supa_moe_fused_ffn_dyn_infer(down_outputs,
gate_up_outputs,
w2,
tokens_per_expert_list,
max(tokens_per_expert_list),
bias=w2_bias,
act_mode="act_default")
output = torch_br.supa_unpermutation_infer(
input_list=down_outputs,
indices=indices_trans_supa,
probs=prob_per_expert,
tokens_per_expert=tokens_per_expert_list)
else:
output = torch.zeros_like(hidden_states)
return output.unsqueeze(0)
def fused_moe_quant_dyn(
hidden_states: torch.Tensor,
shared_gate_up_weight: torch.Tensor,
down_weight: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_weight: torch.Tensor,
topk: int,
intermediate_size: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = None,
global_rank: Optional[int] = None,
tp_size: Optional[int] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
total_expert_num = gating_weight.shape[-1]
cur_device = hidden_states.device
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
shared_output, _, indices_supa, indice_per_expert, prob_per_expert = torch_br.supa_fused_shared_router_prefill_v2_infer(
hidden_states,
shared_gate_up_weight,
down_weight,
gating_weight,
intermediate_size,
topk,
num_expert_group,
topk_group,
ep_size,
ep_rank,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
if is_dual_die:
shared_tmp = torch_br._empty_ut_only(size=shared_output.shape,
dtype=shared_output.dtype,
is_numa=False,
device=shared_output.device,
tensor_type="colmajor")
shared_tmp.copy_(shared_output)
shared_output = shared_tmp
else:
assert topk_group is None, "Only support non group topk router"
assert shared_gate_up_weight is None and down_weight is None
_, indices_supa, prob_per_expert, indice_per_expert = torch_br.supa_moe_router_v2_infer(
hidden_states,
gating_weight.permute(1, 0).contiguous(),
topk,
ep_size,
ep_rank,
gating_bias=e_score_correction_bias)
shared_output = None
indices_supa = indices_supa.permute(1, 0).contiguous()
indice_per_expert = indice_per_expert.permute(1, 0).contiguous()
prob_per_expert = prob_per_expert.permute(1, 0).contiguous()
local_expert_num = total_expert_num // ep_size # type: ignore
b_seq = hidden_states.shape[0]
indices_trans_supa = torch_br._empty_ut_only(
size=(local_expert_num, (b_seq + 64 - 1) // 64 * 64),
dtype=torch.int32,
is_numa=False,
device=hidden_states.device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
tokens_per_expert_supa = torch.bincount(indices_supa.reshape(-1),
minlength=total_expert_num)
tokens_per_expert_list = tokens_per_expert_supa.cpu().data.numpy().tolist(
)[ep_rank * local_expert_num:(ep_rank + 1) * # type: ignore
local_expert_num]
topk_per_expert = sum(1 for x in tokens_per_expert_list if x != 0)
if topk_per_expert > 0:
expert_tokens = torch_br.supa_permutation_infer(
global_hidden_states=hidden_states,
indices=indice_per_expert,
tokens_per_expert=tokens_per_expert_list,
indices_trans=indices_trans_supa)
assert len(
expert_tokens) == local_expert_num, "Number of experts mismatch"
supa_device = torch.supa.current_device()
spc_num = torch_br.supa.get_device_properties(
supa_device).max_compute_units
out_expert_tokens = []
use_moe_fused_ffn_dyn = True
if not use_moe_fused_ffn_dyn or total_expert_num == 128:
w13_hw = w13.shape[-2] * w13.shape[-1]
w2_hw = w2.shape[-2] * w2.shape[-1]
for i in range(local_expert_num):
expert_token = expert_tokens[i]
if tokens_per_expert_list[i] == 0:
out_expert_tokens.append(expert_token)
continue
expert_gate_up_weight = w13.view_as_usharp(
"COLMAJOR", (spc_num, w13.shape[-2], w13.shape[-1]),
Sbp.ss(0), i * w13_hw)
down_weight = w2.view_as_usharp(
"COLMAJOR", (spc_num, w2.shape[-2], w2.shape[-1]),
Sbp.ss(0), i * w2_hw)
expert_gate_up_scale = w13_scale[
i] if w13_scale is not None else None
down_scale = w2_scale[i] if w2_scale is not None else None
gate_up_output = torch_br._empty_ut_only(
size=(expert_token.shape[0], intermediate_size),
dtype=torch.bfloat16,
is_numa=False,
device=expert_token.device,
tensor_type="colmajor",
sbp="SB" if is_dual_die else None,
axis=1)
torch_br.supa_fused_linear_infer(gate_up_output,
expert_token,
expert_gate_up_weight,
expert_gate_up_scale,
act_mode="act_swiglu")
down_output = torch_br._empty_ut_only(
size=expert_token.shape,
dtype=torch.float32,
is_numa=False,
device=gate_up_output.device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
torch_br.supa_fused_linear_infer(down_output, gate_up_output,
down_weight, down_scale)
out_expert_tokens.append(down_output)
else:
gate_up_outputs = []
cur_device = expert_tokens[0].device
hidden_size = expert_tokens[0].shape[-1]
for i in range(local_expert_num):
if tokens_per_expert_list[i] == 0:
gate_up_outputs.append(
torch.empty(size=(0, intermediate_size),
dtype=torch.bfloat16,
device=cur_device))
out_expert_tokens.append(
torch.empty(size=(0, hidden_size),
dtype=torch.float32,
device=cur_device))
continue
gate_up_output = torch_br._empty_ut_only(
size=(tokens_per_expert_list[i], intermediate_size),
dtype=torch.bfloat16,
is_numa=False,
device=cur_device,
tensor_type="colmajor",
sbp="SB" if is_dual_die else None,
axis=1)
gate_up_outputs.append(gate_up_output)
down_output = torch_br._empty_ut_only(
size=(tokens_per_expert_list[i], hidden_size),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
out_expert_tokens.append(down_output)
torch_br.supa_moe_fused_ffn_dyn_infer(gate_up_outputs,
expert_tokens,
w13,
tokens_per_expert_list,
max(tokens_per_expert_list),
scales=w13_scale,
act_mode="act_swiglu")
torch_br.supa_moe_fused_ffn_dyn_infer(out_expert_tokens,
gate_up_outputs,
w2,
tokens_per_expert_list,
max(tokens_per_expert_list),
scales=w2_scale,
act_mode="act_default")
out_states = torch_br.supa_unpermutation_infer(
input_list=out_expert_tokens,
indices=indices_trans_supa,
probs=prob_per_expert,
tokens_per_expert=tokens_per_expert_list)
output = out_states if shared_output is None else out_states + shared_output
else:
output = torch.zeros_like(
hidden_states) if shared_output is None else shared_output
return output.unsqueeze(0)
def fused_moe_quant_device(
hidden_states: torch.Tensor,
shared_gate_up_weight: torch.Tensor,
down_weight: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_weight: torch.Tensor,
topk: int,
intermediate_size: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = None,
global_rank: Optional[int] = None,
tp_size: Optional[int] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
expert_num = gating_weight.shape[-1]
b_seq = hidden_states.shape[-2]
if topk_group is None:
assert shared_gate_up_weight is None and down_weight is None
shared_output, masked_probs, hitted_experts = torch_br.supa_moe_router_decoder_infer(
hidden_states, gating_weight, topk, ep_size, ep_rank)
else:
assert use_grouped_topk is True, "Only support group topk router"
assert num_expert_group is not None and topk_group is not None
if ep_size > 1: # type: ignore
shared_output, masked_probs, hitted_experts = torch_br.supa_fused_shared_router_v2_infer(
hidden_states,
shared_gate_up_weight,
down_weight,
gating_weight,
intermediate_size,
topk,
num_expert_group,
topk_group,
ep_size,
ep_rank,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias
if e_score_correction_bias is not None else torch.empty(
(expert_num),
dtype=torch.float32,
device=hidden_states.device))
else:
shared_output, masked_probs, hitted_experts = torch_br.supa_fused_shared_router_infer(
hidden_states,
shared_gate_up_weight,
down_weight,
gating_weight,
intermediate_size,
topk,
num_expert_group,
topk_group,
scoring_func,
e_score_correction_bias=e_score_correction_bias
if e_score_correction_bias is not None else torch.empty(
(expert_num),
dtype=torch.float32,
device=hidden_states.device))
if is_dual_die:
shared_output = shared_output.view_as_usharp(
"COLMAJOR", shared_output.shape, Sbp.bb())
if w13.dtype == torch.int32:
torch_br.supa_moe_fused_ffn_s4_infer(shared_output, hidden_states, w13,
w2, hitted_experts, masked_probs,
w13_scale, w2_scale)
else:
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
if envs.VLLM_BR_USE_FUSED_ALLREDUCE and b_seq <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and (
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
# ffn+allreduce only support tp 4|8 and 16spc
torch_br.supa_moe_fused_ffn_allreduce(shared_output, hidden_states,
w13, w2, hitted_experts,
masked_probs, tp_rank,
tp_size, global_rank, 0,
w13_scale, w2_scale)
else:
torch_br.supa_moe_fused_ffn_infer(shared_output, hidden_states,
w13, w2, hitted_experts,
masked_probs, w13_scale,
w2_scale)
return shared_output.unsqueeze(0)

View File

@@ -0,0 +1,67 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import os
from typing import Optional, Tuple, Union
import torch
import torch_br
from fastcore.basics import patch_to
from torch import Tensor, nn
from vllm.model_executor.layers.layernorm import RMSNorm
@patch_to(RMSNorm)
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.weight.data.dtype == torch.bfloat16:
self.weight.data = self.weight.data.to(torch.float32)
if residual is not None:
y_supa, add_out_supa = torch_br.supa_add_rmsnorm_infer( # type: ignore
x, residual, self.weight.data, self.variance_epsilon)
return y_supa, add_out_supa
else:
if len(x.shape) == 2:
x = x.unsqueeze(0)
if len(x.shape) == 4:
x = x.squeeze(0)
x = torch_br.supa_rmsnorm_infer(
x,
self.weight.data,
self.variance_epsilon # type: ignore
)
return x
@patch_to(RMSNorm)
def enabled(cls) -> bool:
return True
@patch_to(nn.LayerNorm)
def forward(self, input: Tensor) -> Tensor:
if os.environ.get("USE_BR_FUSED_LAYERNORM",
'False').lower() not in {'false', '0', ''}:
return torch_br.fused_layernorm(input, self.weight, self.bias,
self.eps)
else:
return nn.functional.layer_norm(input, self.normalized_shape,
self.weight, self.bias, self.eps)

View File

@@ -0,0 +1,767 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Literal, Optional, Union
import torch
import torch.nn.functional as F
import torch_br
import torch_br.supa._debug as supa_debug
from fastcore.basics import patch_to
from torch.nn.parameter import Parameter
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group, split_tensor_along_last_dim,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import logger
from vllm.model_executor.layers.linear import (adjust_bitsandbytes_4bit_shard,
adjust_marlin_shard,
adjust_scalar_to_fused_array)
from vllm_br import envs
from vllm_br.utils import get_grandparent_pid
from .br_utils import (_convert_to_crossed_numa_tensor,
_convert_to_numa_tensor, _convert_to_numa_tensor_vit,
is_br166_device)
from vllm.model_executor.layers.linear import ( # isort:skip
LinearBase, MergedColumnParallelLinear, QuantizationConfig,
ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod,
QKVParallelLinear)
def _should_skip_linear_post_process(layer, use_ds_mla, use_ds_mla_sparse):
"""NOTE: SUPA: for MLA linears, we do process in MLA.process_weights_after_loading """
# TODO: Hard code for native dsa op
if use_ds_mla_sparse:
MLA_LINEAR_NAMES = [
"kv_b_proj",
]
else:
MLA_LINEAR_NAMES = [
"q_a_proj",
"q_b_proj",
# "q_proj",
"kv_a_proj_with_mqa",
"kv_b_proj",
# "o_proj",
]
if use_ds_mla and not use_ds_mla_sparse:
MLA_LINEAR_NAMES.append("o_proj")
skip = any(k in layer.prefix for k in MLA_LINEAR_NAMES)
if skip:
logger.debug(
f'[SUPA] skip {layer.prefix} UnquantizedLinearMethod.process_weights_after_loading' # noqa: G004
)
return skip
# NOTE: ReplicatedLinear, usually used in MoE as a gate module.
# In DeepseekV3, it needs to be transposed.
def process_weights_ReplicatedLinear(
layer: ReplicatedLinear) -> Literal[True, False]:
layer.weight.data = layer.weight.data.transpose(1, 0).contiguous()
return True
def process_share_expert_weight(layer: MergedColumnParallelLinear):
gate_up_weight = layer.weight.transpose(1, 0).contiguous()
gate_weight, up_weight = torch.chunk(gate_up_weight, 2, dim=-1)
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
is_br166 = die_spc_num > 16
spc_num = die_spc_num // 2 if is_br166 else die_spc_num
if is_br166:
# 2&2 for 4spc, 4&8 for 12spc, 8&8 for 16spc
spc_for_shared = 2 if spc_num == 4 else 8
spc_for_router = spc_num - spc_for_shared
align_size = 32
weight_dtype = gate_weight.dtype
hidden_size = gate_weight.shape[0]
gate_d0, gate_d1 = torch.chunk(gate_weight, 2, dim=-1)
up_d0, up_d1 = torch.chunk(up_weight, 2, dim=-1)
im_size = gate_d0.shape[-1]
n_align_size = (align_size * 2) * spc_for_shared
swiglu_w_aligned = ((
(im_size * 2) + n_align_size - 1) // n_align_size) * n_align_size
region_size = swiglu_w_aligned // spc_for_shared
block_nums = (region_size // (align_size * 2)) * spc_for_shared
gate_d0_align = torch.nn.functional.pad(
gate_d0, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
mode='constant',
value=0)
gate_d1_align = torch.nn.functional.pad(
gate_d1, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
mode='constant',
value=0)
up_d0_align = torch.nn.functional.pad(
up_d0, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
mode='constant',
value=0)
up_d1_align = torch.nn.functional.pad(
up_d1, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
mode='constant',
value=0)
gate_weight_d0_reshape = gate_d0_align.reshape(
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
gate_weight_d1_reshape = gate_d1_align.reshape(
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
up_weight_d0_reshape = up_d0_align.reshape(
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
up_weight_d1_reshape = up_d1_align.reshape(
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
gate_up_weight_d0 = torch.zeros(
[hidden_size, block_nums, align_size * 2],
dtype=weight_dtype,
device='supa')
gate_up_weight_d0[:, :, 0:0 +
align_size] = gate_weight_d0_reshape[:, :,
0:align_size]
gate_up_weight_d0[:, :, align_size:align_size *
2] = up_weight_d0_reshape[:, :, 0:align_size]
gate_up_weight_d0 = gate_up_weight_d0.reshape(
hidden_size, spc_for_shared,
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
gate_up_d0_invalid = torch.zeros(
[spc_for_router, hidden_size, region_size],
dtype=weight_dtype,
device='supa') # invalid regions
gate_up_weight_d0_whole = torch.cat(
[gate_up_weight_d0, gate_up_d0_invalid], dim=0)
gate_up_weight_d1 = torch.zeros(
[hidden_size, block_nums, align_size * 2],
dtype=weight_dtype,
device='supa')
gate_up_weight_d1[:, :, 0:0 +
align_size] = gate_weight_d1_reshape[:, :,
0:align_size]
gate_up_weight_d1[:, :, align_size:align_size *
2] = up_weight_d1_reshape[:, :, 0:align_size]
gate_up_weight_d1 = gate_up_weight_d1.reshape(
hidden_size, spc_for_shared,
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
gate_up_d1_invalid = torch.zeros(
[spc_for_router, hidden_size, region_size],
dtype=weight_dtype,
device='supa') # invalid regions
gate_up_weight_d1_whole = torch.cat(
[gate_up_weight_d1, gate_up_d1_invalid], dim=0)
gate_up_weight_whole = torch.cat(
[gate_up_weight_d0_whole, gate_up_weight_d1_whole], dim=0)
gate_up_weight_supa = torch_br._empty_ut_only(
size=gate_up_weight_whole.shape,
dtype=gate_weight.dtype,
is_numa=True,
device="supa",
tensor_type="colmajor")
gate_up_weight_supa.copy_(gate_up_weight_whole)
layer.weight.data = gate_up_weight_supa
else:
# 2&2 for 4spc, 4&8 for 12spc, 8&8 for 16spc
spc_for_shared = 2 if spc_num == 4 else 8
spc_for_router = spc_num - spc_for_shared
align_size = 32
weight_dtype = gate_weight.dtype
hidden_size = gate_weight.shape[0]
im_size = gate_weight.shape[-1]
n_align_size = (align_size * 2) * spc_for_shared
swiglu_w_aligned = ((
(im_size * 2) + n_align_size - 1) // n_align_size) * n_align_size
region_size = swiglu_w_aligned // spc_for_shared
block_nums = (region_size // (align_size * 2)) * spc_for_shared
gate_golden_align = torch.nn.functional.pad(
gate_weight, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
mode='constant',
value=0)
up_golden_align = torch.nn.functional.pad(
up_weight, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
mode='constant',
value=0)
gate_weight_golden_reshape = gate_golden_align.reshape(
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
up_weight_golden_reshape = up_golden_align.reshape(
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
gate_up_weight_golden = torch.zeros(
[hidden_size, block_nums, align_size * 2],
dtype=weight_dtype,
device='supa')
gate_up_weight_golden[:, :, 0:0 +
align_size] = gate_weight_golden_reshape[:, :, 0:
align_size]
gate_up_weight_golden[:, :, align_size:align_size *
2] = up_weight_golden_reshape[:, :, 0:align_size]
gate_up_weight_golden = gate_up_weight_golden.reshape(
hidden_size, spc_for_shared,
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
gate_up_invalid = torch.zeros(
[spc_for_router, hidden_size, region_size],
dtype=weight_dtype,
device='supa') # invalid regions
gate_up_weight_whole = torch.cat(
[gate_up_weight_golden, gate_up_invalid], dim=0)
gate_up_weight_supa = torch_br._empty_ut_only(
size=gate_up_weight_whole.shape,
dtype=gate_weight.dtype,
is_numa=True,
device="supa",
tensor_type="colmajor")
gate_up_weight_supa.copy_(gate_up_weight_whole)
layer.weight.data = gate_up_weight_supa
# NOTE: MergedColumnParallelLinear, usually used in MergedGateUpMLPSiluL2
def process_weights_QuantMergedColumnParallelLinear(
layer: MergedColumnParallelLinear):
if 'shared_experts' not in layer.prefix:
#NOTE: normal MLP gate_up, after load weight, convert to supa numa tensor
if hasattr(layer, "qweight"):
gate_weight, up_weight = torch.chunk(layer.qweight, 2, dim=-1)
gate_up_weight_numa = _convert_to_crossed_numa_tensor(
gate_weight,
up_weight,
envs.VLLM_BR_DEVICE_SPC_NUM,
dim=-1,
need_pad=True,
do_transpose=False)
layer.qweight.data = gate_up_weight_numa
else:
gate_up_weight = layer.weight.permute(1, 0).contiguous()
gate_up_weight_numa = _convert_to_numa_tensor(
gate_up_weight,
32,
"colmajor",
gate_up_weight.dtype,
False,
parallel_type="col_parallel")
layer.weight.data = gate_up_weight_numa
if hasattr(layer, "scales") and layer.scales is not None:
gate_scales, up_scales = torch.chunk(layer.scales, 2, dim=-1)
gate_up_scales_internleaved_numa = _convert_to_crossed_numa_tensor(
gate_scales,
up_scales,
envs.VLLM_BR_DEVICE_SPC_NUM,
dim=-1,
need_pad=False,
layout="linear_bias",
do_transpose=False)
layer.scales.data = gate_up_scales_internleaved_numa
if hasattr(layer, "bias") and layer.bias is not None:
gate_bias, up_bias = torch.chunk(layer.bias, 2, dim=-1)
gate_up_bias_internleaved_numa = _convert_to_crossed_numa_tensor(
gate_bias,
up_bias,
envs.VLLM_BR_DEVICE_SPC_NUM,
dim=-1,
need_pad=False,
layout="linear_bias",
do_transpose=False)
layer.bias.data = gate_up_bias_internleaved_numa
else:
process_share_expert_weight(layer)
def process_weights_MergedColumnParallelLinear(
layer: MergedColumnParallelLinear):
if 'shared_experts' not in layer.prefix:
gate_up_weight = layer.weight.permute(1, 0).contiguous()
if not (hasattr(layer, "no_need_cross") and layer.no_need_cross):
gate_weight, up_weight = torch.chunk(gate_up_weight, 2, dim=-1)
gate_up_weight_internleaved_numa = _convert_to_crossed_numa_tensor(
gate_weight,
up_weight,
envs.VLLM_BR_DEVICE_SPC_NUM,
dim=-1,
need_pad=True,
do_transpose=False)
layer.weight.data = gate_up_weight_internleaved_numa
else:
gate_up_weight_numa = _convert_to_numa_tensor(
gate_up_weight,
align_size=32,
layout="colmajor",
dtype=gate_up_weight.dtype,
do_transpose=False)
layer.weight.data = gate_up_weight_numa
if hasattr(layer, "bias") and layer.bias is not None:
gate_bias, up_bias = torch.chunk(layer.bias, 2, dim=-1)
gate_up_bias_internleaved_numa = _convert_to_crossed_numa_tensor(
gate_bias,
up_bias,
envs.VLLM_BR_DEVICE_SPC_NUM,
dim=-1,
need_pad=False,
layout="linear_bias",
do_transpose=False)
layer.bias.data = gate_up_bias_internleaved_numa
else:
#NOTE: by default, gate module and shared_expert(1) module will be involved into calculation in 1 kernel
process_share_expert_weight(layer)
@patch_to(UnquantizedLinearMethod)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _should_skip_linear_post_process(
layer, self.use_ds_mla,
self.use_ds_mla_sparse) or self.weight_type != "NUMA":
return
still_need_process = True
do_transpose = True
parallel_type = "col_parallel"
# NOTE: all process_weights func should done before process_weights_after_loading
match layer:
case ReplicatedLinear():
process_weights_ReplicatedLinear(layer)
still_need_process = not ("indexer" not in layer.prefix and (
layer.output_size == 64 or layer.output_size == 160 # Glm4-Moe
or layer.output_size == 128 or layer.output_size == 256))
do_transpose = False
case MergedColumnParallelLinear():
process_weights_MergedColumnParallelLinear(layer)
still_need_process = False
do_transpose = False
case RowParallelLinear():
parallel_type = "row_parallel"
case _:
pass
if not still_need_process or self.weight_type != "NUMA":
return
# process numa weight and bias
if hasattr(layer, "weight") and len(layer.weight.shape) == 2:
if 'vision' in layer.prefix and is_br166_device():
layer.weight.data = _convert_to_numa_tensor_vit(
layer.weight,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"colmajor",
torch.bfloat16,
do_transpose=do_transpose,
wk=(layer.weight.data.shape[1]
if do_transpose else layer.weight.data.shape[0]),
wn=(layer.weight.data.shape[0]
if do_transpose else layer.weight.data.shape[1]),
parallel_type=parallel_type) # noqa: SIM210
else:
layer.weight.data = _convert_to_numa_tensor(
layer.weight,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"colmajor",
torch.bfloat16,
do_transpose=do_transpose,
wk=(layer.weight.data.shape[1]
if do_transpose else layer.weight.data.shape[0]),
wn=(layer.weight.data.shape[0]
if do_transpose else layer.weight.data.shape[1]),
parallel_type=parallel_type) # noqa: SIM210
if hasattr(layer, "bias") and layer.bias is not None:
pad_zeros = (parallel_type == "row_parallel")
if 'vision' in layer.prefix and is_br166_device():
if (pad_zeros and layer.reduce_results):
return
layer.bias.data = _convert_to_numa_tensor_vit(
layer.bias,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
else:
layer.bias.data = _convert_to_numa_tensor(
layer.bias,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
@patch_to(UnquantizedLinearMethod)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# numa weight is 3-dims
if 'vision' in layer.prefix and is_br166_device():
if len(layer.weight.shape) == 3:
is_row = isinstance(layer, RowParallelLinear)
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
act_mode = "act_default"
if isinstance(layer, MergedColumnParallelLinear) and not (hasattr(
layer, "no_need_cross") and layer.no_need_cross):
act_mode = "act_swiglu"
output_size //= 2
if bias is None or (is_row and layer.reduce_results):
# return torch_br.br_matmul_infer(
# x,
# layer.weight,
# bias=None,
# output_w=output_size,
# )
return torch_br.br_fused_mlp_infer(x, [layer.weight],
output_w=output_size,
activation_mode=act_mode)
else:
return torch_br.br_matmul_infer(x, layer.weight, bias,
output_size)
supa_debug.set_enable_sublas_api(True)
output = F.linear(x, layer.weight, bias)
supa_debug.set_enable_sublas_api(False)
return output
if len(layer.weight.shape) == 3:
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
act_mode = "act_default"
if isinstance(layer, MergedColumnParallelLinear) and not (hasattr(
layer, "no_need_cross") and layer.no_need_cross):
act_mode = "act_swiglu"
output_size //= 2
bias = [bias] if bias is not None else None
if isinstance(layer, RowParallelLinear):
seq_len = x.shape[-2]
tp_size = get_tensor_model_parallel_world_size()
# TODO(CaoJun): This is WA, delete (16, 8) so that the test_vllm_model_accu_qwen25_72b_instruct can run through
support_types = ((16, 4), (32, 2), (32, 4))
# bypass tp8 and tp4pp2 allreduce
pp_size = get_pp_group().world_size
all_rank = tp_size * pp_size
layer.reduce_results = not (
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
if layer.reduce_results:
return torch_br.br_fused_mlp_infer(x, [layer.weight],
output_w=output_size,
bias=bias,
activation_mode=act_mode)
else:
tp_rank = get_tp_group().rank_in_group
global_rank = get_tp_group().rank
rank_i = global_rank % tp_size
assert rank_i == tp_rank
return torch_br.supa_fused_linear_allreduce_opt(
x, layer.weight, output_size, tp_rank, tp_size,
global_rank, 0)
else:
return torch_br.br_fused_mlp_infer(x, [layer.weight],
output_w=output_size,
bias=bias,
activation_mode=act_mode)
supa_debug.set_enable_sublas_api(True)
output = F.linear(x, layer.weight, bias)
supa_debug.set_enable_sublas_api(False)
return output
@patch_to(LinearBase)
def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
super(LinearBase, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
self.return_bias = return_bias
self.prefix = prefix
self.tp_rank = (get_tensor_model_parallel_rank() if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
@patch_to(RowParallelLinear)
def forward(
self, input_
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
self, "grandparent_pid"):
self.grandparent_pid = get_grandparent_pid()
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
if self.reduce_results and self.tp_size > 1:
# CPU all reduce will be applied.
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and self.tp_size >= 4 and output_parallel.shape[
1] <= 32:
tp_rank = get_tensor_model_parallel_rank()
output = torch_br.supa_allreduce_pcie_infer(
output_parallel, tp_rank, self.tp_size, self.grandparent_pid)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
@patch_to(QKVParallelLinear)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
idx_map = {"q": 0, "k": 1, "v": 2}
if loaded_shard_id is not None:
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
else:
param.shard_weight_type = {
k: loaded_weight.item()
for k in idx_map
}
return
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size
if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
return
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv).
# (e.g., Phi-3's qkv_proj).
if output_dim is None:
if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
("k", self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
("v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size),
]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.total_num_heads * self.head_size),
"k": (self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
"v":
((self.total_num_heads + self.total_num_kv_heads) *
self.head_size, self.total_num_kv_heads * self.head_size),
"total":
((self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_size, 0)
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, shard_id)
loaded_weight_shard = loaded_weight.narrow(output_dim,
shard_offset,
shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
tp_rank = get_tensor_model_parallel_rank()
assert loaded_shard_id in ["q", "k", "v"]
# If output dim is defined, use the default loading process.
if output_dim is not None:
if loaded_shard_id == "q":
shard_offset = 0
shard_size = self.num_heads * self.head_size
elif loaded_shard_id == "k":
shard_offset = self.num_heads * self.head_size
shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (self.num_heads * self.head_size,
self.num_kv_heads * self.head_size),
"v": ((self.num_heads + self.num_kv_heads) * self.head_size,
self.num_kv_heads * self.head_size),
"total":
((self.num_heads + 2 * self.num_kv_heads) * self.head_size, 0)
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id)
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
half_w = param_data.shape[output_dim] // 2
param_data = (param_data.narrow(output_dim, shard_offset // 2,
shard_size // 2),
param_data.narrow(output_dim,
shard_offset // 2 + half_w,
shard_size // 2))
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":
shard_id = tp_rank
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size, shard_size)
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, loaded_shard_id)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
if isinstance(param_data, tuple):
half_w = loaded_weight.shape[output_dim] // 2
param_data[0].copy_(loaded_weight.narrow(output_dim, 0, half_w))
param_data[1].copy_(loaded_weight.narrow(output_dim, half_w, half_w))
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

View File

@@ -0,0 +1,72 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch
import torch_br
from fastcore.basics import patch_to
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm_br import envs
# TODO(shouqing): need to open this patch when fix hang in mtp
@patch_to(LogitsProcessor)
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
if spc_num > 16:
bb_input = torch_br._empty_ut_only(size=logits.shape,
dtype=logits.dtype,
is_numa=False,
device=logits.device,
tensor_type="colmajor")
# work around the hang in s1b copy to bb
bb_input.copy_(logits)
logits = bb_input
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
logits_ = torch.zeros((logits.shape[0], logits.shape[-1] * tp_size),
dtype=logits.dtype,
device=logits.device)
start = logits.shape[-1] * tp_rank
end = start + logits.shape[-1]
logits_[:, start:end].copy_(logits)
logits = tensor_model_parallel_all_reduce(logits_)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits

View File

@@ -0,0 +1,19 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import compressed_tensors, gptq
__all__ = ["gptq", 'compressed_tensors']

View File

@@ -0,0 +1,18 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from .compressed_tensors import *
from .compressed_tensors_moe import *
from .compressed_tensors_wNa16 import *

View File

@@ -0,0 +1,64 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
from typing import Any, cast
from fastcore.basics import patch_to
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig)
@patch_to(CompressedTensorsConfig, cls_method=True)
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
"""
[PatchNote] add qkv_quantized param support
"""
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config)
transform_config = config.get("transform_config")
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
default=True)
return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
transform_config=transform_config,
qkv_quantized=qkv_quantized)
def wrapper_CompressedTensorsConfig_init(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
qkv_quantized = kwargs.pop("qkv_quantized", True)
fn(self, *args, **kwargs)
self.qkv_quantized = qkv_quantized
return wrapper
CompressedTensorsConfig.__init__ = wrapper_CompressedTensorsConfig_init(
CompressedTensorsConfig.__init__) # noqa: E501

View File

@@ -0,0 +1,594 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Callable, Optional
import torch
import torch_br
from compressed_tensors.compressors.quantized_compressors import (
unpack_from_int32)
from fastcore.basics import patch_to
from torch_br.utils.tensor_methods import Sbp
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
WNA16_SUPPORTED_BITS, CompressedTensorsMoEMethod,
CompressedTensorsWNA16MoEMethod, CompressionFormat)
from vllm.model_executor.utils import set_weight_attrs
from vllm_br import envs
from ...br_utils import (_convert_to_crossed_numa_tensor,
_convert_to_numa_tensor, align_n, cross_weight_32)
from ...fused_moe.supa_moe import fused_moe_quant_device, fused_moe_quant_dyn
@patch_to(CompressedTensorsMoEMethod)
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
) -> "CompressedTensorsMoEMethod":
"""NOTE:
1. SUPA only supports CompressedTensorsWNA16MoEMethod without Marlin
2. Only Linear targets are supported for MoE layers
"""
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
keys = list(quant_config.target_scheme_map.keys())
assert len(keys) > 0, ("No valid quant key!!!")
# assert "Linear" in quant_config.target_scheme_map
# [Patch]: Only Linear target is supported for MoE layers, for temporary compatibility, we change the key of target_scheme_map to the first one
quant_config.target_scheme_map[
"Linear"] = quant_config.target_scheme_map.pop(keys[0])
target_key = "Linear"
# target_key = keys[0] # normal only one key
weight_quant = quant_config.target_scheme_map[target_key].get("weights")
input_quant = quant_config.target_scheme_map[target_key].get(
"input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
@patch_to(CompressedTensorsWNA16MoEMethod)
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super(CompressedTensorsWNA16MoEMethod, self).__init__(moe)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config = self.quant_config.target_scheme_map["Linear"].get("weights")
self.num_bits = config.num_bits
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy
# channelwise is not supported by this kernel
# [Patch]: SUPA use CompressedTensorsWNA16MoEMethod for both channel/group strategies
# assert config.strategy == "group"
self.group_size = config.group_size
# grouped actorder isn't supported by this kernel
# assert config.actorder != "group"
assert config.symmetric, (
"Only symmetric quantization is supported for MoE")
if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS):
raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")
@patch_to(CompressedTensorsWNA16MoEMethod)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
extra_weight_attrs.update({
"is_transposed": True,
"quant_method": self.strategy
})
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size // self.packed_factor,
2 * intermediate_size_per_partition,
dtype=torch.int32,
device="cpu",
),
requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition // self.packed_factor,
hidden_size,
dtype=torch.int32,
device="cpu"),
requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_scales_size = intermediate_size_per_partition
if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size
w13_scale = torch.nn.Parameter(torch.ones(
num_experts,
num_groups_w13,
2 * intermediate_size_per_partition,
dtype=params_dtype,
device="cpu",
),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs)
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
num_groups_w2,
hidden_size,
dtype=params_dtype,
device="cpu"),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, {"load_full_w2": False})
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts,
2,
device="cpu"),
requires_grad=False)
layer.register_parameter("w2_weight_shape", w2_weight_shape)
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts,
2,
device="cpu"),
requires_grad=False)
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
device="cpu",
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
device="cpu",
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
device="cpu",
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
device="cpu",
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
layer.a13_scale = None
layer.a2_scale = None
@patch_to(CompressedTensorsWNA16MoEMethod)
def process_weights_after_loading(self: CompressedTensorsWNA16MoEMethod,
layer: FusedMoE) -> None:
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
die_num = 1 if die_spc_num <= 16 else 2
spc_num = die_spc_num // die_num
cur_device = torch.supa.current_device()
is_dual_die = (die_spc_num > 16)
if self.num_bits == 8:
# NOTE: w13_weight
# after _load_w13, w13_weight is a colparallel weight, shape
# [num_experts, hidden_size // 4, 2 * intermediate_size_per_partition] INT32
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] INT8
wk = layer.hidden_size
wn = layer.intermediate_size_per_partition * 2
align_size = 64
wn_block = align_n(wn // die_num,
align_size=align_size,
spc_num=spc_num)
supa_w13_weight_packed = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.int8,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13 = layer.w13_weight_packed[
expert_id] # hidden_size // 4, 2 * intermediate_size_per_partition
expert_1, expert_3 = expert_w13.chunk(
2, dim=1) # each is a packed int4 weight
unpacked_expert_1 = unpack_from_int32(
expert_1, self.num_bits,
torch.Size(
[layer.hidden_size,
layer.intermediate_size_per_partition]), 0)
unpacked_expert_3 = unpack_from_int32(
expert_3, self.num_bits,
torch.Size(
[layer.hidden_size,
layer.intermediate_size_per_partition]), 0)
pad_expert_w13 = _convert_to_crossed_numa_tensor(
unpacked_expert_1,
unpacked_expert_3,
die_spc_num,
dim=1,
need_pad=True,
layout='COLMAJOR',
do_transpose=False)
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
narrow_data = supa_w13_weight_packed.view_as_usharp(
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
layer.w13_weight_packed.data = supa_w13_weight_packed
# NOTE: w13_scale
# after _load_w13, w13_weight is a colparallel weight, shape
# S8: [num_experts, 1, 2 * intermediate_size_per_partition]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [num_experts, wn]
supa_w13_scales = torch_br._empty_ut_only(
size=(layer.local_num_experts, wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="linear_bias",
sbp="BB" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13_scales = layer.w13_weight_scale[expert_id]
expert_1_scale, expert_3_scale = expert_w13_scales.chunk(
2, dim=1) # each is a packed int4 weight
crossed_expert_w13_scales = cross_weight_32(
expert_1_scale.squeeze(),
expert_3_scale.squeeze(),
die_spc_num,
dim=0,
need_pad=False,
)
narrow_data = supa_w13_scales[expert_id]
narrow_data.copy_(crossed_expert_w13_scales)
layer.w13_weight_scale.data = supa_w13_scales
# NOTE: w2_weight
# after _load_w2, w2_weight is a colparallel weight, shape
# [num_experts, intermediate_size_per_partition // 4, hidden_size] INT32
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] INT8
wk = layer.intermediate_size_per_partition
wn = layer.hidden_size
align_size = 32
wn_block = align_n(wn, align_size=align_size, spc_num=spc_num)
supa_w2_weight_packed = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk // die_num,
wn_block),
dtype=torch.int8,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight_packed[expert_id]
unpacked_expert_2 = unpack_from_int32(
expert_w2, self.num_bits,
torch.Size(
[layer.intermediate_size_per_partition,
layer.hidden_size]), 0)
pad_expert_w2 = _convert_to_numa_tensor(
unpacked_expert_2,
align_size,
'COLMAJOR',
expert_w2.dtype,
do_transpose=False,
parallel_type="row_parallel")
pad_expert_w2_shape = pad_expert_w2.shape
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
narrow_data = supa_w2_weight_packed.view_as_usharp(
"COLMAJOR", pad_expert_w2_shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w2)
layer.w2_weight_packed.data = supa_w2_weight_packed
# NOTE: w2_scale
# after _load_w2, w2_weight is a colparallel weight, shape
# S8: [num_experts, 1, hidden_size]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [num_experts, wn]
supa_w2_scales = torch_br._empty_ut_only(size=(layer.local_num_experts,
wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="linear_bias")
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight_scale[expert_id]
narrow_data = supa_w2_scales[expert_id::layer.local_num_experts]
narrow_data.copy_(expert_w2)
layer.w2_weight_scale.data = supa_w2_scales
elif self.num_bits == 4:
# NOTE: w13_weight
# after _load_w13, w13_weight is a colparallel weight, shape
# [num_experts, hidden_size // 8, 2 * intermediate_size_per_partition] INT32
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] INT32
wk = layer.hidden_size // 8
wn = layer.intermediate_size_per_partition * 2
wn_block = align_n(wn, align_size=64, spc_num=spc_num)
supa_w13_weight_packed = torch_br._empty_ut_only(
size=(spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.int32,
is_numa=True,
device=cur_device,
tensor_type="colmajor")
for expert_id in range(layer.local_num_experts):
expert_w13 = layer.w13_weight_packed[
expert_id] # hidden_size // 4, 2 * intermediate_size_per_partition
expert_1, expert_3 = expert_w13.chunk(
2, dim=0) # each is a packed int4 weight
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
expert_3,
spc_num,
dim=1,
need_pad=True,
layout='COLMAJOR',
do_transpose=True)
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
narrow_data = supa_w13_weight_packed.view_as_usharp(
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
layer.w13_weight_packed.data = supa_w13_weight_packed
# NOTE: w13_scale
# after _load_w13, w13_weight is a colparallel weight, shape
# S4: [num_experts, hidden_size // 128, 2 * intermediate_size_per_partition]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [num_experts, group_nums, wn]
supa_w13_scales = torch_br._empty_ut_only(
size=(layer.local_num_experts,
layer.hidden_size // self.group_size, wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="colmajor")
for expert_id in range(layer.local_num_experts):
expert_w13_scales = layer.w13_weight_scale[expert_id]
expert_1_scale, expert_3_scale = expert_w13_scales.chunk(
2, dim=0) # each is a packed int4 weight
crossed_expert_w13_scales = cross_weight_32(
expert_1_scale,
expert_3_scale,
spc_num,
dim=1,
need_pad=False,
)
narrow_data = supa_w13_scales[expert_id]
narrow_data.copy_(crossed_expert_w13_scales)
layer.w13_weight_scale.data = supa_w13_scales
# NOTE: w2_weight
# after _load_w2, w2_weight is a colparallel weight, shape
# [num_experts, intermediate_size_per_partition // 8, hidden_size] INT32
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] INT32
wk = layer.intermediate_size_per_partition // 8
wn = layer.hidden_size
wn_block = align_n(wn, align_size=32, spc_num=spc_num)
supa_w2_weight_packed = torch_br._empty_ut_only(
size=(spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.int32,
is_numa=True,
device=cur_device,
tensor_type="colmajor")
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight_packed[expert_id]
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
spc_num,
'COLMAJOR',
expert_w2.dtype,
do_transpose=True)
pad_expert_w2_shape = pad_expert_w2.shape
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
narrow_data = supa_w2_weight_packed.view_as_usharp(
"COLMAJOR", pad_expert_w2_shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w2)
layer.w2_weight_packed.data = supa_w2_weight_packed
# NOTE: w2_scale
# after _load_w2, w2_weight is a colparallel weight, shape
# S4: [num_experts, intermediate_size_per_partition // 128, hidden_size]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [num_experts, group_nums, wn]
supa_w2_scales = torch_br._empty_ut_only(
size=(layer.local_num_experts,
layer.intermediate_size_per_partition // self.group_size,
wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="colmajor")
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight_scale[expert_id]
narrow_data = supa_w2_scales[expert_id::layer.local_num_experts]
narrow_data.copy_(expert_w2)
layer.w2_weight_scale.data = supa_w2_scales
else:
raise ValueError(
f"Unsupported num_bits: {self.num_bits}. Only 4 and 8 are supported."
)
# remove other CompressedTensorsWNA16MoEMethod registied buffer to reduce memory usage
layer.w13_weight_shape = None
layer.w13_weight_g_idx = None
layer.w13_g_idx_sort_indices = None
layer.w2_weight_shape = None
layer.w2_weight_g_idx = None
layer.w2_g_idx_sort_indices = None
@patch_to(CompressedTensorsWNA16MoEMethod)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
b_seq = x.shape[0]
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
return fused_moe_quant_dyn(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight_packed,
layer.w2_weight_packed,
layer.w13_weight_scale,
layer.w2_weight_scale,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
ep_rank=layer.ep_rank,
ep_size=layer.ep_size,
)
else:
return fused_moe_quant_device(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight_packed,
layer.w2_weight_packed,
layer.w13_weight_scale,
layer.w2_weight_scale,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
tp_rank=get_tp_group().rank_in_group,
global_rank=get_tp_group().rank,
tp_size=get_tensor_model_parallel_world_size(),
ep_rank=layer.ep_rank,
ep_size=layer.ep_size,
)

View File

@@ -0,0 +1,267 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Callable, Optional
import torch
import torch_br
from compressed_tensors.compressors.quantized_compressors import (
unpack_from_int32)
from fastcore.basics import patch_to
from vllm.distributed import (get_pipeline_model_parallel_group,
get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
# yapf: enable
from vllm_br import envs
from ...br_utils import _convert_to_numa_tensor
@patch_to(CompressedTensorsWNA16)
def create_weights(self, layer: torch.nn.Module, input_size: int,
input_size_per_partition: int, output_size: int,
output_partition_sizes: list[int],
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.output_size_per_partition = sum(output_partition_sizes)
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size
weight = PackedvLLMParameter(
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(self.output_size_per_partition,
input_size_per_partition // self.pack_factor,
dtype=torch.int32,
device="cpu"))
weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(
self.output_size_per_partition,
scales_and_zp_size,
device="cpu",
dtype=params_dtype,
)
}
zeros_args = {
"weight_loader":
weight_loader,
"data":
torch.zeros(
self.output_size_per_partition // self.pack_factor,
scales_and_zp_size,
device="cpu",
dtype=torch.int32,
)
}
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
if not self.symmetric:
qzeros = PackedColumnParameter(output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
else:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
if not self.symmetric:
qzeros = PackedvLLMParameter(input_dim=1,
output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64,
device="cpu"),
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
if not self.symmetric:
layer.register_parameter("weight_zero_point", qzeros)
# group index (for activation reordering)
if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
device="cpu",
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)
self.input_size_per_partition = input_size_per_partition
@patch_to(CompressedTensorsWNA16)
def process_weights_after_loading(self: CompressedTensorsWNA16,
layer: torch.nn.Module) -> None:
# spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
# cur_device = torch.supa.current_device()
self.num_bits = 32 // self.pack_factor
layer.weight_packed.data = unpack_from_int32(
layer.weight_packed.data, self.num_bits,
torch.Size(
[self.output_size_per_partition, self.input_size_per_partition]),
1)
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
br_scales = layer.weight_scale.data.to(torch.float32)
layer.weight_scale.data = br_scales
do_transpose = True
parallel_type = "col_parallel"
match layer:
case RowParallelLinear():
parallel_type = "row_parallel"
case _:
pass
if hasattr(layer, 'weight_packed') and len(layer.weight_packed.shape) == 2:
weight_packed = layer.weight_packed.data
layer.weight_packed.data = _convert_to_numa_tensor(
weight_packed,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"colmajor",
torch.int8,
do_transpose=do_transpose,
wk=(weight_packed.shape[1]
if do_transpose else weight_packed.shape[0]),
wn=(weight_packed.shape[0]
if do_transpose else weight_packed.shape[1]),
parallel_type=parallel_type) # noqa: SIM210
if hasattr(layer, 'weight_scale') and layer.weight_scale is not None:
pad_zeros = False
layer.weight_scale.data = _convert_to_numa_tensor(
layer.weight_scale.data.T,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
if hasattr(layer, 'bias') and layer.bias is not None:
pad_zeros = (parallel_type == "row_parallel")
layer.bias.data = _convert_to_numa_tensor(
layer.bias.data,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
@patch_to(CompressedTensorsWNA16)
def apply_weights(self: CompressedTensorsWNA16,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# numa weight is 3-dims
if len(layer.weight_packed.shape) == 3:
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
act_mode = "act_default"
if isinstance(layer, MergedColumnParallelLinear):
act_mode = "act_swiglu"
output_size //= 2
if isinstance(layer, RowParallelLinear):
seq_len = x.shape[-2]
tp_size = get_tensor_model_parallel_world_size()
# bypass tp8 and tp4pp2 allreduce
pp_size = get_pipeline_model_parallel_group().world_size
all_rank = tp_size * pp_size
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
layer.reduce_results = not (
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
if layer.reduce_results:
return torch_br.br_fused_mlp_infer(
x, [layer.weight_packed.data],
output_w=output_size,
scales=[layer.weight_scale.data]
if layer.weight_scale.data is not None else None,
bias=[bias] if bias is not None else None,
activaion_mode=act_mode)
else:
tp_rank = get_tp_group().rank_in_group
global_rank = get_tp_group().rank
rank_i = global_rank % tp_size
assert rank_i == tp_rank
return torch_br.supa_fused_linear_allreduce_opt(
x,
layer.weight_packed.data,
output_size,
tp_rank,
tp_size,
global_rank,
0,
scales=layer.weight_scale.data,
bias=bias,
act_mode=act_mode)
else:
return torch_br.br_fused_mlp_infer(
x, [layer.weight_packed.data],
output_w=output_size,
scales=[layer.weight_scale.data]
if layer.weight_scale.data is not None else None,
bias=[bias] if bias is not None else None,
activation_mode=act_mode)
xn = x.shape[0]
xh = x.shape[1]
ww = layer.weight_packed.shape[1]
# TODO, hard code to skip dry_run stage
if xh >= 4096:
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
return torch_br.sudnn_qmatmul_infer(x,
layer.weight_packed,
layer.weight_scale,
bias=bias)

View File

@@ -0,0 +1,34 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None

View File

@@ -0,0 +1,244 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
from typing import Any, Dict, Optional
import torch
import torch_br
from fastcore.basics import patch_to
from torch.nn.parameter import Parameter
from vllm.distributed import (get_pipeline_model_parallel_group,
get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
GPTQLinearMethod)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm_br import envs
from ..br_utils import _br_qweight_cvt, _convert_to_numa_tensor
from ..linear import (process_weights_MergedColumnParallelLinear,
process_weights_QuantMergedColumnParallelLinear,
process_weights_ReplicatedLinear)
@patch_to(GPTQConfig, cls_method=True)
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@patch_to(GPTQConfig)
def get_quant_method(self: GPTQConfig, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQLinearMethod"]:
quant_method = get_linear_quant_method(self, layer, prefix,
GPTQLinearMethod)
return quant_method
@patch_to(GPTQConfig, cls_method=True)
def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
"""
[PatchNote] add qkv_quantized param support
"""
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
default="")
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None)
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
default=True)
return cls(weight_bits=weight_bits,
group_size=group_size,
desc_act=desc_act,
lm_head_quantized=lm_head_quantized,
dynamic=dynamic,
autoround_version=autoround_version,
modules_in_block_to_quantize=modules_in_block_to_quantize,
qkv_quantized=qkv_quantized)
def wrapper_GPTQConfig_init(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
qkv_quantized = kwargs.pop("qkv_quantized", True)
fn(self, *args, **kwargs)
self.qkv_quantized = qkv_quantized
return wrapper
GPTQConfig.__init__ = wrapper_GPTQConfig_init(
GPTQConfig.__init__) # noqa: E501
@patch_to(GPTQLinearMethod)
def process_weights_after_loading(self: GPTQLinearMethod,
layer: torch.nn.Module) -> None:
still_need_process = True
merge_col_quant = False
# NOTE: all process_weights func should done before process_weights_after_loading
parallel_type = "col_parallel"
match layer:
case ReplicatedLinear():
process_weights_ReplicatedLinear(layer)
still_need_process = layer.output_size == 64 or layer.output_size == 256
case MergedColumnParallelLinear():
if hasattr(layer, "qweight"):
merge_col_quant = True
else:
process_weights_MergedColumnParallelLinear(layer)
still_need_process = False
case RowParallelLinear():
parallel_type = "row_parallel"
case _:
pass
# NOTE: if use exllama, br gptq needs similar treatment
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.qweight.dtype == torch.int32:
input_size = layer.input_size_per_partition if hasattr(
layer, 'input_size_per_partition') else layer.input_size
output_size = layer.output_size_per_partition if hasattr(
layer, 'output_size_per_partition') else layer.output_size
br_qweight = _br_qweight_cvt(self, layer.qweight, layer.qzeros,
input_size, output_size)
layer.qweight.data = br_qweight
if merge_col_quant:
process_weights_QuantMergedColumnParallelLinear(layer)
still_need_process = False
br_scales = layer.scales.to(torch.float32)
layer.scales.data = br_scales
# for torch.compile
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
if not still_need_process or self.weight_type != "NUMA":
return
if hasattr(layer, 'qweight') and len(layer.qweight.shape) == 2:
layer.qweight.data = _convert_to_numa_tensor(
layer.qweight,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"colmajor",
torch.int8,
parallel_type=parallel_type)
if hasattr(layer, 'scales') and layer.scales is not None:
pad_zeros = False
layer.scales.data = _convert_to_numa_tensor(
layer.scales,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
if hasattr(layer, 'bias') and layer.bias is not None:
pad_zeros = (parallel_type == "row_parallel")
layer.bias.data = _convert_to_numa_tensor(
layer.bias,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
@patch_to(GPTQLinearMethod)
def apply(self: GPTQLinearMethod,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# numa weight is 3-dims
if len(layer.qweight.shape) == 3:
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
act_mode = "act_default"
if isinstance(layer, MergedColumnParallelLinear):
act_mode = "act_swiglu"
output_size //= 2
if isinstance(layer, RowParallelLinear):
seq_len = x.shape[-2]
tp_size = get_tensor_model_parallel_world_size()
# bypass tp8 and tp4pp2 allreduce
pp_size = get_pipeline_model_parallel_group().world_size
all_rank = tp_size * pp_size
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
layer.reduce_results = not (
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
if layer.reduce_results:
return torch_br.br_fused_mlp_infer(
x, [layer.qweight],
output_w=output_size,
scales=[layer.scales]
if layer.scales is not None else None,
bias=[bias] if bias is not None else None,
activation_mode=act_mode)
else:
tp_rank = get_tp_group().rank_in_group
global_rank = get_tp_group().rank
rank_i = global_rank % tp_size
assert rank_i == tp_rank
return torch_br.supa_fused_linear_allreduce_opt(
x,
layer.qweight,
output_size,
tp_rank,
tp_size,
global_rank,
0,
scales=layer.scales,
bias=bias,
act_mode=act_mode)
else:
return torch_br.br_fused_mlp_infer(
x, [layer.qweight],
output_w=output_size,
scales=[layer.scales] if layer.scales is not None else None,
bias=[bias] if bias is not None else None,
activation_mode=act_mode)
xn = x.shape[0]
xh = x.shape[1]
ww = layer.qweight.shape[1]
# TODO, hard code to skip dry_run stage
if xh >= 4096:
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
return torch_br.sudnn_qmatmul_infer(x,
layer.qweight,
layer.scales,
bias=bias)

View File

@@ -0,0 +1,924 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import itertools
from typing import Any, Optional, Tuple, Union
import torch
import torch_br
from fastcore.basics import patch_to
from transformers import PretrainedConfig
import vllm.model_executor.layers.rotary_embedding
import vllm.model_executor.models.chatglm
import vllm.model_executor.models.deepseek_v2
import vllm_br.envs as br_envs
from vllm.logger import logger
from vllm.model_executor.layers.rotary_embedding import (
_ROPE_DICT, DeepseekScalingRotaryEmbedding, DualChunkRotaryEmbedding,
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
Llama3RotaryEmbedding, Llama4VisionRotaryEmbedding, MRotaryEmbedding,
NTKScalingRotaryEmbedding, Phi3LongRoPEScaledRotaryEmbedding,
RotaryEmbedding, YaRNScalingRotaryEmbedding)
from vllm.model_executor.layers.rotary_embedding.common import (
rotate_gptj, rotate_neox, yarn_find_correction_range,
yarn_linear_ramp_mask)
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import (
yarn_get_mscale)
from vllm.model_executor.layers.rotary_embedding.mrope import (
apply_interleaved_rope)
@patch_to(RotaryEmbedding)
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
op_type: str = "Half", # FIXME: other op type not supported yet
) -> None:
logger.info('[Patch] RotaryEmbedding use SUPA RoPE')
super(RotaryEmbedding, self).__init__() # type: ignore
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
self.op_type = op_type # FIXME: other op type not supported yet
if isinstance(self, MRotaryEmbedding):
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
device = torch.cuda.current_device()
cache = cache.to(device)
self.cos_sin_cache: torch.Tensor # type: ignore
self.register_buffer("cos_sin_cache", cache, persistent=False)
elif isinstance(self, DeepseekScalingRotaryEmbedding):
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
device = torch.supa.current_device()
cache = cache.to(device)
self.cos_sin_cache: torch.Tensor # type: ignore
self.register_buffer("cos_sin_cache", cache, persistent=False)
else:
sin_cache, cos_cache = self._compute_cos_sin_cache()
sin_cache = sin_cache.to(torch.float32)
cos_cache = cos_cache.to(torch.float32)
device = torch.cuda.current_device()
sin_cache = sin_cache.to(device)
cos_cache = cos_cache.to(device)
self.register_buffer("sin_cache", sin_cache, persistent=False)
self.register_buffer("cos_cache", cos_cache, persistent=False)
@patch_to(RotaryEmbedding)
def _compute_cos_sin_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the cos and sin cache."""
with torch.device('cpu'):
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
if isinstance(self, MRotaryEmbedding):
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
else:
if self.op_type == "Half" or self.op_type == "TeleChat":
freqs = freqs.repeat(1, 2)
cos = freqs.cos()
sin = freqs.sin()
else:
cos_freqs = freqs.repeat_interleave(2, dim=-1)
cos = cos_freqs.cos()
scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1
sin_freqs = cos_freqs * scales.reshape_as(cos_freqs)
sin = sin_freqs.sin()
return sin, cos
@patch_to(RotaryEmbedding)
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query_, key_ = torch_br.supa_rope_infer_v2(query,
key,
self.sin_cache,
self.cos_cache,
positions,
self.head_size,
rope_type=self.op_type,
rotary_size=self.rotary_dim)
return query_, key_
@patch_to(RotaryEmbedding)
def enabled(cls) -> bool:
return True
class SupaDeepseekScalingRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
with torch.device('cpu'):
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cpu") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(
self.beta_fast, self.beta_slow, self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
with torch.device('cpu'):
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings *
self.scaling_factor,
dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos_freqs = freqs.repeat_interleave(2, dim=-1)
cos = (cos_freqs.cos() * self.mscale)
scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1
sin_freqs = cos_freqs * scales.reshape_as(cos_freqs)
sin = (sin_freqs.sin() * self.mscale)
return sin, cos
@patch_to(DeepseekScalingRotaryEmbedding)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
with torch.device('cpu'):
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cpu") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
@patch_to(DeepseekScalingRotaryEmbedding)
def _compute_cos_sin_cache(self) -> torch.Tensor:
with torch.device('cpu'):
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
return cache
@patch_to(DeepseekScalingRotaryEmbedding)
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
self._match_cos_sin_cache_dtype(query)
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
cos_sin = self.cos_sin_cache[
torch.add(positions, offsets) if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
device = torch.supa.current_device()
cos = cos.to('cpu')
sin = sin.to('cpu')
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
cos = cos.to(device)
sin = sin.to(device)
rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
device = query_rot.device
if query.shape[0] > 1024:
query_rot = query_rot.to('cpu')
key_rot = key_rot.to('cpu')
cos = cos.to('cpu')
sin = sin.to('cpu')
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if query.shape[0] > 1024:
query_rot = query_rot.to(device)
key_rot = key_rot.to(device)
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key
@patch_to(DeepseekScalingRotaryEmbedding)
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query, key = self.forward_native(positions, query, key, offsets)
return query, key
@patch_to(YaRNScalingRotaryEmbedding)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
with torch.device('cpu'):
pos_freqs = self.base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
@patch_to(YaRNScalingRotaryEmbedding)
def _compute_cos_sin_cache(self) -> torch.Tensor:
with torch.device('cpu'):
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
freqs = freqs.repeat(1, 2)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
return sin, cos
def dtnamicNTK_compute_cos_sin_cache(
self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the cos and sin cache."""
with torch.device('cpu'):
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
if self.op_type == "Half" or self.op_type == "TeleChat":
freqs = freqs.repeat(1, 2)
cos = freqs.cos()
sin = freqs.sin()
else:
cos_freqs = freqs.repeat_interleave(2, dim=-1)
cos = cos_freqs.cos()
scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1
sin_freqs = cos_freqs * scales.reshape_as(cos_freqs)
sin = sin_freqs.sin()
return sin, cos
def dynamicNTKScaling_rope_forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if query.shape[-1] != key.shape[-1]:
query_, key_ = torch_br.supa_rope_infer_v2(query,
key,
self.sin_cache,
self.cos_cache,
positions,
self.head_size,
rope_type="MRope")
else:
query_, key_ = torch_br.supa_rope_infer_v2(query,
key,
self.sin_cache,
self.cos_cache,
positions,
self.head_size,
rope_type=self.op_type)
return query_, key_
DynamicNTKScalingRotaryEmbedding._compute_cos_sin_cache = dtnamicNTK_compute_cos_sin_cache
DynamicNTKScalingRotaryEmbedding.forward = dynamicNTKScaling_rope_forward
def _apply_rotary_emb_torch(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
is_neox_style: bool) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
def forward_MRotaryEmbedding_0_9_2(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
cos = torch.cat([
m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
],
dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_supa(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
if br_envs.VLLM_BR_USE_MROPE_0_9_2:
return forward_MRotaryEmbedding_0_9_2(self, positions, query, key)
assert positions.ndim == 1 or positions.ndim == 2
data_in_supa = lambda t: str(t.device).startswith('supa')
data_in_cpu = lambda t: t.device == torch.device('cpu')
if positions.ndim == 2:
# use bypass for decode stage
if (positions.shape[1] == 1):
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos[0]
sin = sin[0]
else:
cos_sin = self.cos_sin_cache[positions.to(torch.int64)]
cos, sin = cos_sin.chunk(2, dim=-1)
assert self.mrope_section
if self.mrope_interleaved:
cos = apply_interleaved_rope(cos, self.mrope_section)
sin = apply_interleaved_rope(sin, self.mrope_section)
else:
cos = torch.cat([
m[i] for i, m in enumerate(
cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i] for i, m in enumerate(
sin.split(self.mrope_section, dim=-1))
],
dim=-1)
if data_in_supa(query) and data_in_supa(key):
sin = sin.supa() if data_in_cpu(sin) else sin
cos = cos.supa() if data_in_cpu(cos) else cos
positions = positions.supa() if data_in_cpu(positions) else positions
query, key = torch_br.supa_rope_infer_v2(query,
key,
sin.to(torch.float32),
cos.to(torch.float32),
positions.to(torch.int32),
self.head_size,
rope_type="MRope")
return query, key
MRotaryEmbedding.forward = forward_supa
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
op_type: str = "Half",
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in dual_chunk_attention_config.items()
if k != "sparse_attention_config"
}
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
else:
dual_chunk_attention_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dual_chunk_attention_args, dtype)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
for k, v in dual_chunk_attention_config.items()
if k in ("chunk_size", "local_size")
}
rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype,
**extra_kwargs)
elif not rope_scaling:
rotary_emb = RotaryEmbedding(head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
op_type=op_type)
else:
scaling_type = rope_scaling["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype,
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype=torch.float32,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved",
False),
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype)
elif scaling_type == "ntk":
scaling_factor = rope_scaling["factor"]
mixed_b = rope_scaling.get('mixed_b', None)
rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype,
mixed_b)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor, dtype,
**extra_kwargs)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow", "mscale", "mscale_all_dim")
}
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
elif scaling_type == "deepseek_yarn_supa":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow", "mscale", "mscale_all_dim")
}
rotary_emb = SupaDeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb
def deepseek_get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> RotaryEmbedding:
return get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling, dtype, partial_rotary_factor,
dual_chunk_attention_config, "DeepSeek")
def chatglm2_get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> RotaryEmbedding:
return get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling, dtype, partial_rotary_factor,
dual_chunk_attention_config, "DeepSeek")
vllm.model_executor.layers.rotary_embedding.get_rope = get_rope
vllm.model_executor.models.deepseek_v2.get_rope = deepseek_get_rope
vllm.model_executor.models.chatglm.get_rope = chatglm2_get_rope
@patch_to(MRotaryEmbedding)
def _glm4v_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(enumerate(input_token_type),
lambda x: x[1]):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
for t_idx in range(llm_grid_t):
t_index = torch.tensor(t_idx).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
1, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
@patch_to(MRotaryEmbedding)
def get_input_positions_tensor_for_glm(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
from vllm.transformers_utils.config import thinker_uses_mrope
if thinker_uses_mrope(hf_config):
return cls._omni_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
elif "glm4v" in hf_config.model_type:
return cls._glm4v_get_input_positions_tensor(
cls,
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
)

View File

@@ -0,0 +1,65 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import torch
import vllm
from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask
def apply_penalties_fit(logits: torch.Tensor,
prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor) -> torch.Tensor:
"""
Applies penalties in place to the logits tensor
logits : The input logits tensor of shape [num_seqs, vocab_size]
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
are padded to the maximum prompt length within the batch using
`vocab_size` as the padding value. The value `vocab_size` is used
for padding because it does not correspond to any valid token ID
in the vocabulary.
output_tokens_tensor: The output tokens tensor.
presence_penalties: The presence penalties of shape (num_seqs, )
frequency_penalties: The frequency penalties of shape (num_seqs, )
repetition_penalties: The repetition penalties of shape (num_seqs, )
"""
num_seqs, vocab_size = logits.shape
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits
vllm.model_executor.layers.utils.apply_penalties = apply_penalties_fit

View File

@@ -0,0 +1,139 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import torch_br
import torch_br.supa._debug as supa_debug
from fastcore.basics import patch_to
import vllm
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
from vllm_br import envs
@patch_to(UnquantizedEmbeddingMethod)
def process_weights_after_loading(self, module):
if envs.VLLM_BR_EMBEDDING_S0B:
ori_weight = module.weight.data.cpu()
module.weight.data = torch_br._empty_ut_only(module.weight.shape,
"colmajor",
False,
0,
dtype=module.weight.dtype,
sbp='SB')
module.weight.data.copy_(ori_weight)
@patch_to(UnquantizedEmbeddingMethod)
def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
if envs.VLLM_BR_EMBEDDING_S0B:
y_supa = torch_br._empty_ut_only(
[1, input_.shape[0], layer.weight.shape[-1]],
is_numa=False,
dtype=layer.weight.dtype,
sbp='BB',
tensor_type="colmajor",
)
torch_br.out_embedding(y_supa, layer.weight.data, input_, -1, -1)
y_supa.squeeze_(0)
return y_supa
return F.embedding(input_, layer.weight)
@torch.jit.script
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.jit.script will fuse all of the pointwise ops below
# into a single kernel, making it very fast
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
if spc_num > 16:
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index -
org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
else:
input_, inv_vocab_mask = torch_br.supa_embedding_mask_infer(
input_, org_vocab_start_index, org_vocab_end_index,
num_org_vocab_padding, added_vocab_start_index,
added_vocab_end_index)
return input_, inv_vocab_mask
vllm.model_executor.layers.vocab_parallel_embedding.get_masked_input_and_mask = get_masked_input_and_mask
def vocab_parallel_embedding_forward(self, input_) -> torch.Tensor:
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
input_,
self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index,
)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1),
0) # type: ignore
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# numa weight is 3-dims
if len(layer.weight.shape) == 3:
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
return torch_br.br_fused_mlp_infer(
x, [layer.weight],
output_w=output_size,
bias=[bias] if bias is not None else None)
supa_debug.set_enable_sublas_api(True)
output = F.linear(x, layer.weight, bias)
supa_debug.set_enable_sublas_api(False)
return output
UnquantizedEmbeddingMethod.apply = apply
VocabParallelEmbedding.forward = vocab_parallel_embedding_forward