Files
xc-llm-ascend/vllm_ascend/ops/vocab_parallel_embedding.py
Zetong Li a86ece5e39 [Bugfix][LoRA] Fix forward error and shape mismatch when using LoRA (#3153)
### What this PR does / why we need it?
Relying on #3044, this PR aims to further fix:
1. The forward error occured when `LogitsProcessorWithLoRA` calls
`AscendLogitsProcessor.forward`. Since `LogitsProcessorWithLoRA`
bypasses the MRO to call it, `super().forward(...)` in
`AscendLogitsProcessor.forward` will raise an error. This PR fixes it by
directly invoking `LogitsProcessor.forward(self, ...)`;
2. The shape mismatch in `add_lora_logits` in punica_npu.py. The
`lora_a_stacked` and `lora_b_stacked` are organized as [num_loras, 1,
lora_rank, hidden_size] and [num_loras, 1, vocab_size, lora_rank] shapes
respectively, but they are misunderstood in #1583---the last two
dimensions were assumed in reverse order, which causes errors in
`bgmv_shrink` and `bgmv_expand`. This PR fixes it by reverting it to the
previous version to align with the implementation in punica_cpu.py in
vllm.

### Dependencies
This PR depends on changes introduced by #3044 (LoRA support for
`AscendQKVParallelLinear` and `AscendMergedQKVParallelLinear` layers).

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
The LoRA-related tests, e.g., test_ilama_lora.py and
test_ilama_lora_tp2.py, use ilama-3.2-1B, and this model is regarded as
`TransformersForCausalLM`, where `embedding_modules` attribute lacks
`lm_head`. However, `LlamaForCausalLM` and most other models include
both `embed_tokens` and `lm_head` in `embedding_modules`. This attribute
contributes to `supported_lora_modules` when using LoRA in vllm.
Therefore, without `lm_head` in `embedding_modules`, current tests using
ilama-3.2-1B are unable to find the abve errors since
`LogitsProcessorWithLoRA` replacing `lm_head` is skipped. Simply using
Meta-Llama-3.1-8B-Instruct can reproduce the above errors and check
whether these fixes can work. What's more, it's necessary to add more
comprehensive tests for LoRA.

- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9

Signed-off-by: Zetong Li <slippersss@126.com>
2025-09-28 17:30:50 +08:00

269 lines
12 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn.parameter import Parameter
from vllm.distributed import divide, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod,
VocabParallelEmbedding, pad_vocab_size)
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group
from vllm_ascend.utils import lmhead_tp_enable
class AscendVocabParallelEmbedding(VocabParallelEmbedding):
"""
Register VocabParallelEmbedding as a custom op for Ascend.
AscendVocabParallelEmbedding support different communication parallel groups
Added the feature of lmheadTP in pure dp scenario
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
nn.Module.__init__(self)
if lmhead_tp_enable() and prefix.find("lm_head") != -1:
self.comm_group = get_lmhead_tp_group()
else:
self.comm_group = get_tp_group()
self.tp_size = self.comm_group.world_size
self.tp_rank = self.comm_group.rank_in_group
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size,
self.tp_rank, self.tp_size)
self.embedding_dim = embedding_dim
quant_method = None
if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
self.tp_size)
assert (self.shard_indices.num_elements_padded ==
self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index -
self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.quant_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def _get_masked_input_and_mask(
self, input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
# Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index.
if added_vocab_start_index == added_vocab_end_index:
valid_offset = (org_vocab_start_index * org_vocab_mask)
vocab_mask = org_vocab_mask
else:
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index -
org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
# Adapt end.
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = self._get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self,
masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
class AscendParallelLMHead(ParallelLMHead):
"""
Register ParallelLMHead as a custom op for Ascend."""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
AscendVocabParallelEmbedding.__init__(self, num_embeddings,
embedding_dim, params_dtype,
org_num_embeddings, padding_size,
quant_config, prefix)
self.quant_config = quant_config
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
class AscendLogitsProcessor(LogitsProcessor):
"""
Register LogitsProcessor as a custom op for Ascend.
Added the feature of lmheadTP in pure dp scenario
"""
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: AscendParallelLMHead,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
if lmhead_tp_enable():
return self._get_logits_lmheadtp(hidden_states, lm_head,
embedding_bias)
else:
return self._get_logits_normal(hidden_states, lm_head,
embedding_bias)
def _get_logits_lmheadtp(
self,
hidden_states: torch.Tensor,
lm_head: AscendParallelLMHead,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Gather hidden states from all devices in tensor parallel group
gathered_hidden_states = get_lmhead_tp_group().all_gather(
hidden_states, dim=0)
local_logits = lm_head.quant_method.apply(lm_head,
gathered_hidden_states,
bias=embedding_bias)
# Gather logits for tensor parallel
logits = get_lmhead_tp_group().all_to_all(local_logits)
# Remove paddings in vocab (if any)
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits
def _get_logits_normal(
self,
hidden_states: torch.Tensor,
lm_head: AscendParallelLMHead,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
local_logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
# Gather logits for tensor parallel
logits = self._gather_logits(local_logits)
# Remove paddings in vocab (if any)
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits
def forward(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
# keep this for version compatibility
sampling_metadata=None, # type: ignore
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
return LogitsProcessor.forward(self,
lm_head,
hidden_states,
embedding_bias=embedding_bias)