### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| vllm_ascend/ops/\_\_init\_\_.py |
| vllm_ascend/ops/activation.py |
| vllm_ascend/ops/flashcomm2_oshard_manager.py |
| vllm_ascend/ops/layernorm.py |
| vllm_ascend/ops/mla.py |
| vllm_ascend/ops/mm_encoder_attention.py |
| vllm_ascend/ops/register_custom_ops.py |
| vllm_ascend/ops/vocab_parallel_embedding.py |
| vllm_ascend/ops/weight_prefetch.py |
| vllm_ascend/spec_decode/\_\_init\_\_.py |
| vllm_ascend/spec_decode/eagle_proposer.py |
| vllm_ascend/spec_decode/interface.py |
| vllm_ascend/spec_decode/mtp_proposer.py |
| vllm_ascend/spec_decode/ngram_proposer.py |
| vllm_ascend/spec_decode/suffix_proposer.py |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -24,14 +23,20 @@ from vllm.distributed import divide
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
method_has_implemented_embedding,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod,
|
||||
VocabParallelEmbedding, pad_vocab_size)
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
UnquantizedEmbeddingMethod,
|
||||
VocabParallelEmbedding,
|
||||
pad_vocab_size,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_embed_tp_group,
|
||||
get_lmhead_tp_group)
|
||||
from vllm_ascend.distributed.parallel_state import get_embed_tp_group, get_lmhead_tp_group
|
||||
from vllm_ascend.utils import embedding_tp_enable, lmhead_tp_enable
|
||||
|
||||
|
||||
@@ -42,14 +47,16 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
Added the feature of lmheadTP in pure dp scenario
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
org_num_embeddings: int | None = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.forward_type = None
|
||||
if lmhead_tp_enable() and "head" in prefix:
|
||||
@@ -67,18 +74,20 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
self.padding_size = padding_size
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
||||
self.padding_size)
|
||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size)
|
||||
self.num_embeddings_padded = pad_vocab_size(
|
||||
self.org_vocab_size_padded + num_added_embeddings,
|
||||
self.padding_size)
|
||||
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
|
||||
)
|
||||
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
||||
|
||||
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size,
|
||||
self.tp_rank, self.tp_size)
|
||||
self.shard_indices = self._get_indices(
|
||||
self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size,
|
||||
self.tp_rank,
|
||||
self.tp_size,
|
||||
)
|
||||
self.embedding_dim = embedding_dim
|
||||
quant_method = None
|
||||
if quant_config is not None:
|
||||
@@ -90,12 +99,12 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self) is VocabParallelEmbedding
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(quant_method))
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(type(quant_method))
|
||||
if is_embedding_layer and not quant_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(quant_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod."
|
||||
)
|
||||
|
||||
self.quant_method: QuantizeMethodBase = quant_method
|
||||
|
||||
@@ -104,46 +113,47 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
self.params_dtype = params_dtype
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||
self.tp_size)
|
||||
assert (self.shard_indices.num_elements_padded ==
|
||||
self.num_embeddings_per_partition)
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded, self.tp_size)
|
||||
assert self.shard_indices.num_elements_padded == self.num_embeddings_per_partition
|
||||
self.num_org_embeddings_per_partition = (
|
||||
self.shard_indices.org_vocab_end_index -
|
||||
self.shard_indices.org_vocab_start_index)
|
||||
self.shard_indices.org_vocab_end_index - self.shard_indices.org_vocab_start_index
|
||||
)
|
||||
self.num_added_embeddings_per_partition = (
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
self.shard_indices.added_vocab_end_index - self.shard_indices.added_vocab_start_index
|
||||
)
|
||||
|
||||
self.quant_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
self.quant_method.create_weights(
|
||||
self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
)
|
||||
|
||||
def _get_masked_input_and_mask(
|
||||
self, input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
org_vocab_start_index: int,
|
||||
org_vocab_end_index: int,
|
||||
num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.compile will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
||||
input_ < org_vocab_end_index)
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
|
||||
# Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index.
|
||||
if added_vocab_start_index == added_vocab_end_index:
|
||||
valid_offset = (org_vocab_start_index * org_vocab_mask)
|
||||
valid_offset = org_vocab_start_index * org_vocab_mask
|
||||
vocab_mask = org_vocab_mask
|
||||
else:
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index -
|
||||
org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (input_ < added_vocab_end_index)
|
||||
added_offset = (
|
||||
added_vocab_start_index - (org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
)
|
||||
valid_offset = (org_vocab_start_index * org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
# Adapt end.
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
@@ -158,14 +168,15 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
def _forward_embed_tp(self, input_):
|
||||
complete_input = self.comm_group.all_gather(input_, dim=0)
|
||||
masked_input, input_mask = self._get_masked_input_and_mask(
|
||||
complete_input, self.shard_indices.org_vocab_start_index,
|
||||
complete_input,
|
||||
self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
self.shard_indices.added_vocab_end_index,
|
||||
)
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
||||
output = output.view(input_.shape[0], -1)
|
||||
@@ -175,16 +186,17 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = self._get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
input_,
|
||||
self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
self.shard_indices.added_vocab_end_index,
|
||||
)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
@@ -197,29 +209,31 @@ class AscendParallelLMHead(ParallelLMHead):
|
||||
"""
|
||||
Register ParallelLMHead as a custom op for Ascend."""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
AscendVocabParallelEmbedding.__init__(self, num_embeddings,
|
||||
embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size,
|
||||
quant_config, prefix)
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
org_num_embeddings: int | None = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
AscendVocabParallelEmbedding.__init__(
|
||||
self, num_embeddings, embedding_dim, params_dtype, org_num_embeddings, padding_size, quant_config, prefix
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
self.bias = Parameter(torch.empty(self.num_embeddings_per_partition, dtype=params_dtype))
|
||||
set_weight_attrs(
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
@@ -234,48 +248,41 @@ class AscendLogitsProcessor(LogitsProcessor):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: AscendParallelLMHead,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
embedding_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
if lmhead_tp_enable():
|
||||
return self._get_logits_lmheadtp(hidden_states, lm_head,
|
||||
embedding_bias)
|
||||
return self._get_logits_lmheadtp(hidden_states, lm_head, embedding_bias)
|
||||
else:
|
||||
return self._get_logits_normal(hidden_states, lm_head,
|
||||
embedding_bias)
|
||||
return self._get_logits_normal(hidden_states, lm_head, embedding_bias)
|
||||
|
||||
def _get_logits_lmheadtp(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: AscendParallelLMHead,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
embedding_bias: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
# Gather hidden states from all devices in tensor parallel group
|
||||
gathered_hidden_states = get_lmhead_tp_group().all_gather(
|
||||
hidden_states, dim=0)
|
||||
local_logits = lm_head.quant_method.apply(lm_head,
|
||||
gathered_hidden_states,
|
||||
bias=embedding_bias)
|
||||
gathered_hidden_states = get_lmhead_tp_group().all_gather(hidden_states, dim=0)
|
||||
local_logits = lm_head.quant_method.apply(lm_head, gathered_hidden_states, bias=embedding_bias)
|
||||
# Gather logits for tensor parallel
|
||||
logits = get_lmhead_tp_group().all_to_all(local_logits)
|
||||
# Remove paddings in vocab (if any)
|
||||
if logits is not None:
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
logits = logits[..., : self.org_vocab_size]
|
||||
return logits
|
||||
|
||||
def _get_logits_normal(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: AscendParallelLMHead,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
local_logits = lm_head.quant_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
embedding_bias: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
local_logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias)
|
||||
# Gather logits for tensor parallel
|
||||
logits = self._gather_logits(local_logits)
|
||||
|
||||
# Remove paddings in vocab (if any)
|
||||
if logits is not None:
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
logits = logits[..., : self.org_vocab_size]
|
||||
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user