################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from typing import Optional, Tuple import torch import torch.nn.functional as F import torch_br import torch_br.supa._debug as supa_debug from fastcore.basics import patch_to import vllm from vllm.distributed import tensor_model_parallel_all_reduce from vllm.model_executor.layers.vocab_parallel_embedding import ( UnquantizedEmbeddingMethod, VocabParallelEmbedding) from vllm_br import envs @patch_to(UnquantizedEmbeddingMethod) def process_weights_after_loading(self, module): if envs.VLLM_BR_EMBEDDING_S0B: ori_weight = module.weight.data.cpu() module.weight.data = torch_br._empty_ut_only(module.weight.shape, "colmajor", False, 0, dtype=module.weight.dtype, sbp='SB') module.weight.data.copy_(ori_weight) @patch_to(UnquantizedEmbeddingMethod) def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: if envs.VLLM_BR_EMBEDDING_S0B: y_supa = torch_br._empty_ut_only( [1, input_.shape[0], layer.weight.shape[-1]], is_numa=False, dtype=layer.weight.dtype, sbp='BB', tensor_type="colmajor", ) torch_br.out_embedding(y_supa, layer.weight.data, input_, -1, -1) y_supa.squeeze_(0) return y_supa return F.embedding(input_, layer.weight) @torch.jit.script def get_masked_input_and_mask( input_: torch.Tensor, org_vocab_start_index: int, org_vocab_end_index: int, num_org_vocab_padding: int, added_vocab_start_index: int, added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: # torch.jit.script will fuse all of the pointwise ops below # into a single kernel, making it very fast spc_num = envs.VLLM_BR_DEVICE_SPC_NUM if spc_num > 16: org_vocab_mask = (input_ >= org_vocab_start_index) & ( input_ < org_vocab_end_index) added_vocab_mask = (input_ >= added_vocab_start_index) & ( input_ < added_vocab_end_index) added_offset = added_vocab_start_index - ( org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding valid_offset = (org_vocab_start_index * org_vocab_mask) + (added_offset * added_vocab_mask) vocab_mask = org_vocab_mask | added_vocab_mask input_ = vocab_mask * (input_ - valid_offset) return input_, ~vocab_mask else: input_, inv_vocab_mask = torch_br.supa_embedding_mask_infer( input_, org_vocab_start_index, org_vocab_end_index, num_org_vocab_padding, added_vocab_start_index, added_vocab_end_index) return input_, inv_vocab_mask vllm.model_executor.layers.vocab_parallel_embedding.get_masked_input_and_mask = get_masked_input_and_mask def vocab_parallel_embedding_forward(self, input_) -> torch.Tensor: if self.tp_size > 1: # Build the mask. masked_input, input_mask = get_masked_input_and_mask( input_, self.shard_indices.org_vocab_start_index, self.shard_indices.org_vocab_end_index, self.shard_indices.num_org_vocab_padding, self.shard_indices.added_vocab_start_index, self.shard_indices.added_vocab_end_index, ) else: masked_input = input_ # Get the embeddings. output_parallel = self.quant_method.embedding(self, masked_input.long()) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) # type: ignore # Reduce across all the model parallel GPUs. output = tensor_model_parallel_all_reduce(output_parallel) return output def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: # numa weight is 3-dims if len(layer.weight.shape) == 3: output_size = (layer.output_size_per_partition if hasattr( layer, "output_size_per_partition") else layer.output_size) return torch_br.br_fused_mlp_infer( x, [layer.weight], output_w=output_size, bias=[bias] if bias is not None else None) supa_debug.set_enable_sublas_api(True) output = F.linear(x, layer.weight, bias) supa_debug.set_enable_sublas_api(False) return output UnquantizedEmbeddingMethod.apply = apply VocabParallelEmbedding.forward = vocab_parallel_embedding_forward