first commit
This commit is contained in:
139
vllm_br/model_executor/layers/vocab_parallel_embedding.py
Normal file
139
vllm_br/model_executor/layers/vocab_parallel_embedding.py
Normal file
@@ -0,0 +1,139 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_br
|
||||
import torch_br.supa._debug as supa_debug
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
import vllm
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
@patch_to(UnquantizedEmbeddingMethod)
|
||||
def process_weights_after_loading(self, module):
|
||||
if envs.VLLM_BR_EMBEDDING_S0B:
|
||||
ori_weight = module.weight.data.cpu()
|
||||
module.weight.data = torch_br._empty_ut_only(module.weight.shape,
|
||||
"colmajor",
|
||||
False,
|
||||
0,
|
||||
dtype=module.weight.dtype,
|
||||
sbp='SB')
|
||||
module.weight.data.copy_(ori_weight)
|
||||
|
||||
|
||||
@patch_to(UnquantizedEmbeddingMethod)
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
if envs.VLLM_BR_EMBEDDING_S0B:
|
||||
y_supa = torch_br._empty_ut_only(
|
||||
[1, input_.shape[0], layer.weight.shape[-1]],
|
||||
is_numa=False,
|
||||
dtype=layer.weight.dtype,
|
||||
sbp='BB',
|
||||
tensor_type="colmajor",
|
||||
)
|
||||
torch_br.out_embedding(y_supa, layer.weight.data, input_, -1, -1)
|
||||
y_supa.squeeze_(0)
|
||||
return y_supa
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.jit.script will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
if spc_num > 16:
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
||||
input_ < org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index -
|
||||
org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
else:
|
||||
input_, inv_vocab_mask = torch_br.supa_embedding_mask_infer(
|
||||
input_, org_vocab_start_index, org_vocab_end_index,
|
||||
num_org_vocab_padding, added_vocab_start_index,
|
||||
added_vocab_end_index)
|
||||
return input_, inv_vocab_mask
|
||||
|
||||
|
||||
vllm.model_executor.layers.vocab_parallel_embedding.get_masked_input_and_mask = get_masked_input_and_mask
|
||||
|
||||
|
||||
def vocab_parallel_embedding_forward(self, input_) -> torch.Tensor:
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_,
|
||||
self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index,
|
||||
)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1),
|
||||
0) # type: ignore
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# numa weight is 3-dims
|
||||
if len(layer.weight.shape) == 3:
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.weight],
|
||||
output_w=output_size,
|
||||
bias=[bias] if bias is not None else None)
|
||||
supa_debug.set_enable_sublas_api(True)
|
||||
output = F.linear(x, layer.weight, bias)
|
||||
supa_debug.set_enable_sublas_api(False)
|
||||
return output
|
||||
|
||||
|
||||
UnquantizedEmbeddingMethod.apply = apply
|
||||
|
||||
VocabParallelEmbedding.forward = vocab_parallel_embedding_forward
|
||||
Reference in New Issue
Block a user