### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| vllm_ascend/ops/\_\_init\_\_.py |
| vllm_ascend/ops/activation.py |
| vllm_ascend/ops/flashcomm2_oshard_manager.py |
| vllm_ascend/ops/layernorm.py |
| vllm_ascend/ops/mla.py |
| vllm_ascend/ops/mm_encoder_attention.py |
| vllm_ascend/ops/register_custom_ops.py |
| vllm_ascend/ops/vocab_parallel_embedding.py |
| vllm_ascend/ops/weight_prefetch.py |
| vllm_ascend/spec_decode/\_\_init\_\_.py |
| vllm_ascend/spec_decode/eagle_proposer.py |
| vllm_ascend/spec_decode/interface.py |
| vllm_ascend/spec_decode/mtp_proposer.py |
| vllm_ascend/spec_decode/ngram_proposer.py |
| vllm_ascend/spec_decode/suffix_proposer.py |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
@@ -19,15 +19,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.mla import (MLAModules,
|
||||
MultiHeadLatentAttentionWrapper)
|
||||
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||
@@ -36,20 +34,20 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("v0.15.0"):
|
||||
from vllm.attention.layer import MLAAttention # type: ignore
|
||||
from vllm.attention.layer import MLAAttention # type: ignore
|
||||
else:
|
||||
from vllm.model_executor.layers.attention import MLAAttention
|
||||
|
||||
|
||||
class IndexerWrapper(nn.Module):
|
||||
'''
|
||||
"""
|
||||
A wrapper of Indexer for Deepseek v3.2.
|
||||
This wrapper is currently used to solve the fp8 hard code issue of vllm's deepseek_v2.py.
|
||||
It wraps the original Indexer, inherits its module weights
|
||||
(including wq_b, wk, weights_proj, k_norm)
|
||||
while deletes the unused topk_indices_buffer and k_cache to save memory.
|
||||
while deletes the unused topk_indices_buffer and k_cache to save memory.
|
||||
TODO: Will be removed once original Indexer supports different quantization methods.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_indexer: nn.Module) -> None:
|
||||
super().__init__()
|
||||
@@ -71,7 +69,6 @@ class IndexerWrapper(nn.Module):
|
||||
|
||||
|
||||
class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -80,11 +77,11 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
mla_modules: MLAModules,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
@@ -97,8 +94,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
self.v_head_dim = v_head_dim
|
||||
self.prefix = prefix
|
||||
hf_config = get_current_vllm_config().model_config.hf_text_config
|
||||
self.enable_shared_expert_dp = get_ascend_config(
|
||||
).enable_shared_expert_dp
|
||||
self.enable_shared_expert_dp = get_ascend_config().enable_shared_expert_dp
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.layers = hf_config.num_hidden_layers
|
||||
if mla_modules.indexer is not None:
|
||||
@@ -134,6 +130,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
|
||||
def wrapped_process_weights(act_dtype: torch.dtype):
|
||||
from vllm_ascend.attention.sfa_v1 import AscendSFAImpl
|
||||
|
||||
if not isinstance(self.mla_attn.impl, AscendSFAImpl):
|
||||
original_process_weights(act_dtype)
|
||||
self.mla_attn.impl.process_weights_after_loading(act_dtype)
|
||||
@@ -146,19 +143,17 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor | None = None,
|
||||
attn_metadata: AttentionMetadata | None = None,
|
||||
) -> torch.Tensor:
|
||||
need_gather_q_kv = get_forward_context().sp_enabled
|
||||
output_shape = hidden_states.shape
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output,
|
||||
self.prefix)
|
||||
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, self.prefix)
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
|
||||
@@ -176,9 +171,9 @@ def mla_forward(
|
||||
else:
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||
self.mla_attn.impl.forward(self.mla_attn.layer_name, hidden_states,
|
||||
kv_cache, attn_metadata, need_gather_q_kv,
|
||||
output)
|
||||
self.mla_attn.impl.forward(
|
||||
self.mla_attn.layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv, output
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user