From 1a1f9a6d894a3947fcff4c5c52fa0846c35e5759 Mon Sep 17 00:00:00 2001 From: Pleaplusone <38376071+ganyi1996ppo@users.noreply.github.com> Date: Sat, 19 Apr 2025 17:38:18 +0800 Subject: [PATCH] port deepseekv2 and mtp to main branch (#429) ### What this PR does / why we need it? This PR ports all the deepseek graph mode code and mtp code from v0.7.3 to the main branch --------- Signed-off-by: SidaoY <1024863041@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: Yizhou Liu Signed-off-by: mengwei805 Signed-off-by: libaokui Signed-off-by: q00832892 Signed-off-by: ganyi Co-authored-by: SidaoY <1024863041@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: Yizhou Liu Co-authored-by: mengwei805 Co-authored-by: libaokui --- examples/disaggregated_prefill_hccl.py | 128 ++++ examples/dp_offline/data_parallel.py | 86 +++ examples/dp_offline/run_dp.sh | 21 + examples/offline_inference_npu_v1.py | 49 ++ tests/conftest.py | 11 +- vllm_ascend/attention/attention.py | 325 +++++++--- vllm_ascend/attention/attention_v1.py | 107 +++- vllm_ascend/attention/mla_v1.py | 561 ++++++++++++++++++ .../distributed/llmdatadist_connector.py | 2 +- vllm_ascend/distributed/parallel_state.py | 75 +++ vllm_ascend/models/__init__.py | 7 + vllm_ascend/models/deepseek_mtp.py | 173 ++++++ vllm_ascend/models/deepseek_v2.py | 324 +++++++++- vllm_ascend/ops/__init__.py | 1 + vllm_ascend/ops/attention.py | 293 +++++++++ vllm_ascend/ops/cache.py | 35 ++ vllm_ascend/ops/fused_moe.py | 476 +++++++++++++-- vllm_ascend/ops/rotary_embedding.py | 169 +++++- vllm_ascend/ops/vocab_parallel_embedding.py | 67 +++ vllm_ascend/patch/__init__.py | 16 + .../patch/platform/patch_0_8_4/__init__.py | 3 +- .../platform/patch_0_8_4/patch_distributed.py | 138 +++++ .../patch/platform/patch_main/__init__.py | 3 +- .../platform/patch_main/patch_distributed.py | 138 +++++ vllm_ascend/platform.py | 29 + vllm_ascend/quantization/quant_config.py | 74 +-- vllm_ascend/utils.py | 4 + vllm_ascend/worker/__init__.py | 17 + vllm_ascend/worker/cache_engine.py | 69 +++ vllm_ascend/worker/model_runner.py | 183 ++++-- vllm_ascend/worker/model_runner_v1.py | 62 +- vllm_ascend/worker/worker.py | 20 +- vllm_ascend/worker/worker_v1.py | 10 + 33 files changed, 3361 insertions(+), 315 deletions(-) create mode 100644 examples/disaggregated_prefill_hccl.py create mode 100644 examples/dp_offline/data_parallel.py create mode 100644 examples/dp_offline/run_dp.sh create mode 100644 examples/offline_inference_npu_v1.py create mode 100644 vllm_ascend/attention/mla_v1.py create mode 100644 vllm_ascend/distributed/parallel_state.py create mode 100644 vllm_ascend/models/deepseek_mtp.py create mode 100644 vllm_ascend/ops/attention.py create mode 100644 vllm_ascend/ops/cache.py create mode 100644 vllm_ascend/ops/vocab_parallel_embedding.py create mode 100644 vllm_ascend/patch/platform/patch_0_8_4/patch_distributed.py create mode 100644 vllm_ascend/patch/platform/patch_main/patch_distributed.py create mode 100644 vllm_ascend/worker/cache_engine.py diff --git a/examples/disaggregated_prefill_hccl.py b/examples/disaggregated_prefill_hccl.py new file mode 100644 index 0000000..ab82abc --- /dev/null +++ b/examples/disaggregated_prefill_hccl.py @@ -0,0 +1,128 @@ +""" + This file demonstrates the example usage of disaggregated prefilling + We will launch 2 vllm instances (NPU 0,1 for prefill and NPU 2,3 for decode), + and then transfer the KV cache between them. + """ +import multiprocessing as mp +import os +import time +from multiprocessing import Event, Process + + +def clean_up(): + import gc + + import torch + from vllm.distributed.parallel_state import ( + destroy_distributed_environment, destroy_model_parallel) + destroy_model_parallel() + destroy_distributed_environment() + gc.collect() + torch.npu.empty_cache() + + +def run_prefill(prefill_done, process_close): + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" + + from vllm import LLM, SamplingParams + from vllm.config import KVTransferConfig + + prompts = [ + "Hello, how are you today?", "Hi, what is your name?", + "Tell me a very long story.", "what is your favourite book?" + ] + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}' + ) + + # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB + # memory. You may need to adjust the value to fit your GPU. + llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + kv_transfer_config=ktc, + max_model_len=2000, + gpu_memory_utilization=0.8, + tensor_parallel_size=2) + + llm.generate(prompts, sampling_params) + print("Prefill node is finished.") + prefill_done.set() + + # To keep the prefill node running in case the decode node is not done; + # otherwise, the script might exit prematurely, causing incomplete decoding. + try: + while not process_close.is_set(): + time.sleep(1) + except KeyboardInterrupt: + print("Script stopped by user.") + finally: + print("Cleanup prefill resources") + del llm + clean_up() + + +def run_decode(prefill_done): + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "2,3" + + from vllm import LLM, SamplingParams + from vllm.config import KVTransferConfig + + prompts = [ + "Hello, how are you today?", "Hi, what is your name?", + "Tell me a very long story.", "what is your favourite book?" + ] + sampling_params = SamplingParams(temperature=0, top_p=0.95) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}' + ) + + llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + kv_transfer_config=ktc, + max_model_len=2000, + gpu_memory_utilization=0.8, + tensor_parallel_size=2) + + # Wait for the producer to start the comsumer + print("Waiting for prefill node to finish...") + prefill_done.wait() + + # At this point when the prefill_done is set, the kv-cache should have been + # transferred to this decode node, so we can start decoding. + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + del llm + clean_up() + + +if __name__ == "__main__": + mp.get_context('spawn') + + prefill_done = Event() + process_close = Event() + prefill_process = Process(target=run_prefill, + args=( + prefill_done, + process_close, + )) + decode_process = Process(target=run_decode, args=(prefill_done, )) + + # Start prefill node + prefill_process.start() + + # Start decode node + decode_process.start() + + # Terminate the prefill node when decode is finished + decode_process.join() + + # Terminate prefill process + process_close.set() + prefill_process.join() + prefill_process.terminate() + print("All process done!") diff --git a/examples/dp_offline/data_parallel.py b/examples/dp_offline/data_parallel.py new file mode 100644 index 0000000..ae5b104 --- /dev/null +++ b/examples/dp_offline/data_parallel.py @@ -0,0 +1,86 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/examples/offline_inference/data_parallel.py +# SPDX-License-Identifier: Apache-2.0 +# usage: +# python examples/offline_inference_data_parallel.py +# we need to have a launcher to create multiple data parallel +# ranks. And each rank will create a vLLM instance to process its own prompts. + +import gc +import os + +VLLM_ENABLE_GRAPGH_MODE = os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1" + + +def main(): + dp_rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + dp_size = int(os.environ['WORLD_SIZE']) + master_addr = os.environ['MASTER_ADDR'] + master_port = os.environ['MASTER_PORT'] + tp_size = 4 + etp_size = 2 + + os.environ["VLLM_DP_RANK"] = str(dp_rank) + os.environ["VLLM_DP_SIZE"] = str(dp_size) + os.environ["VLLM_DP_MASTER_IP"] = master_addr + os.environ["VLLM_DP_MASTER_PORT"] = master_port + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = ",".join( + str(i) + for i in range(local_rank * tp_size, (local_rank + 1) * tp_size)) + + import torch + import torch_npu # noqa + from vllm import LLM, SamplingParams + from vllm.distributed.parallel_state import ( + destroy_distributed_environment, destroy_model_parallel) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 4 + + promts_per_rank = len(prompts) // dp_size + start = dp_rank * promts_per_rank + end = start + promts_per_rank + prompts = prompts[start:end] + if len(prompts) == 0: + prompts = ["Placeholder"] + print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts") + num_seqs = len(prompts) + + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=4, + min_tokens=4) + # Create an LLM. + llm = LLM( + model="deepseek-ai/DeepSeek-V2-Lite-Chat", + tensor_parallel_size=tp_size, + trust_remote_code=True, + expert_tensor_parallel_size=etp_size, + max_model_len=4096, + max_num_seqs=num_seqs, + compilation_config=1 if VLLM_ENABLE_GRAPGH_MODE else 0, + ) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"DP rank {dp_rank}, Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") + + del llm + destroy_model_parallel() + destroy_distributed_environment() + gc.collect() + torch.npu.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/examples/dp_offline/run_dp.sh b/examples/dp_offline/run_dp.sh new file mode 100644 index 0000000..0e525f4 --- /dev/null +++ b/examples/dp_offline/run_dp.sh @@ -0,0 +1,21 @@ +export HCCL_IF_IP=${local_ip} +export GLOO_SOCKET_IFNAME=${ifname} +export TP_SOCKET_IFNAME=${ifname} +export HCCL_SOCKET_IFNAME=${ifname} + +# dp_size = node_size * dp_per_node +node_size=1 +node_rank=0 +dp_per_node=2 +master_addr=127.0.0.1 +master_port=12345 + +rm -rf ./.torchair_cache/ +rm -rf ./dynamo_* +rm -rf /root/ascend/log/debug/plog/* +export VLLM_ENABLE_GRAPH_MODE=0 +export VLLM_ENABLE_MC2=0 + +torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \ + --node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \ + data_parallel.py diff --git a/examples/offline_inference_npu_v1.py b/examples/offline_inference_npu_v1.py new file mode 100644 index 0000000..939d84b --- /dev/null +++ b/examples/offline_inference_npu_v1.py @@ -0,0 +1,49 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/examples/offline_inference/basic.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +from vllm import LLM, SamplingParams + +os.environ["VLLM_USE_V1"] = "1" +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +if __name__ == "__main__": + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(max_tokens=100, temperature=0.0) + # Create an LLM. + llm = LLM(model="/data/weights/deepseek-ai/deepseekv3-lite-base-latest", + tensor_parallel_size=2, + enforce_eager=True, + trust_remote_code=True, + max_model_len=1024) + + # Generate texts from the prompts. + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/conftest.py b/tests/conftest.py index 48430a7..606ff83 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,8 +26,6 @@ import torch from PIL import Image from vllm import LLM, SamplingParams from vllm.config import TaskOption -from vllm.distributed.parallel_state import (destroy_distributed_environment, - destroy_model_parallel) from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams @@ -35,6 +33,15 @@ from vllm.utils import is_list_of from tests.model_utils import (TokensTextLogprobs, TokensTextLogprobsPromptLogprobs) +# TODO: remove this part after the patch merged into vllm, if +# we not explicitly patch here, some of them might be effectiveless +# in pytest scenario +from vllm_ascend.utils import adapt_patch # noqa E402 + +adapt_patch(True) + +from vllm.distributed.parallel_state import ( # noqa E402 + destroy_distributed_environment, destroy_model_parallel) _M = TypeVar("_M") diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index 7a2d179..f3acf08 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -31,13 +31,14 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionMetadata, AttentionType, MLAAttentionImpl) -from vllm.attention.backends.utils import (CommonAttentionState, +from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, CommonMetadataBuilder, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE from vllm_ascend.worker.model_runner import ( ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata) @@ -222,7 +223,7 @@ class AscendMLAAttentionBackend(AscendAttentionBackend): num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return (1, num_blocks, block_size, num_kv_heads * head_size) + return (num_blocks, block_size, num_kv_heads, head_size) @dataclass @@ -552,10 +553,33 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): inter_data.block_tables, ) + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables # [:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + def build( self, seq_lens: List[int], query_lens: List[int], + graph_pad_size: int, ): """Build attention metadata with on-device tensors. @@ -568,6 +592,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): self.input_builder.chunked_prefill_enabled) device = self.runner.device + use_torchair_graph = graph_pad_size != -1 max_query_len = max(query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) @@ -582,12 +607,36 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): self.attn_mask = None num_decode_tokens = self.num_decode_tokens + if self.num_prefills == 0 and use_torchair_graph: + num_seqs = len(seq_lens) + self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size) + self.block_tables.extend([[]] * graph_pad_size) + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int32, + device=device, + ) + + if self.num_prefills > 0: + self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore + max_prefill_seq_len, + self.input_builder.runner.model_config.dtype, + self.input_builder.runner.device) + else: + self.attn_mask = None + num_decode_tokens = self.num_decode_tokens + block_tables = make_tensor_with_pad( self.block_tables, pad=0, dtype=torch.int32, device=device, ) + assert max_query_len > 0, "query_lens: {}".format(query_lens) assert device is not None @@ -855,14 +904,100 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): self.q_proj = extra_impl_args['q_proj'] self.kv_b_proj = extra_impl_args['kv_b_proj'] self.o_proj = extra_impl_args['o_proj'] + self.kv_a_proj_with_mqa = extra_impl_args.get('kv_a_proj_with_mqa', + None) + self.kv_a_layernorm = extra_impl_args.get('kv_a_layernorm', None) + self.k_pe_cache = None + self.k_nope_cache = None self.w_kc = None self.w_vc = None + def exec_kv( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + B = hidden_states.shape[0] + N = self.num_kv_heads + S = 1 + kv = self.kv_a_proj_with_mqa(hidden_states)[0] + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + + k_pe, k_nope = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache( + kv, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode="PA", + ) + + return k_pe, k_nope + + def apply_rotary_emb( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, + ) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + def rope_single( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + B, N, D = x.shape + S = 1 + x = x.view(B, N, S, D) + x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin) + return x.view(B, N, D) + + def process_weights_after_loading(self, act_dtype: torch.dtype): + if self.w_kc is None or self.w_vc is None: + kv_b_proj_weight = self.kv_b_proj.weight.reshape( + self.num_heads, self.qk_nope_head_dim + self.v_head_dim, + self.kv_lora_rank) + self.w_kc = kv_b_proj_weight[:, :self. + qk_nope_head_dim, :].contiguous() + self.w_vc = kv_b_proj_weight[:, + self.qk_nope_head_dim:, :].transpose( + 1, 2).contiguous() + def forward( self, layer: AttentionLayer, hidden_states_or_q_c: torch.Tensor, - kv_c_normed: torch.Tensor, + hidden_states_or_kv_c_normed: torch.Tensor, k_pe: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AscendMetadata, @@ -873,7 +1008,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): Args: hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size] num_tokens = batch_size * seq_len - kv_c_normed: shape = [num_tokens, num_kv_heads * head_size] + hidden_states_or_kv_c_normed: shape = [num_tokens, num_kv_heads * head_size] k_pe: shape = [num_tokens, num_kv_heads * head_size] kv_cache: shape = [1, num_blocks, block_size, num_kv_heads * head_size] @@ -889,71 +1024,86 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): "are not implemented for " "PallasAttentionBackendImpl") + if attn_metadata is None: + # for profile run + return hidden_states_or_q_c + num_tokens = hidden_states_or_q_c.shape[0] q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads, self.qk_head_dim) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + if k_pe is None and attn_metadata.decode_metadata: + seq_len = self.rotary_emb.max_position_embeddings - k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1) + cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype) + sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype) + cos = cos[attn_metadata.input_positions] + sin = sin[attn_metadata.input_positions] + cos = cos[:, None, None, :] + sin = sin[:, None, None, :] - if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding': - ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape - q_pe = q_pe.reshape(num_tokens, -1) - k_pe = k_pe.reshape(num_tokens, -1) - q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, - k_pe) - q_pe = q_pe.view(ori_q_pe_shape) - k_pe = k_pe.view(ori_k_pe_shape) + q_pe = self.rope_single(q_pe, cos, sin) + k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, cos, sin, + kv_cache, attn_metadata.slot_mapping) else: - q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, - k_pe) - - if self.w_kc is None or self.w_vc is None: - kv_b_proj_weight = self.kv_b_proj.weight.reshape( - self.num_heads, self.qk_nope_head_dim + self.v_head_dim, - self.kv_lora_rank) - self.w_kc = kv_b_proj_weight[:, :self. - qk_nope_head_dim, :].contiguous() - self.w_vc = kv_b_proj_weight[:, - self.qk_nope_head_dim:, :].transpose( - 1, 2).contiguous() + if k_pe is None: + # NOTE: k_pe is None when graph mode enabled + kv_c, k_pe = self.kv_a_proj_with_mqa( + hidden_states_or_kv_c_normed)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + else: + kv_c_normed = hidden_states_or_kv_c_normed + k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1) + if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding': + # NOTE: When scaling not specified + ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape + q_pe = q_pe.reshape(num_tokens, -1) + k_pe = k_pe.reshape(num_tokens, -1) + q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, + q_pe, k_pe) + q_pe = q_pe.view(ori_q_pe_shape) + k_pe = k_pe.view(ori_k_pe_shape) + else: + q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, + q_pe, k_pe) if attn_metadata.num_prefills > 0: - kv_heads_num = self.num_heads - kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens, kv_heads_num, - -1) + kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens, + self.num_heads, -1) k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_cache = torch.cat( - [kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe], - dim=2) - k_pe = k_pe.expand(-1, self.num_heads, -1) - key = torch.cat([k_nope.view(num_tokens, kv_heads_num, -1), k_pe], - dim=2) else: - kv_heads_num = self.num_kv_heads q_nope_t = torch.transpose(q_nope, 0, 1) q_nope_out = torch.bmm(q_nope_t, self.w_kc) q_nope = torch.transpose(q_nope_out, 0, 1) - k_cache = torch.cat( - [kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe], - dim=2) query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens, self.num_heads, -1) - if kv_cache.numel() > 0: - key_cache = kv_cache[0] - num_blocks, block_size, _ = key_cache.shape - - key_cache = key_cache.view( - num_blocks, block_size, self.num_kv_heads, - self.qk_rope_head_dim + self.kv_lora_rank) - slots = attn_metadata.slot_mapping - torch_npu._npu_reshape_and_cache_siso(key=k_cache, - key_cache=key_cache, - slot_indices=slots) + # TODO: Replace the env with more flexible expressions + if VLLM_ENABLE_GRAPH_MODE == '1': + if len(kv_cache) > 0 and kv_cache[0].numel( + ) > 0 and attn_metadata.num_prefills > 0: + slots = attn_metadata.slot_mapping + # NOTE: Seperate the kv cache in advance to avoid OOM or other issues + torch_npu._npu_reshape_and_cache(key=kv_c_normed.view( + num_tokens, self.num_kv_heads, -1), + value=k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=slots) + else: + if kv_cache.numel() > 0: + key = torch.cat([ + kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe + ], + dim=2) + slots = attn_metadata.slot_mapping + torch_npu._npu_reshape_and_cache_siso(key=key, + key_cache=kv_cache, + slot_indices=slots) if attn_metadata.num_prefills > 0: attn_output = torch.empty(num_tokens, @@ -964,12 +1114,15 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): if (attn_metadata.block_tables is None or attn_metadata.block_tables.numel() == 0): assert attn_metadata.attn_mask is not None - mask = attn_metadata.attn_mask assert attn_metadata.prefill_metadata is not None assert attn_metadata.prefill_metadata.seq_lens is not None + mask = attn_metadata.attn_mask self.seq_lens_tensor_cpu = torch.from_numpy( np.array(attn_metadata.prefill_metadata.seq_lens).astype( np.int32)) + k_pe = k_pe.repeat(1, self.num_heads, 1) + key = torch.cat( + [k_nope.view(num_tokens, self.num_heads, -1), k_pe], dim=2) torch_npu._npu_flash_attention( query=query, key=key, @@ -987,29 +1140,55 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): ) elif attn_metadata.decode_metadata: assert kv_cache is not None - # if torch.empty is used here, the preemptive scheduling case of - # test_mtp_correctness.py will fail to run. - attn_output = torch.randn( - [num_tokens, self.num_heads, self.kv_lora_rank], - dtype=query.dtype, - device=query.device) - self.seq_lens_tensor_cpu = torch.from_numpy( - np.array(attn_metadata.decode_metadata.seq_lens).astype( - np.int32)) - block_tables = attn_metadata.decode_metadata.block_tables - torch_npu._npu_paged_attention_mla( - query=query, - key_cache=key_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=block_tables, - context_lens=self.seq_lens_tensor_cpu, - mla_vheadsize=self.kv_lora_rank, - out=attn_output) - attn_output_t = torch.transpose(attn_output, 0, 1) - attn_output_t = torch.bmm(attn_output_t, self.w_vc) - attn_output = torch.transpose(attn_output_t, 0, 1) + if VLLM_ENABLE_GRAPH_MODE == '1': + # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim] + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=self.num_heads, + num_key_value_heads=1, + input_layout="BNSD", + atten_mask=attn_metadata.attn_mask, + scale=self.scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=attn_metadata.block_tables, + block_size=kv_cache[0].shape[1], + actual_seq_lengths_kv=attn_metadata.seq_lens, + ) + attn_output = attn_output.view(num_tokens, -1, + self.kv_lora_rank).transpose( + 0, 1) + attn_output = torch.bmm(attn_output, self.w_vc).transpose(0, 1) + else: + # if torch.empty is used here, the preemptive scheduling case of + # test_mtp_correctness.py will fail to run. + attn_output = torch.randn( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=query.dtype, + device=query.device) + self.seq_lens_tensor_cpu = torch.from_numpy( + np.array(attn_metadata.decode_metadata.seq_lens).astype( + np.int32)) + block_tables = attn_metadata.decode_metadata.block_tables + torch_npu._npu_paged_attention_mla( + query=query, + key_cache=kv_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=block_tables, + context_lens=self.seq_lens_tensor_cpu, + mla_vheadsize=self.kv_lora_rank, + out=attn_output) + attn_output_t = torch.transpose(attn_output, 0, 1) + attn_output_t = torch.bmm(attn_output_t, self.w_vc) + attn_output = torch.transpose(attn_output_t, 0, 1) output, _ = self.o_proj(attn_output.reshape(num_tokens, -1)) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 22e6b35..39878ce 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -24,6 +24,10 @@ import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.worker.gpu_input_batch import InputBatch + +from vllm_ascend.ops.attention import vanilla_chunked_prefill class AscendAttentionBackend(AttentionBackend): @@ -44,6 +48,10 @@ class AscendAttentionBackend(AttentionBackend): def get_state_cls() -> Type["CommonAttentionState"]: return CommonAttentionState + @staticmethod + def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: + return AscendAttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -94,11 +102,11 @@ class AscendAttentionState(Enum): class AscendMetadata: # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) - block_tables: Optional[torch.Tensor] + block_tables: torch.Tensor # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] = None - context_lens: Optional[List[int]] = None + query_lens: torch.Tensor + seq_lens: torch.Tensor # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None # (num_tokens,). The indices of the token slots that input tokens will be @@ -117,6 +125,36 @@ class AscendMetadata: attn_mask: Optional[torch.Tensor] = None +class AscendAttentionMetadataBuilder: + + def __init__(self, runner): + self.runner = runner + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return False + + def build(self, num_reqs, num_actual_tokens, max_query_len, + common_prefix_len): + block_table = ( + self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + query_lens = self.runner.query_lens + seq_lens = self.runner.seq_lens_cpu[:num_reqs] + slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + self.runner.device, non_blocking=True) + attn_mask = self.runner.attn_mask + attn_state = self.runner.attn_state + + attn_metadata = AscendMetadata(block_tables=block_table, + query_lens=query_lens, + seq_lens=seq_lens, + max_query_len=max_query_len, + slot_mapping=slot_mapping, + attn_mask=attn_mask, + attn_state=attn_state) + return attn_metadata + + class AscendAttentionBackendImpl(AttentionImpl): def __init__( @@ -229,29 +267,46 @@ class AscendAttentionBackendImpl(AttentionImpl): out=output) elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: block_tables = attn_metadata.block_tables - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=block_tables, - context_lens=attn_metadata.context_lens, - out=output) + torch_npu._npu_paged_attention(query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=block_tables, + context_lens=attn_metadata.seq_lens, + out=output) # Normal V1 situation. else: - # use paged attention - torch_npu._npu_paged_attention_splitfuse( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - mask=attn_metadata.attn_mask, - block_table=attn_metadata.block_tables, - seq_len=attn_metadata.seq_lens, - context_lens=attn_metadata.context_lens, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - out=output) + # use chunked prefill for head size 192 scenario, like deepseek + # paged_attention_splitfuse maybe crash at such scenario + # TODO: vanilla path will be removed after the kernel support + # head_size 192 scenario + if self.head_size == 192: + cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() + cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist() + cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu") + cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu") + cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) + cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0) + max_seqlen_q = torch.max(attn_metadata.query_lens) + max_seqlen_k = torch.max(attn_metadata.seq_lens) + vanilla_chunked_prefill(output, query, self.key_cache, + self.value_cache, + attn_metadata.block_tables, + cu_seqlen_q, cu_seqlen_k, max_seqlen_q, + max_seqlen_k, self.scale, None, True) + else: + torch_npu._npu_paged_attention_splitfuse( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + mask=attn_metadata.attn_mask, + block_table=attn_metadata.block_tables, + seq_len=attn_metadata.query_lens, + context_lens=attn_metadata.seq_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) return output.view(num_tokens, self.hidden_size) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py new file mode 100644 index 0000000..e5b7e73 --- /dev/null +++ b/vllm_ascend/attention/mla_v1.py @@ -0,0 +1,561 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar + +import torch +import torch_npu +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata, + MLAAttentionImpl) +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla +from vllm_ascend.ops.cache import concat_and_cache_mla +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + +logger = init_logger(__name__) + + +class AscendMLABackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "VLLM_ASCEND_MLA" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return AscendMLAMetadata + + @staticmethod + def get_builder_cls(): + return AscendMLAMetadataBuilder + + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, + head_size: int) -> tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_impl_cls() -> Type["MLAAttentionImpl"]: + return AscendMLAImpl + + +@dataclass +class AscendMLAPrefillMetadata: + """ Prefill Specific Metadata for Ascend""" + attn_mask: torch.Tensor + query_lens: list[int] + context_lens: torch.Tensor + input_positions: torch.Tensor + block_table: torch.Tensor + max_query_len: int + max_context_len: int + + +@dataclass +class AscendMLADecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + + +@dataclass +class AscendMLAMetadata: + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + slot_mapping: torch.Tensor + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + # The dimension of the attention heads + head_dim: Optional[int] = None + attn_mask: torch.Tensor = None + # chunked prefill by default if no attn_states passed + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + decode: Optional[AscendMLADecodeMetadata] = None + prefill: Optional[AscendMLAPrefillMetadata] = None + + def __post_init__(self): + pass + # supported_head_sizes = AscendMLABackend.get_supported_head_sizes() + # if self.head_dim is not None and self.head_dim \ + # not in supported_head_sizes: + # raise ValueError( + # f"Only {supported_head_sizes} are supported for head_dim,", + # f"received {self.head_dim}.") + + +M = TypeVar("M", bound=AscendMLAMetadata) + + +class AscendMLAMetadataBuilder: + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # _attn_mask_builder = None + def __init__(self, + runner: "NPUModelRunner", + metadata_cls: Optional[AscendMLAMetadata] = None): + self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \ + if metadata_cls is not None else AscendMLAMetadata # type: ignore + self.runner = runner + scheduler_config = runner.scheduler_config + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + # self.attn_mask = None + # if AscendMLAMetadataBuilder._attn_mask_builder is None: + # AscendMLAMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len( + # 128, self.runner.model_config.dtype + # ) + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the TritonMLA._forward_decode only supports + # num_tokens = 1 + if num_tokens == 1: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + modified_batch = True + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + self._num_decodes = num_decodes + self._num_prefills = num_prefills + self._num_decode_tokens = num_decode_tokens + self._num_prefill_tokens = num_prefill_tokens + + return modified_batch + + def build(self, + num_reqs: int, + num_actual_tokens: int, + max_query_len: int, + common_prefix_len: Optional[int] = None) -> AscendMLAMetadata: + assert self._num_decodes + self._num_prefills == num_reqs + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.runner.device + block_table = ( + self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + device, non_blocking=True).long() + input_positions = self.runner.positions_cpu[:num_actual_tokens].to( + device, non_blocking=True).long() + + seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] + query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[: + num_reqs] + seq_lens = seq_lens_cpu + max_query_len = query_lens.max().item() + max_context_len = seq_lens.max().item() + + prefill_metadata = None + if self._num_prefills > 0: + reqs_start = self._num_decodes # prefill_start + tokens_start = self._num_decode_tokens + + prefill_metadata = AscendMLAPrefillMetadata( + attn_mask=self.runner.attn_mask, + query_lens=query_lens[tokens_start:], + context_lens=seq_lens[tokens_start:], + input_positions=input_positions[tokens_start:], + block_table=block_table[reqs_start:, ...], + max_query_len=max_query_len, + max_context_len=max_context_len, + ) + + decode_metadata = None + if self._num_decodes > 0: + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions[:self._num_decode_tokens], + block_table=block_table[:self._num_decode_tokens, ...], + seq_lens=seq_lens[:self._num_decode_tokens]) + + return self.metadata_cls( # type: ignore + num_actual_tokens=num_actual_tokens, + slot_mapping=slot_mapping, + head_dim=self.runner.model_config.get_head_size(), + num_decodes=self._num_decodes, + num_decode_tokens=self._num_decode_tokens, + num_prefills=self._num_prefills, + attn_mask=self.runner.attn_mask, + attn_state=self.runner.attn_state, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + +class AscendMLAImpl(MLAAttentionImpl): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + rotary_emb: RotaryEmbedding, + # q_proj should be q_b_proj if q_lora_rank is not None, but from an + # attention backend perspective we rely on the layer to pass in the + # correct matrix + q_proj: ColumnParallelLinear, + kv_b_proj: ColumnParallelLinear, + o_proj: RowParallelLinear, + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + + # Hack for V1 for now to avoid torch library overhead (since we are + # already inside an attention custom op), pull out the forward + # method from the rotary embedding and call it directly + # TODO(lucas): we should probably find a cleaner way to do this + self.rotary_emb = rotary_emb.forward_native + + self.q_proj = q_proj + self.kv_b_proj = kv_b_proj + self.o_proj = o_proj + + # Handle the differences between the flash_attn_varlen from flash_attn + # and the one from vllm_flash_attn. The former is used on RoCM and the + # latter has an additional parameter to control FA2 vs FA3 + # self.flash_attn_varlen_func = flash_attn_varlen_func + # if self.vllm_flash_attn_version is not None: + # self.flash_attn_varlen_func = \ + # functools.partial(flash_attn_varlen_func, + # fa_version=self.vllm_flash_attn_version) + + def _v_up_proj_and_o_proj(self, x): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + return self.o_proj(x)[0] + + # Return `ql_nope`, `q_pe` + def _q_proj_and_k_up_proj(self, x): + q_nope, q_pe = self.q_proj(x)[0]\ + .view(-1, self.num_heads, self.qk_head_dim)\ + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _forward_prefill( + self, + query: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AscendMLAMetadata, + ) -> torch.Tensor: + assert attn_metadata.prefill is not None + + # TODO: enable this compute for flash attention computation + # kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + # -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + # k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + # key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + # v_padded = torch.nn.functional.pad(v, [0, query.shape[-1] - v.shape[-1]], + # value=0) + num_tokens = query.size(0) + attn_output = torch.empty(num_tokens, + self.num_heads, + self.v_head_dim, + dtype=query.dtype, + device=query.device) + # current requests is chunked in prefill, disable flash attention with chunked prefill + vanilla_chunked_prefill_mla( + output=attn_output, + query=query, + kv_cache=kv_c_and_k_pe_cache, + block_tables=attn_metadata.prefill.block_table, + query_lens=attn_metadata.prefill.query_lens, + context_lens=attn_metadata.prefill.context_lens, + kv_b_proj=self.kv_b_proj, + max_query_len=attn_metadata.prefill.max_query_len, + max_context_len=attn_metadata.prefill.max_context_len, + nope_dim=self.qk_nope_head_dim, + rope_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + scale=self.scale, + alibi_slopes=None, + causal=True) + attn_output = attn_output.view( + [num_tokens, self.num_heads * self.v_head_dim]) + return self.o_proj(attn_output)[0] + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AscendMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1) + num_tokens = q.size(0) + attn_output = torch.randn( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=q.dtype, + device=q.device) + torch_npu._npu_paged_attention_mla( + query=q, + key_cache=kv_c_and_k_pe_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.decode.block_table, # type:ignore + context_lens=attn_metadata.decode.seq_lens, # type:ignore + mla_vheadsize=self.kv_lora_rank, + out=attn_output) + return self._v_up_proj_and_o_proj(attn_output) + + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, # query in unified attn + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: M, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + # Restore head dim (for rotary embedding) + k_pe = k_pe.unsqueeze(1) + + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + decode_k_pe = k_pe[:num_decode_tokens] + + prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + + if has_decode: + assert attn_metadata.decode is not None + decode_ql_nope, decode_q_pe = \ + self._q_proj_and_k_up_proj(decode_hs_or_q_c) + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( + attn_metadata.decode.input_positions, decode_q_pe.contiguous(), + decode_k_pe) + + if has_prefill: + assert attn_metadata.prefill is not None + prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ + .view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( + attn_metadata.prefill.input_positions, + prefill_q_pe.contiguous(), prefill_k_pe) + + if kv_cache.numel() > 0: + concat_and_cache_mla(k_c_normed, k_pe, kv_cache, + attn_metadata.slot_mapping.flatten()) + # TODO: replaced back to ascend ops + # key = torch.cat([k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe], dim=2) + # torch_npu._npu_reshape_and_cache_siso( + # key=key, + # key_cache=kv_cache, + # slot_indices=attn_metadata.slot_mapping.flatten()) + + if has_prefill: + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata) + + if has_decode: + output[:num_decode_tokens] = self._forward_decode( + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) + + return output_padded diff --git a/vllm_ascend/distributed/llmdatadist_connector.py b/vllm_ascend/distributed/llmdatadist_connector.py index 6e0d4e5..69c8ce7 100644 --- a/vllm_ascend/distributed/llmdatadist_connector.py +++ b/vllm_ascend/distributed/llmdatadist_connector.py @@ -462,4 +462,4 @@ class LLMDataDistConnector(KVConnectorBase): def close(self, ): self.llm_datadist_engine.data_dist.unlink_clusters([self.cluster], - 5000) \ No newline at end of file + 5000) diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py new file mode 100644 index 0000000..acb5048 --- /dev/null +++ b/vllm_ascend/distributed/parallel_state.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch +from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, + init_model_parallel_group) + +# vllm-ascend will maintain its own EP GroupCoordinator and ETP GroupCoordinator for +# customize parallel solution +_EP: Optional[GroupCoordinator] = None +_ETP: Optional[list[GroupCoordinator]] = None + + +def get_ep_group() -> GroupCoordinator: + assert _EP is not None, ("expert model parallel group is not initialized") + return _EP + + +def get_etp_group() -> GroupCoordinator: + assert _ETP is not None, ( + "expert tensor parallel group is not initialized") + return _ETP + + +def init_ascend_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + expert_tensor_parallel_size: int = 1, + backend: Optional[str] = None, +): + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) + num_expert_parallel_groups: int = expert_tensor_parallel_size + num_expert_tensor_parallel_groups: int = (world_size // + expert_tensor_parallel_size) + + global _EP + assert _EP is None, ("expert parallel group is already initialized") + group_ranks = [] + for i in range(num_expert_parallel_groups): + ranks = list(range(i, world_size, num_expert_parallel_groups)) + group_ranks.append(ranks) + + _EP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="ep") + + group_ranks = [] + global _ETP + assert _ETP is None, ( + "expert tensor parallel group is already initialized") + for i in range(num_expert_tensor_parallel_groups): + ranks = list( + range(i * expert_tensor_parallel_size, + (i + 1) * expert_tensor_parallel_size)) + group_ranks.append(ranks) + + _ETP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="etp") + + +def destory_ascend_model_parallel(): + global _EP + if _EP: + _EP.destroy() + _EP = None + + global _ETP + if _ETP: + _ETP.destroy() + _ETP = None \ No newline at end of file diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index e2c3fd9..3cb497f 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -2,8 +2,15 @@ from vllm import ModelRegistry def register_model(): + from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 + from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 + from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401 from .qwen2_vl import CustomQwen2VLForConditionalGeneration # noqa: F401 + ModelRegistry.register_model( + "DeepSeekMTPModel", + "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP") + ModelRegistry.register_model( "Qwen2VLForConditionalGeneration", "vllm_ascend.models.qwen2_vl:CustomQwen2VLForConditionalGeneration") diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py new file mode 100644 index 0000000..983b7fd --- /dev/null +++ b/vllm_ascend/models/deepseek_mtp.py @@ -0,0 +1,173 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Adapted from vllm/model_executor/models/qwen2_vl.py +# Copyright 2023 The vLLM team. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import \ + VocabParallelEmbedding +from vllm.model_executor.models.deepseek_mtp import ( + DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, + SharedHead) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.sampling_metadata import SamplingMetadata + +from .deepseek_v2 import CustomDeepseekV2DecoderLayer + + +class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix, + model_config, + cache_config, + quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds = torch.where((positions == 0).unsqueeze(-1), + torch.zeros_like(inputs_embeds), + inputs_embeds) + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): CustomDeepSeekMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + + # Note: torch._dynamo.exc.Unsupported: builtin: str + self.layers_list = [ + self.layers[str(idx)] + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + ] + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers_list[current_step_idx]( + input_ids, + positions, + kv_caches[current_step_idx], + attn_metadata, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers_list[current_step_idx] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states), + sampling_metadata) + return logits + + +class CustomDeepSeekMTP(DeepSeekMTP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config.model_config.hf_config + self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + self.sampler = get_sampler() diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index f38206e..1b092e2 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -19,36 +19,77 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# Adapted from -# vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py -"""Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import Optional, Union +# <<<<<<< HEAD +# # Adapted from +# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py +# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py +# """Inference-only DeepseekV2/DeepseekV3 model.""" +# from typing import Optional, Union + +# import torch +# from torch import nn +# from transformers import PretrainedConfig +# from vllm.config import CacheConfig, ModelConfig, VllmConfig +# from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +# from vllm.model_executor.layers.fused_moe import FusedMoE +# from vllm.model_executor.layers.layernorm import RMSNorm +# from vllm.model_executor.layers.linear import ReplicatedLinear +# from vllm.model_executor.layers.logits_processor import LogitsProcessor +# from vllm.model_executor.layers.quantization import QuantizationConfig +# from vllm.model_executor.layers.sampler import get_sampler +# from vllm.model_executor.layers.vocab_parallel_embedding import ( +# ParallelLMHead, VocabParallelEmbedding) +# from vllm.model_executor.models.deepseek_v2 import ( # noqa +# DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, +# DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2MoE) +# ======= + +import os +from typing import Any, Dict, Optional, Union import torch +import torch.distributed as dist from torch import nn from transformers import PretrainedConfig -from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.attention import Attention +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) +from vllm.distributed import (get_dp_group, get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group, tensor_model_parallel_all_reduce) +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.models.deepseek_v2 import ( # noqa - DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, - DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2MoE) +from vllm.model_executor.models.deepseek_v2 import \ + DeepseekV2ForCausalLM # ruff: noqa: E501 +from vllm.model_executor.models.deepseek_v2 import \ + yarn_get_mscale # ruff: noqa: E501 +from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, + DeepseekV2DecoderLayer, + DeepseekV2MLAAttention, + DeepseekV2MLP) from vllm.model_executor.models.utils import ( PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +# >>>>>>> dcd5c73 (Feat: Graph mode for deepseek v2/v3.) from vllm.sequence import IntermediateTensors +from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE -class CustomDeepseekV2MoE(DeepseekV2MoE): + +class CustomDeepseekV2MoE(nn.Module): + + top_k: int def __init__( self, @@ -56,10 +97,15 @@ class CustomDeepseekV2MoE(DeepseekV2MoE): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): - nn.Module.__init__(self) + super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}.") if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " @@ -76,7 +122,7 @@ class CustomDeepseekV2MoE(DeepseekV2MoE): else: self.gate.e_score_correction_bias = None - self.experts = FusedMoE( + self.experts = AscendFusedMoE( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -99,9 +145,248 @@ class CustomDeepseekV2MoE(DeepseekV2MoE): intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=False, + reduce_results=True, prefix=f"{prefix}.shared_experts", ) + CustomDeepseekV2MoE.top_k = config.num_experts_per_tok + + vllm_config = get_current_vllm_config() + self.dp_size = get_dp_group().world_size + batch_size = vllm_config.scheduler_config.max_num_seqs + self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", 0)) == 1 + + params_dtype = torch.get_default_dtype() + self.final_hidden_states = torch.zeros( + [batch_size, config.hidden_size], dtype=params_dtype, device="npu") + self.tp_group = get_tp_group().device_group + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is None: + # for profile run + return hidden_states + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + if (self.tp_size > 1 and self.enable_mc2 + and attn_metadata.num_prefills == 0): + # hidden_states = dist._functional_collectives.reduce_scatter_tensor( + # hidden_states, "sum", scatter_dim=0, group=self.tp_group + # ) + chunks = torch.chunk(hidden_states, + get_tp_group().world_size, + dim=0) + hidden_states = chunks[get_tp_group().rank_in_group] + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + is_prefill = True if attn_metadata.num_prefills > 0 else False + # is_prefill = attn_metadata.num_prefills > 0 + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor + + if self.tp_size > 1: + if self.enable_mc2 and not is_prefill: + dist.all_gather_into_tensor(self.final_hidden_states, + final_hidden_states, self.tp_group) + final_hidden_states = self.final_hidden_states + else: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + if VLLM_ENABLE_GRAPH_MODE == "1": + self.forward = self.forward_torchair + else: + self.forward = self.forward_eager # type: ignore + + def forward_torchair(self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor = None, + attn_metadata=None): + if self.q_lora_rank is not None: + ckq = self.q_a_proj(hidden_states)[0] + hidden_states_or_q_c = self.q_a_layernorm(ckq) + else: + hidden_states_or_q_c = hidden_states + return self.mla_attn(hidden_states_or_q_c, hidden_states, None, + kv_cache, attn_metadata) + + def forward_eager(self, positions: torch.Tensor, + hidden_states: torch.Tensor): + if self.q_lora_rank is not None: + ckq = self.q_a_proj(hidden_states)[0] + hidden_states_or_q_c = self.q_a_layernorm(ckq) + else: + hidden_states_or_q_c = hidden_states + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + return self.mla_attn(hidden_states_or_q_c, + kv_c_normed, + k_pe, + output_shape=hidden_states.shape) + + # def forward( + # self, + # positions: torch.Tensor, + # hidden_states: torch.Tensor, + # # torchair should pass below two parameters + # kv_cache: torch.Tensor = None, + # attn_metadata: AttentionMetadata = None, + # ) -> torch.Tensor: + # if self.q_lora_rank is not None: + # ckq = self.q_a_proj(hidden_states)[0] + # hidden_states_or_q_c = self.q_a_layernorm(ckq) + # else: + # hidden_states_or_q_c = hidden_states + # if VLLM_ENABLE_GRAPH_MODE == '1': + # return self.mla_attn(hidden_states_or_q_c, hidden_states, None, + # kv_cache, attn_metadata) + # else: + # kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + # [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + # kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + # return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, output_shape=hidden_states.shape) + # kv_cache, attn_metadata) class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): @@ -124,8 +409,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) self.layer_idx = layer_idx + # TODO: enable mla in vllm-ascend if model_config.use_mla: - attn_cls = DeepseekV2MLAAttention + attn_cls = CustomDeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention self.self_attn = attn_cls( @@ -180,8 +466,8 @@ class CustomDeepseekV2Model(nn.Module): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - self.config = config + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size if get_pp_group().is_first_rank: diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index 2e59391..1947799 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -18,3 +18,4 @@ import vllm_ascend.ops.activation # noqa import vllm_ascend.ops.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa import vllm_ascend.ops.rotary_embedding # noqa +import vllm_ascend.ops.vocab_parallel_embedding # noqa diff --git a/vllm_ascend/ops/attention.py b/vllm_ascend/ops/attention.py new file mode 100644 index 0000000..4d38255 --- /dev/null +++ b/vllm_ascend/ops/attention.py @@ -0,0 +1,293 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/kernels/test_moe.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +from vllm.model_executor.layers.linear import ColumnParallelLinear + + +# Implementation of vanilla chunked prefill, should be removed after the kernel is ready for +# all the corner case +def vanilla_chunked_prefill( + output: torch.Tensor, + query: torch.Tensor, # (num_tokens, heads, head_size) + key_cache: torch.Tensor, # (num_blocks, block_size, kv_heads, head_size) + value_cache: torch. + Tensor, # (num_blocks, block_size, kv_heads, head_size,) + block_tables: torch.Tensor, # (num_seqs, max_num_blocks_per_seq) + cu_seqlen_q: torch.Tensor, # (num_seqs + 1,) + cu_seqlen_k: torch.Tensor, # (num_seqs + 1,) + max_seqlen_q: int, + max_seqlen_k: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool = True, +) -> None: + num_query_heads = query.shape[1] + head_dim = value_cache.shape[3] + num_kv_heads = value_cache.shape[2] + block_size = value_cache.shape[1] + num_batch = cu_seqlen_q.shape[0] - 1 + max_num_blocks_per_seq = block_tables.shape[1] + + key = key_cache[block_tables].view(num_batch, + max_num_blocks_per_seq * block_size, + num_kv_heads, head_dim) + + value = value_cache[block_tables].view(num_batch, + max_num_blocks_per_seq * block_size, + num_kv_heads, head_dim) + key = key[:, :max_seqlen_k, :, :] + value = value[:, :max_seqlen_k, :, :] + + seqlen_k = cu_seqlen_k[1:] - cu_seqlen_k[:-1] + seqlen_q = cu_seqlen_q[1:] - cu_seqlen_q[:-1] + seqlen_q = seqlen_q.view(-1, 1) + seqlen_k = seqlen_k.view(-1, 1) + seqlen_diff = seqlen_k - seqlen_q + q_idx_mask = (torch.arange(0, max_seqlen_q, + device="npu").view(1, -1).repeat(num_batch, 1)) + k_idx_mask = (torch.arange(0, max_seqlen_k, + device="npu").view(1, -1).repeat(num_batch, 1)) + q_mask = q_idx_mask < seqlen_q + k_mask = k_idx_mask < seqlen_k + + # calculate idx for causal mask of query [batch, max_seqlen_q] + causal_mask_idx = (q_idx_mask + seqlen_diff)[q_mask] + + # generate causal mask [batch, max_seqlen_q, max_seqlen_k] + tril_mask = torch.tril(torch.ones(max_seqlen_k, max_seqlen_k, + device="npu")) + tril_mask[tril_mask == 0] = float("-inf") + tril_mask[tril_mask == 1] = 0 + causal_mask = tril_mask[causal_mask_idx] + causal_mask_padding = torch.empty([num_batch, max_seqlen_q, max_seqlen_k], + device="npu").fill_(float("-inf")) + causal_mask_padding[q_mask] = causal_mask + # to [batch, num_heads, max_seqlen_q, max_seqlen_k] + causal_mask_padding = causal_mask_padding.unsqueeze(1) + + pad_q = torch.zeros( + [num_batch, max_seqlen_q, num_query_heads, head_dim], + device="npu", + dtype=query.dtype, + ) + pad_k = torch.zeros( + [num_batch, max_seqlen_k, num_kv_heads, head_dim], + device="npu", + dtype=key.dtype, + ) + pad_v = torch.zeros( + [num_batch, max_seqlen_k, num_kv_heads, head_dim], + device="npu", + dtype=value.dtype, + ) + pad_q[q_mask] = query + pad_k[k_mask] = key[k_mask] + pad_v[k_mask] = value[k_mask] + + if num_query_heads > num_kv_heads: + pad_k = pad_k.view( + [num_batch, max_seqlen_k, num_kv_heads, 1, head_dim]) + pad_k = pad_k.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view( + [num_batch, max_seqlen_k, num_query_heads, head_dim]) + pad_v = pad_v.view( + [num_batch, max_seqlen_k, num_kv_heads, 1, head_dim]) + pad_v = pad_v.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view( + [num_batch, max_seqlen_k, num_query_heads, head_dim]) + # permute to [b, h, n, k] + pad_q = pad_q.permute(0, 2, 1, 3) + pad_k = pad_k.permute(0, 2, 1, 3) + pad_v = pad_v.permute(0, 2, 1, 3) + attn_mask = torch.empty([num_batch, 1, 1, max_seqlen_k], + device="npu").fill_(float("-inf")) + attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0) + # [b, h, f, t] + attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k) + attn_weights *= scale + attn_mask = attn_mask.float() + attn_weights = attn_weights + attn_mask + if causal: + attn_weights = attn_weights + causal_mask_padding + + attn_weights = torch.softmax(attn_weights, dim=-1) + attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float()) + attn_output = attn_output.permute(0, 2, 1, 3) + + attn_output = (attn_output[q_mask].view([-1, num_query_heads, + head_dim]).to(output.dtype)) + output.copy_(attn_output) + return attn_output + + +def vanilla_chunked_prefill_mla( + output: torch.Tensor, # (num_tokens, num_heads, v_head_dim) + query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim) + kv_cache: torch.Tensor, # (num_blocks, block_size, latent_kv) + block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq) + query_lens: torch.Tensor, # (batch_size) + context_lens: torch.Tensor, # (batch_size) + kv_b_proj: ColumnParallelLinear, # () + max_query_len: int, + max_context_len: int, + nope_dim: int, + rope_dim: int, + v_head_dim: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool = True) -> None: + batch_size = block_tables.size(0) + assert query_lens.size(0) == batch_size + num_heads = query.size(1) + block_size = kv_cache.size(1) + latent_kv_dim = kv_cache.size(3) - rope_dim + max_num_blocks_per_seq = block_tables.size(1) + batch_size = query_lens.size(0) + kv_cache = kv_cache.squeeze() + # select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] + cache_kv_c_pe = kv_cache[block_tables].view( + batch_size, max_num_blocks_per_seq * block_size, + latent_kv_dim + rope_dim)[:, :max_context_len, :] + # get kv_c and k_pe + # cached_kv_c: [batch_size, max_context_len, latent_kv] + # cached_k_pe: [batch_size, max_context_len, rope_dim] + cache_kv_c = cache_kv_c_pe[:, :, :latent_kv_dim] + cache_k_pe = cache_kv_c_pe[:, :, latent_kv_dim:] + # get k_rope and v + # k_nope: [batch_size, max_context_len, num_heads, nope_dim] + # value: [batch_size, max_context_len, num_heads, v_head_dim] + k_nope, value = kv_b_proj(cache_kv_c)[0].view( + batch_size, max_context_len, num_heads, + nope_dim + v_head_dim).split([nope_dim, v_head_dim], dim=-1) + # key: [batch_size, max_context_len, num_hads, rope_dim + nope_dim] + key = torch.cat( + [k_nope, cache_k_pe.unsqueeze(2).expand(-1, -1, num_heads, -1)], + dim=-1) + + context_lens = context_lens.view(-1, 1).to("npu") + query_lens = query_lens.view(-1, 1).to("npu") + seq_diff = context_lens - query_lens + + q_idx_mask = (torch.arange(0, max_query_len, + device="npu").view(1, -1).repeat(batch_size, 1)) + kv_c_idx_mask = (torch.arange(0, max_context_len, + device="npu").view(1, + -1).repeat(batch_size, 1)) + kv_c_mask = kv_c_idx_mask < context_lens + q_mask = q_idx_mask < query_lens + + # calculate idx for causal mask of query [batch, max_seqlen_q] + causal_mask_idx = (q_idx_mask + seq_diff)[q_mask] + + # generate causal mask [batch, max_seqlen_q, max_seqlen_k] + tril_mask = torch.tril( + torch.ones(max_context_len, max_context_len, device="npu")) + tril_mask[tril_mask == 0] = float("-inf") + tril_mask[tril_mask == 1] = 0 + causal_mask = tril_mask[causal_mask_idx] + causal_mask_padding = torch.empty( + [batch_size, max_query_len, max_context_len], + device="npu").fill_(float("-inf")) + causal_mask_padding[q_mask] = causal_mask + # to [batch, num_heads, max_seqlen_q, max_seqlen_k] + causal_mask_padding = causal_mask_padding.unsqueeze(1) + + pad_q = torch.zeros( + [batch_size, max_query_len, num_heads, rope_dim + nope_dim], + device="npu", + dtype=query.dtype, + ) + pad_k = torch.zeros( + [batch_size, max_context_len, num_heads, rope_dim + nope_dim], + device="npu", + dtype=key.dtype, + ) + pad_v = torch.zeros( + [batch_size, max_context_len, num_heads, v_head_dim], + device="npu", + dtype=value.dtype, + ) + pad_q[q_mask] = query + pad_k[kv_c_mask] = key[kv_c_mask] + pad_v[kv_c_mask] = value[kv_c_mask] + + pad_q = pad_q.permute(0, 2, 1, 3) + pad_k = pad_k.permute(0, 2, 1, 3) + pad_v = pad_v.permute(0, 2, 1, 3) + attn_mask = torch.empty([batch_size, 1, 1, max_context_len], + device="npu").fill_(float("-inf")) + attn_mask[:, :, :, :max_context_len].masked_fill_( + kv_c_mask[:, None, None, :], 0) + # [b, h, f, t] + attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k) + attn_weights *= scale + attn_mask = attn_mask.float() + attn_weights = attn_weights + attn_mask + if causal: + attn_weights = attn_weights + causal_mask_padding + + attn_weights = torch.softmax(attn_weights, dim=-1) + attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float()) + attn_output = attn_output.permute(0, 2, 1, 3) + + attn_output = (attn_output[q_mask].view([-1, num_heads, + v_head_dim]).to(output.dtype)) + output.copy_(attn_output) + return attn_output + + +def vanilla_decode_mla( + query: torch.Tensor, # [num_tokens, num_heads, latent_dim + rope_dim] + key_cache: torch. + Tensor, # [num_blocks, block_size, num_kv_heads, latent_dim + rope_dim] + num_kv_heads: int, + num_heads: int, + scale: float, + block_table: torch.Tensor, # [batch_size, max_block_size] + context_lens: List[int], + mla_vhead_size: int, + rope_dim: int, + output: torch.Tensor): + batch_size = block_table.size()[0] + max_block_size = block_table.size()[1] + reduce_dim = key_cache.size()[-1] + block_size = key_cache.size()[1] + latent_dim = reduce_dim - rope_dim + kv_c_and_pe = key_cache[block_table].view( + [batch_size, max_block_size * block_size, num_kv_heads, reduce_dim]) + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, device="npu").view(batch_size, 1) + # [batch_size, max_context_len, num_kv_heads, latent_dim + rope_dim] + # since the kv head is 1 in deepseek, we use expand here for perf + kv_c_and_pe = kv_c_and_pe[:, :max_context_len, :, :].expand( + -1, -1, num_heads, 1) + kv_c = kv_c_and_pe[..., :latent_dim] + kv_idx_mask = (torch.arange(0, max_context_len, + device="npu").view(1, + -1).repeat(batch_size, 1)) + # [batch_size, max_context_len] + kv_idx_mask = kv_idx_mask < context_lens + query = query.unsqueeze(1) + attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, kv_c_and_pe) + attn_weights *= scale + attn_weights = attn_weights + kv_idx_mask[:, -1, -1, :].float() + attn_weights = torch.softmax(attn_weights, dim=-1) + attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights, + kv_c.float()).view(-1, num_heads, latent_dim) + output.copy_(attn_output) + return output diff --git a/vllm_ascend/ops/cache.py b/vllm_ascend/ops/cache.py new file mode 100644 index 0000000..d4bd08b --- /dev/null +++ b/vllm_ascend/ops/cache.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/kernels/test_moe.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def concat_and_cache_mla( + kv_c_normed: torch.Tensor, # [num_tokens, num_kv_head, nope] + k_pe: torch.Tensor, # [num_tokens, num_kv_head, rope] + kv_cache: torch. + Tensor, # [num_blocks, block_size, num_kv_head, nope + rope] + slot_mapping, # [num_tokens] +): + num_blocks = kv_cache.size()[0] + block_size = kv_cache.size()[1] + num_kv_head = k_pe.size()[1] + + idx_for_copy = slot_mapping // block_size * block_size + slot_mapping % block_size + kv_cache = kv_cache.view(num_blocks * block_size, num_kv_head, -1) + kv_cache[idx_for_copy] = torch.cat([kv_c_normed.unsqueeze(1), k_pe], + dim=-1) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index f1158be..481112d 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -15,12 +15,131 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/kernels/test_moe.py +import os from typing import Callable, Optional import torch +import torch.distributed as dist import torch_npu -from vllm.model_executor.layers.fused_moe.layer import \ - UnquantizedFusedMoEMethod +from vllm.config import get_current_vllm_config +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import get_dp_group +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) +from vllm.model_executor.layers.quantization.base_config import \ + QuantizeMethodBase + +from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group + + +def fused_experts_with_mc2( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: Optional[str] = None, +) -> torch.Tensor: + global_bs = 0 + moe_expert_num = len(expert_map) + kwargs = { + "x": hidden_states, + "expert_ids": topk_ids, + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": global_bs, + } + + rank = torch.distributed.get_rank() + + quant_mode = 0 + ep_group = get_ep_group().device_group + local_rank = torch.distributed.get_rank(group=ep_group) + all_to_all_group_size = torch.distributed.get_world_size(ep_group) + + world_szie = torch.distributed.get_world_size() + tp_size = world_szie // all_to_all_group_size + tp_rank = rank % tp_size + + stage1_kwargs = { + "scales": None, + "quant_mode": quant_mode, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": all_to_all_group_size, + "ep_rank_id": local_rank, + # "group_tp": self.moe_rs_group_name, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": tp_size, + "tp_rank_id": tp_rank, + } + kwargs.update(stage1_kwargs) + + output = torch_npu.npu_moe_distribute_dispatch(**kwargs) + # comm_stream.wait_stream(torch.npu.current_stream()) + expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ + 0:5] + + w1 = w1.transpose(1, 2) + expert_token_nums = torch.cumsum(expert_token_nums, + dim=0, + dtype=torch.int64) + group_list = expert_token_nums.to(torch.int64) + gate_up_out_list = torch_npu.npu_grouped_matmul( + x=[expand_x], + weight=[w1], + split_item=2, + group_list_type=0, + group_type=0, + group_list=group_list, + ) + + # TODO: Remove this in the future. + gate_up_out = torch.cat(gate_up_out_list, dim=0) + gate_up_out = torch_npu.npu_swiglu(gate_up_out) + + w2 = w2.transpose(1, 2) + down_out_list = torch_npu.npu_grouped_matmul( + x=[gate_up_out], + weight=[w2], + split_item=2, + group_list_type=0, + group_type=0, + group_list=group_list, + ) + + down_out_list = torch.cat(down_out_list, dim=0) + + # moeCombine + kwargs = { + "expand_x": down_out_list, + "expert_ids": topk_ids, + "expand_idx": expand_idx, + "expert_scales": topk_weights.to(torch.float32), + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": 0, + } + tp_recv_counts = output[5] + stage3_kwargs = { + "ep_send_counts": ep_recv_counts, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": all_to_all_group_size, + "ep_rank_id": local_rank, + "tp_send_counts": tp_recv_counts, + # "group_tp": self.moe_rs_group_name, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": tp_size, + "tp_rank_id": tp_rank, + } + kwargs.update(stage3_kwargs) + + hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs) + + return hidden_states def fused_experts( @@ -47,22 +166,27 @@ def fused_experts( Returns: hidden_states: Hidden states after routing. """ + """ # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" + """ + # if torch.distributed.get_rank() == 0: + # print(w1.shape) + # print(hidden_states.shape) original_shape = hidden_states.shape - assert len(original_shape) == 2 + # assert len(original_shape) == 2 num_tokens = hidden_states.shape[:-1].numel() num_experts = w1.shape[0] dtype = hidden_states.dtype device = hidden_states.device - assert dtype in [torch.float32, torch.float16, torch.bfloat16 - ], "Only float32, float16, and bfloat16 are supported" + # assert dtype in [torch.float32, torch.float16, torch.bfloat16 + # ], "Only float32, float16, and bfloat16 are supported" if expert_map is not None: # Generate token indices and flatten @@ -152,11 +276,18 @@ def fused_experts( final_hidden_states = torch.zeros(*original_shape, device=hidden_states.device, dtype=dtype) - final_hidden_states.index_add_(0, sorted_token_indices, - weighted_down_out) - # TODO: This should not happen! Look into it! - # fill nan with 0.0 - final_hidden_states[torch.isnan(final_hidden_states)] = 0.0 + + # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...] + # This created multiple NaN and index_add_ will mix them up which harms accracy + # remove this mask and filter after it being fixed + num_valid_tokens = mask.sum() + valid_token_mask = torch.arange( + 0, sorted_token_indices.shape[0], + device=device).unsqueeze(1) < num_valid_tokens + valid_output = torch.where( + valid_token_mask, weighted_down_out, + torch.zeros_like(weighted_down_out)).to(dtype) + final_hidden_states.index_add_(0, sorted_token_indices, valid_output) else: # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. @@ -199,16 +330,17 @@ def native_grouped_topk( def select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: Optional[bool] = True ) -> tuple[torch.Tensor, torch.Tensor]: """ Select top-k experts based on router logits. @@ -232,8 +364,23 @@ def select_experts( Raises: ValueError: If an unsupported scoring function is provided. """ - assert hidden_states.shape[0] == router_logits.shape[0], ( - "Number of tokens mismatch") + # assert hidden_states.shape[0] == router_logits.shape[0], ( + # "Number of tokens mismatch") + # if os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1" and not is_prefill: + # topk_weight, topk_idx, _ = torch.ops.npu_inference.npu_moe_gating_top_k( + # router_logits, + # k=top_k, # topk当前写8 + # bias=e_score_correction_bias, + # k_group=topk_group, # fix: 4 + # group_count=num_expert_group, # fix 8 + # group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + # renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + # norm_type=1, # 0: softmax; 1: sigmoid(fix) + # # out_flag=False, # todo new api; 第三个输出是否输出 + # # y2_flag=False, # old api; 第三个输出是否输出 + # routed_scaling_factor=1, + # eps=float(1e-20)) + # return topk_weight, topk_idx if custom_routing_function is not None: raise NotImplementedError( @@ -261,14 +408,16 @@ def select_experts( # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) topk_weights = native_grouped_topk(topk_weights, num_expert_group, topk_group) - + # TODO bfloat16 is not supported in torch.topk with ge graph. if e_score_correction_bias is not None: - topk_ids = torch.topk(topk_weights, k=top_k, dim=-1, + topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, sorted=False)[1] # Use original unbiased scores for the routing weights topk_weights = original_weights.gather(1, topk_ids) else: - topk_weights, topk_ids = torch.topk(topk_weights, + topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False) @@ -285,46 +434,245 @@ def select_experts( return topk_weights, topk_ids -def forward_oot( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - **kwargs, -): - assert router_logits.shape[ - 1] == global_num_experts, "Number of global experts mismatch" +class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + def __init__(self): + super().__init__() + vllm_config = get_current_vllm_config() - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map) + ep_group = get_ep_group() + self.ep_size = ep_group.world_size + self.global_batch_size = vllm_config.scheduler_config.max_num_seqs + self.local_batch_size = self.global_batch_size // self.ep_size + + try: + device_group = ep_group.device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) + except AttributeError: + self.moe_all_to_all_group_name = None + + def process_weights_after_loading(self, layer): + super(UnquantizedFusedMoEMethod, + self).process_weights_after_loading(layer) + layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( + layer.w13_weight.data), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( + layer.w2_weight.data), + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill=False, + **kwargs, + ): + # assert router_logits.shape[ + # 1] == global_num_experts, "Number of global experts mismatch" + # set prefill as false always, should fix this + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + is_prefill=is_prefill) + + if os.environ.get("VLLM_ENABLE_MC2") == "1" and not is_prefill: + return fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name) + else: + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map) -UnquantizedFusedMoEMethod.forward_oot = forward_oot +class AscendFusedMoE(FusedMoE): + + def __init__(self, + num_experts, + top_k, + hidden_size, + intermediate_size, + params_dtype=None, + reduce_results=False, + renormalize=True, + use_grouped_topk=False, + num_expert_group=None, + topk_group=None, + quant_config=None, + tp_size=None, + ep_size=None, + dp_size=None, + prefix="", + custom_routing_function=None, + scoring_func="softmax", + e_score_correction_bias=None, + activation="silu"): + super(FusedMoE, self).__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.ep_size = get_ep_group().world_size + self.tp_size = get_etp_group().world_size + self.dp_size = (dp_size + if dp_size is not None else get_dp_group().world_size) + self.dp_rank = (0 + if self.dp_size == 1 else get_dp_group().rank_in_group) + + self.top_k = top_k + self.num_experts = num_experts + self.global_num_experts = num_experts + assert intermediate_size % self.tp_size == 0 + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.custom_routing_function = custom_routing_function + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias + self.expert_map = None + self.activation = activation + + if self.ep_size > 1: + # Create a tensor of size num_experts filled with -1 + self.local_num_experts, self.expert_map = determine_expert_map( + self.ep_size, + get_ep_group().rank_in_group, self.global_num_experts) + self.tp_rank = get_etp_group().rank_in_group + self.ep_rank = get_ep_group().rank_in_group + else: + # Adjust TP size for DP attention + # haven't test its functionality yet, may remove in the future + self.tp_rank = self.tp_size * self.dp_rank + self.ep_rank = 0 + self.tp_size = self.tp_size * self.dp_size + self.ep_size = 1 + self.local_num_experts = self.global_num_experts + self.expert_map = None + + if self.scoring_func != "softmax" and not self.use_grouped_topk: + raise ValueError("Only softmax scoring function is supported for " + "non-grouped topk.") + + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = ( + AscendUnquantizedFusedMoEMethod()) + else: + self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method is not None + + local_num_experts = torch.sum(self.expert_map != -1) \ + if self.expert_map is not None else num_experts + + moe_quant_params = { + "num_experts": local_num_experts, + "hidden_size": hidden_size, + "intermediate_size_per_partition": + self.intermediate_size_per_partition, + "params_dtype": params_dtype, + "weight_loader": self.weight_loader, + } + # need full intermediate size pre-sharding for WNA16 act order + if (self.quant_method.__class__.__name__ + in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): + moe_quant_params["intermediate_size_full"] = intermediate_size + + self.quant_method.create_weights(layer=self, **moe_quant_params) + + def forward(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + top_k=None): + assert self.quant_method is not None + + if top_k: + real_top_k = top_k + else: + real_top_k = self.top_k + + if self.dp_size > 1: + if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore + ) == 1 and not is_prefill: + ... + elif int(os.environ.get("USING_LCCL_COM")) == 1: # type: ignore + hidden_states = get_dp_group().all_gather( + hidden_states, 0, False) + router_logits = get_dp_group().all_gather( + router_logits, 0, False) + else: + hidden_states = get_dp_group().all_gather(hidden_states, 0) + router_logits = get_dp_group().all_gather(router_logits, 0) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=real_top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + is_prefill=is_prefill) + + if self.dp_size > 1: + if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore + ) == 1 and not is_prefill: + ... + else: + final_hidden_states = dist._functional_collectives.reduce_scatter_tensor( + final_hidden_states, + "sum", + scatter_dim=0, + group=get_dp_group().device_group) + + # if self.reduce_results and self.tp_size > 1: + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 5fc30ee..5e6af79 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -15,6 +15,7 @@ # This file is a part of the vllm-ascend project. # +import math from typing import Optional, Tuple import torch @@ -38,7 +39,7 @@ def rope_forward_oot( if self.cos_sin_cache.dtype != query.dtype: self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) # adopt custom kernel path for rotary_embedding - if CUSTOM_OP_ENABLED and self.is_neox_style: + if CUSTOM_OP_ENABLED and self.is_neox_style and self.head_size % 32 == 0: return torch.ops._C.rotary_embedding( positions, query, @@ -66,5 +67,169 @@ def rope_forward_oot( return query.view(query_shape), key.view(key_shape) +def native_rope_deepseek_forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, +): + # seq_len = positions.max() + 1 + seq_len = self.max_position_embeddings + + # x: [bs, num_attention_heads, seq_len, head_size] + # if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + # self._set_cos_sin_cache(seq_len=seq_len, device=query.device, dtype=query.dtype) + self._set_cos_sin_cache(seq_len=seq_len, + device=query.device, + dtype=query.dtype) + + cos = self.cos_cached[:seq_len].to(dtype=query.dtype) + sin = self.sin_cached[:seq_len].to(dtype=query.dtype) + + q_pe, k_pe = apply_rotary_pos_emb(query, key, cos, sin, positions) + + return q_pe, k_pe + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim(num_rotations, + dim, + base=10000, + max_position_embeddings=2048): + # Note: use torch instead of math to solve MTP compilation error. + return (dim * torch.log( + torch.tensor(max_position_embeddings) / + (num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base))) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +# Find dim range bounds based on rotations +def yarn_find_correction_range(low_rot, + high_rot, + dim, + base=10000, + max_position_embeddings=2048): + # Note: use torch instead of math to solve MTP compilation error. + low = torch.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = torch.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + # Note: use torch instead of max/min to solve MTP compilation error. + return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1) + + +def yarn_linear_ramp_mask(min_value, max_value, dim): + # Note: The if conditional branch is not used here + # to solve MTP compilation error. + max_value += (min_value == max_value).float() * 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - + min_value) / (max_value - min_value) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids] + sin = sin[position_ids] + cos = cos[:, None, None, :] + sin = sin[:, None, None, :] + + if len(q.shape) == 3: + q = q[:, :, None, :] + if len(k.shape) == 2: + k = k[:, None, None, :] + elif len(k.shape) == 3: + k = k[:, :, None, :] + + b, h_q, s, d = q.shape + q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d) + + b, h_k, s, d = k.shape + k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + q_embed = q_embed.view(b, h_q, d) + k_embed = k_embed.view(b, h_k, d) + + return q_embed, k_embed + + +def _set_cos_sin_cache(self, seq_len, device, dtype): + seq_len = self.max_position_embeddings + self.max_seq_len_cached = seq_len + dim = self.rotary_dim + + freq_extra = 1.0 / (self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freq_inter = 1.0 / (self.scaling_factor * self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + # _mscale = float( + # yarn_get_mscale(self.scaling_factor, self.mscale) + # / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + # ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), + persistent=False) + self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), + persistent=False) + + +# TODO: Patch when aclnn ops avaiable RotaryEmbedding.forward_oot = rope_forward_oot -DeepseekScalingRotaryEmbedding.forward = rope_forward_oot +# DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot +DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward +DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache +DeepseekScalingRotaryEmbedding.max_seq_len_cached = None diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py new file mode 100644 index 0000000..b326f0c --- /dev/null +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Tuple + +import torch +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.model_executor.layers.vocab_parallel_embedding import \ + VocabParallelEmbedding + + +def get_masked_input_and_mask( + input_: torch.Tensor, org_vocab_start_index: int, + org_vocab_end_index: int, num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + # torch.compile will fuse all of the pointwise ops below + # into a single kernel, making it very fast + org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < + org_vocab_end_index) + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index) + added_offset = added_vocab_start_index - ( + org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding + valid_offset = (org_vocab_start_index * + org_vocab_mask) + (added_offset * added_vocab_mask) + vocab_mask = org_vocab_mask | added_vocab_mask + input_ = vocab_mask * (input_ - valid_offset) + return input_, ~vocab_mask + + +def vocab_parallel_embedding_forward(self, input_): + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = get_masked_input_and_mask( + input_, self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output + + +VocabParallelEmbedding.forward = vocab_parallel_embedding_forward diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index e69de29..2ed088b 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -0,0 +1,16 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# \ No newline at end of file diff --git a/vllm_ascend/patch/platform/patch_0_8_4/__init__.py b/vllm_ascend/patch/platform/patch_0_8_4/__init__.py index d1c5ac2..cdbe66f 100644 --- a/vllm_ascend/patch/platform/patch_0_8_4/__init__.py +++ b/vllm_ascend/patch/platform/patch_0_8_4/__init__.py @@ -15,4 +15,5 @@ # limitations under the License. # -import vllm_ascend.patch.platform.patch_0_8_4.patch_config # noqa \ No newline at end of file +import vllm_ascend.patch.platform.patch_0_8_4.patch_config # noqa +import vllm_ascend.patch.platform.patch_0_8_4.patch_distributed # noqa \ No newline at end of file diff --git a/vllm_ascend/patch/platform/patch_0_8_4/patch_distributed.py b/vllm_ascend/patch/platform/patch_0_8_4/patch_distributed.py new file mode 100644 index 0000000..83f4a53 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_0_8_4/patch_distributed.py @@ -0,0 +1,138 @@ +import torch +import vllm +import vllm.distributed +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import (Backend, PrefixStore, + _get_default_timeout, + is_nccl_available) +from torch.distributed.rendezvous import rendezvous +from vllm.config import ParallelConfig + + +def ascend_destroy_model_parallel(): + """Set the groups to none and destroy them.""" + from vllm.distributed.parallel_state import _DP, _PP, _TP + if _TP: + _TP.destroy() + _TP = None + + if _PP: + _PP.destroy() + _PP = None + + if _DP: + _DP.destroy() + _DP = None + from vllm.platforms import current_platform + current_platform.destroy_platform_model_parallel() + + +def ascend_stateless_init_torch_distributed_process_group( + host: str, port: int, rank: int, world_size: int, + backend: str) -> ProcessGroup: + """ + A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. The created ProcessGroup object can be used for + some operations such as `allreduce`, because it does not depend on the + global rank. However, some operations such as `broadcast` cannot be used + because it depends on the global rank. + + # TODO: ask for help from PyTorch team if we need the `broadcast` operation. + + This function is useful when we are not sure about the total number of + processes in the process group. For example, we may have process + 1, 2, ..., 8 who want to communicate, and process 9 might be the same + process as process 1, or it might be a different process; process 10 + might be the same process as process 5, or it might be a different process. + In this case, how can we reliably form a communication channel within + process 9 and 10, without affecting the communication channel within + process 1, 2, ..., 8? + + One possible solution is to figure out if process 9 and 10 are the same + as process 1 and 5 beforehand, and then form a communication channel + based on the information, adjusting the ranks and world_size etc. However, + figuring out the information is not always easy, and it will interfere + with the main communication channel. + + Our solution is to always form a communication channel with process 1, 2, + ..., 8, and then use this function to form another communication channel + with process 9 and 10. This way, regardless of whether process 9 and 10 + are the same as process 1 and 5, the main communication channel is + always formed with process 1, 2, ..., 8, and the additional communication + channel is formed with process 9 and 10. + """ + init_method = f"tcp://{host}:{port}" + backend = Backend(backend) # it is basically string + timeout = _get_default_timeout(backend) + + store, rank, world_size = next( + rendezvous(init_method, rank, world_size, timeout=timeout)) + store.set_timeout(timeout) + + group_rank = rank + group_size = world_size + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + prefix_store = PrefixStore(init_method, store) + + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + from vllm.platforms import current_platform + if backend == "gloo": + from torch.distributed.distributed_c10d import ProcessGroupGloo + backend_class = ProcessGroupGloo(prefix_store, + group_rank, + group_size, + timeout=timeout) + backend_type = ProcessGroup.BackendType.GLOO + device = torch.device("cpu") + elif backend == "nccl": + assert is_nccl_available() + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, + backend_options) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + elif current_platform.platform_has_backend_register(): + current_platform.platform_register_backend() + return pg + else: + raise RuntimeError(f"Unsupported torch distributed backend: {backend}") + + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + + return pg + + +def parallel_config_get_dp_port(self) -> int: + """ + We might need to initialize process groups in multiple + processes that is related to data parallelism, + e.g. both in the worker and in the engine, which + can live in different processes. To avoid port conflicts, we + increment the port number each time we need to initialize a + new process group related to data parallelism. + """ + answer = self.data_parallel_master_port + self.data_parallel_master_port += 1 + import os + + # NOTE: Get port from envs directly when using torchrun + port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore + return port + + +vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel +vllm.distributed.stateless_init_torch_distributed_process_group = ascend_stateless_init_torch_distributed_process_group +ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port diff --git a/vllm_ascend/patch/platform/patch_main/__init__.py b/vllm_ascend/patch/platform/patch_main/__init__.py index 2ed088b..d430dbe 100644 --- a/vllm_ascend/patch/platform/patch_main/__init__.py +++ b/vllm_ascend/patch/platform/patch_main/__init__.py @@ -13,4 +13,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# \ No newline at end of file +# +import vllm_ascend.patch.platform.patch_main.patch_distributed # noqa F401 \ No newline at end of file diff --git a/vllm_ascend/patch/platform/patch_main/patch_distributed.py b/vllm_ascend/patch/platform/patch_main/patch_distributed.py new file mode 100644 index 0000000..83f4a53 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_main/patch_distributed.py @@ -0,0 +1,138 @@ +import torch +import vllm +import vllm.distributed +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import (Backend, PrefixStore, + _get_default_timeout, + is_nccl_available) +from torch.distributed.rendezvous import rendezvous +from vllm.config import ParallelConfig + + +def ascend_destroy_model_parallel(): + """Set the groups to none and destroy them.""" + from vllm.distributed.parallel_state import _DP, _PP, _TP + if _TP: + _TP.destroy() + _TP = None + + if _PP: + _PP.destroy() + _PP = None + + if _DP: + _DP.destroy() + _DP = None + from vllm.platforms import current_platform + current_platform.destroy_platform_model_parallel() + + +def ascend_stateless_init_torch_distributed_process_group( + host: str, port: int, rank: int, world_size: int, + backend: str) -> ProcessGroup: + """ + A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. The created ProcessGroup object can be used for + some operations such as `allreduce`, because it does not depend on the + global rank. However, some operations such as `broadcast` cannot be used + because it depends on the global rank. + + # TODO: ask for help from PyTorch team if we need the `broadcast` operation. + + This function is useful when we are not sure about the total number of + processes in the process group. For example, we may have process + 1, 2, ..., 8 who want to communicate, and process 9 might be the same + process as process 1, or it might be a different process; process 10 + might be the same process as process 5, or it might be a different process. + In this case, how can we reliably form a communication channel within + process 9 and 10, without affecting the communication channel within + process 1, 2, ..., 8? + + One possible solution is to figure out if process 9 and 10 are the same + as process 1 and 5 beforehand, and then form a communication channel + based on the information, adjusting the ranks and world_size etc. However, + figuring out the information is not always easy, and it will interfere + with the main communication channel. + + Our solution is to always form a communication channel with process 1, 2, + ..., 8, and then use this function to form another communication channel + with process 9 and 10. This way, regardless of whether process 9 and 10 + are the same as process 1 and 5, the main communication channel is + always formed with process 1, 2, ..., 8, and the additional communication + channel is formed with process 9 and 10. + """ + init_method = f"tcp://{host}:{port}" + backend = Backend(backend) # it is basically string + timeout = _get_default_timeout(backend) + + store, rank, world_size = next( + rendezvous(init_method, rank, world_size, timeout=timeout)) + store.set_timeout(timeout) + + group_rank = rank + group_size = world_size + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + prefix_store = PrefixStore(init_method, store) + + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + from vllm.platforms import current_platform + if backend == "gloo": + from torch.distributed.distributed_c10d import ProcessGroupGloo + backend_class = ProcessGroupGloo(prefix_store, + group_rank, + group_size, + timeout=timeout) + backend_type = ProcessGroup.BackendType.GLOO + device = torch.device("cpu") + elif backend == "nccl": + assert is_nccl_available() + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, + backend_options) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + elif current_platform.platform_has_backend_register(): + current_platform.platform_register_backend() + return pg + else: + raise RuntimeError(f"Unsupported torch distributed backend: {backend}") + + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + + return pg + + +def parallel_config_get_dp_port(self) -> int: + """ + We might need to initialize process groups in multiple + processes that is related to data parallelism, + e.g. both in the worker and in the engine, which + can live in different processes. To avoid port conflicts, we + increment the port number each time we need to initialize a + new process group related to data parallelism. + """ + answer = self.data_parallel_master_port + self.data_parallel_master_port += 1 + import os + + # NOTE: Get port from envs directly when using torchrun + port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore + return port + + +vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel +vllm.distributed.stateless_init_torch_distributed_process_group = ascend_stateless_init_torch_distributed_process_group +ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index fad885e..174f84b 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -160,6 +160,8 @@ class NPUPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): + if use_v1 and use_mla: + return "vllm_ascend.attention.mla_v1.AscendMLABackend" if use_v1: return "vllm_ascend.attention.attention_v1.AscendAttentionBackend" if use_mla: @@ -191,3 +193,30 @@ class NPUPlatform(Platform): model configuration. """ return True + + @classmethod + def destroy_platform_model_parallel(cls) -> None: + from vllm_ascend.distributed.parallel_state import \ + destory_ascend_model_parallel + destory_ascend_model_parallel() + + @classmethod + def platform_has_backend_register(cls) -> bool: + return True + + @classmethod + def platform_register_backend(cls, pg, prefix_store, group_rank, + group_size, backend_options, + timeout) -> None: + from torch.distributed import ProcessGroup, is_hccl_available + assert is_hccl_available() + import torch_npu # noqa + from torch_npu._C._distributed_c10d import ProcessGroupHCCL + backend_options = ProcessGroupHCCL.Options() + backend_options._timeout = timeout + backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, + backend_options) + device = torch.device("npu") + backend_class._set_sequence_number_for_group() + backend_type = ProcessGroup.BackendType.CUSTOM + pg._register_backend(device, backend_type, backend_class) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 96d3300..da5f8a9 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -33,9 +33,7 @@ from vllm.model_executor.layers.quantization import \ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) +from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs from .quantizer import AscendQuantizer @@ -171,12 +169,10 @@ class AscendLinearMethod(LinearMethodBase): output_size_per_partition, params_dtype) for weight_name, weight_param in weight_dict.items(): - layer.register_parameter( - weight_name, - ModelWeightParameter(data=weight_param, - input_dim=1, - output_dim=0, - weight_loader=weight_loader)) + param = torch.nn.Parameter(weight_param, requires_grad=False) + set_weight_attrs(param, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter(weight_name, param) + set_weight_attrs(param, extra_weight_attrs) pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) for pertensor_name, pertensor_param in pertensor_dict.items(): @@ -189,11 +185,10 @@ class AscendLinearMethod(LinearMethodBase): perchannel_dict = self.quant_method.get_perchannel_param( output_size_per_partition, params_dtype) for perchannel_name, perchannel_param in perchannel_dict.items(): - layer.register_parameter( - perchannel_name, - ChannelQuantScaleParameter(data=perchannel_param, - output_dim=0, - weight_loader=weight_loader)) + param = torch.nn.Parameter(perchannel_param, requires_grad=False) + set_weight_attrs(param, {"output_dim": 0}) + layer.register_parameter(perchannel_name, param) + set_weight_attrs(param, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): @@ -264,48 +259,6 @@ class AscendKVCacheMethod(BaseKVCacheMethod): seq_lens_tensor_cpu=seq_lens_tensor_cpu) -def fused_moe_perchannel_weight_loader(param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, shard_id: str, - expert_id: int) -> None: - - if shard_id not in ("w1", "w2", "w3"): - raise ValueError(f"shard_id must be ['w1','w2','w3'] but " - f"got {shard_id}.") - - # Fetch the dim to shard the parameter/loaded weight - # based on the shard id. This will be whatever - # dimension intermediate_size_per_partition is used. - SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} - - expert_data = param.data[expert_id] - tp_rank = get_tensor_model_parallel_rank() - - # is_transposed: if the dim to shard the weight - # should be flipped. Required by GPTQ, compressed-tensors - # should be whatever dimension intermediate_size_per_partition is - is_transposed = getattr(param, "is_transposed", False) - shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] - if is_transposed: - shard_dim = int(not shard_dim) - - if shard_id == "w2": - expert_data.copy_(loaded_weight) - elif shard_id in ("w1", "w3"): - shard_size = expert_data.shape[shard_dim] // 2 - loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, - shard_size) - # Narrow parameter and load. - # w1, gate_proj: Load into first logical weight of w13. - if shard_id == "w1": - expert_data = expert_data.narrow(shard_dim, 0, shard_size) - # w3, up_proj: Load into second logical weight of w13. - else: - assert shard_id == "w3" - expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) - expert_data.copy_(loaded_weight) - - class AscendFusedMoEMethod(FusedMoEMethodBase): """FusedMoE method for Ascend quantization. @@ -341,9 +294,6 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) - # load `offset` weight in `fused_moe_perchannel_weight_loader`, the original weight load in vllm 0.7.3 could only load `scale` and `zero` - extra_weight_attrs.update( - {"weight_loader": fused_moe_perchannel_weight_loader}) dynamic_quant_param = self.quant_method.get_dynamic_quant_param( num_experts, intermediate_size_per_partition, hidden_size, params_dtype) @@ -360,15 +310,19 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): top_k: int, router_logits: torch.Tensor, renormalize: bool, + global_num_experts: int, + expert_map: torch.Tensor, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + is_prefill: bool = True, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None ) -> torch.Tensor: return self.quant_method.apply(layer, x, use_grouped_topk, top_k, router_logits, renormalize, topk_group, - num_expert_group, + num_expert_group, global_num_experts, + expert_map, is_prefill, custom_routing_function, scoring_func, e_score_correction_bias) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index dfd0f68..54279d0 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -16,6 +16,8 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm-project/vllm/vllm/worker/worker.py # +import os + import torch import torch_npu # noqa: F401 from packaging.version import Version @@ -23,6 +25,8 @@ from vllm.logger import logger import vllm_ascend.envs as envs +VLLM_ENABLE_GRAPH_MODE = os.environ.get('VLLM_ENABLE_GRAPH_MODE', '0') + def try_register_lib(lib_name: str, lib_info: str = ""): import importlib diff --git a/vllm_ascend/worker/__init__.py b/vllm_ascend/worker/__init__.py index e69de29..ee59a05 100644 --- a/vllm_ascend/worker/__init__.py +++ b/vllm_ascend/worker/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import vllm_ascend.worker.cache_engine # noqa \ No newline at end of file diff --git a/vllm_ascend/worker/cache_engine.py b/vllm_ascend/worker/cache_engine.py new file mode 100644 index 0000000..018a66b --- /dev/null +++ b/vllm_ascend/worker/cache_engine.py @@ -0,0 +1,69 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/model_runner.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Tuple + +import torch +from vllm.utils import is_pin_memory_available +from vllm.worker.cache_engine import CacheEngine + +from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE + + +def allocate_kv_cache( + self, + num_blocks: int, + device: str, +) -> List[Tuple]: + """Allocates KV cache on the specified device.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + pin_memory = is_pin_memory_available() if device == "cpu" else False + kv_cache: List[Tuple] = [] + + # Align entries so they are 256 byte aligned for better performance + # Primarily targets MLA as this typically only ends up having entries + # be 128 byte aligned. + alloc_shape = kv_cache_shape + + for _ in range(self.num_attention_layers): + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # We zero-out everything for simplicity. + layer_kv_cache_nope = torch.zeros( + alloc_shape[:-1] + + (self.model_config.hf_text_config.kv_lora_rank, ), + dtype=self.dtype, + pin_memory=pin_memory, + device=device) + layer_kv_cache_pe = torch.zeros( + alloc_shape[:-1] + + (self.model_config.hf_text_config.qk_rope_head_dim, ), + dtype=self.dtype, + pin_memory=pin_memory, + device=device) + + # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases + # when entry_shape is higher than 1D + kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe)) + return kv_cache + + +if VLLM_ENABLE_GRAPH_MODE == '1': + CacheEngine._allocate_kv_cache = allocate_kv_cache \ No newline at end of file diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index 67a055f..d30070c 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -18,18 +18,21 @@ # import dataclasses +import itertools import weakref from contextlib import contextmanager from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union) +import numpy as np import torch import torch.nn as nn import torch_npu +import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.utils import CommonAttentionState -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.forward_context import set_forward_context @@ -53,7 +56,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, flatten_2d_lists, - is_pin_memory_available) + is_pin_memory_available, supports_dynamo) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -72,6 +75,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend TModelInputForNPU = TypeVar('TModelInputForNPU', bound="ModelInputForNPU") +ENCODER_NUM = 0 @dataclass(frozen=True) @@ -526,6 +530,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): seq_lens = [] max_decode_seq_len = 0 + is_prompt = self.inter_data_list[0].is_prompt for inter_data in self.inter_data_list: seq_lens.extend(inter_data.seq_lens) if not inter_data.is_prompt: @@ -540,7 +545,26 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): for data in self.inter_data_list } - input_tokens_tensor = torch.tensor(flatten_2d_lists(input_tokens), + # Add graph_pad_size here + if self.runner.vllm_config.compilation_config.level ==\ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + graph_pad_size = self.runner.scheduler_config.max_num_seqs - len( + seq_lens) + else: + graph_pad_size = -1 + + #print(f"before tensor input_tokens: {input_tokens}") + #print(f"before tensor input_positions: {input_positions}") + #print(f"before list seq_lens: {seq_lens}") + input_tokens = flatten_2d_lists(input_tokens) + input_positions = flatten_2d_lists(input_positions) + if graph_pad_size != -1 and not is_prompt: + input_tokens.extend(itertools.repeat(0, graph_pad_size)) + input_positions.extend( # type: ignore + itertools.repeat(0, graph_pad_size)) + seq_lens.extend(itertools.repeat(1, graph_pad_size)) + query_lens.extend(itertools.repeat(1, graph_pad_size)) + input_tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, device=self.runner.device) if mrope_input_positions is not None: @@ -548,13 +572,16 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): dtype=torch.long, device=self.runner.device) else: - input_positions_tensor = torch.tensor( - flatten_2d_lists(input_positions), - dtype=torch.long, - device=self.runner.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.runner.device) + #print(f"after tensor input_tokens_tensor: {input_tokens_tensor}") + #print(f"after tensor input_positions_tensor: {input_positions_tensor}") + #print(f"after list seq_lens: {seq_lens}") # Attention metadata. - attn_metadata = self.attn_metadata_builder.build(seq_lens, query_lens) + attn_metadata = self.attn_metadata_builder.build( + seq_lens, query_lens, graph_pad_size) # LoRA data. lora_requests = set() @@ -582,6 +609,13 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): ] multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + if self.runner.vllm_config.compilation_config.level ==\ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + torch._dynamo.mark_static(input_tokens_tensor) + torch._dynamo.mark_static(input_positions_tensor) + torch._dynamo.mark_static(attn_metadata.block_tables) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + return self.model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, @@ -841,6 +875,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): self.in_profile_run = False + self.graph_block_tables = np.zeros( + (self.vllm_config.scheduler_config.max_num_seqs, + (model_config.max_model_len + self.block_size - 1) // + self.block_size), + dtype=np.int32) + # Attention-free but stateful models like Mamba need a placeholder attn # backend, as the attention metadata is needed to manage internal state. # However we must bypass attention selection altogether for some models @@ -930,6 +970,26 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): ) self.model = self.lora_manager.create_lora_manager(self.model) + # adapter torch compile with npu_backend + if self.vllm_config.compilation_config.level ==\ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + import torchair # type: ignore + from torchair import patch_for_hcom # type: ignore + + # 通信算子成图 + patch_for_hcom() + # 设置npu的config,如果不设置config,可以使用默认的,那可以设置npu_backend="npu" + config = torchair.CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True + torch.npu.set_compile_mode(jit_compile=False) + self.compile_model = torchair.inference.cache_compile( + self.model.forward, + dynamic=True, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + config=config, + ge_cache=False) + def save_sharded_state( self, path: str, @@ -1219,10 +1279,43 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): self.attn_state.begin_forward(model_input) assert model_input.attn_metadata is not None + if self.vllm_config.compilation_config.level ==\ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + torch._dynamo.mark_static(model_input.input_tokens) + torch._dynamo.mark_static(model_input.input_positions) + torch._dynamo.mark_static(model_input.attn_metadata.block_tables) + torch._dynamo.mark_static(model_input.attn_metadata.slot_mapping) + torch._dynamo.mark_static( + model_input.attn_metadata.query_start_loc) + torch._dynamo.mark_static(model_input.attn_metadata.seq_start_loc) + for kv in kv_caches: + if isinstance(kv, tuple): + torch._dynamo.mark_static(kv[0]) + torch._dynamo.mark_static(kv[1]) + # TODO(andoorve): We can remove this once all # virtual engines share the same kv cache. virtual_engine = model_input.virtual_engine - model_executable = self.model + prefill_meta = model_input.attn_metadata.prefill_metadata + previous_hidden_states = kwargs.get("previous_hidden_states") + if prefill_meta is None and self.vllm_config.compilation_config.level > 0: + model_executable = self.compile_model + # Note: graph_batch_size value not same as GPU + graph_batch_size = model_input.input_tokens.shape[ # type: ignore + 0] # type: ignore + # Note: previous_hidden_states maybe None not same as GPU + if previous_hidden_states is not None: + previous_hidden_states = torch.cat([ + previous_hidden_states, + torch.empty([ + graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:] + ], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device) + ]) + else: + model_executable = self.model # Receive KV cache in distributed KV cache transfer setting # In disagg prefill setting, it will also recv hidden states and bypass @@ -1248,8 +1341,11 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_inner_state else {} - previous_hidden_states = kwargs.get("previous_hidden_states") - model_kwargs = {} + if self.vllm_config.compilation_config.level ==\ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + model_kwargs = {"inputs_embeds": None} + else: + model_kwargs = {} if previous_hidden_states is not None: model_kwargs["previous_hidden_states"] = previous_hidden_states @@ -1273,44 +1369,30 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): **seqlen_agnostic_kwargs, **model_kwargs) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.record() - - # Sending KV cache in distributed KV cache transfer setting - # NOTE: the send operation is non-blocking - if self.need_send_kv(model_input, kv_caches): - get_kv_transfer_group().send_kv_caches_and_hidden_states( - # model_executable is used to know which layer the current - # worker is working on, so that we can send KV for only those - # layers. - model_executable, - model_input, - kv_caches, - hidden_or_intermediate_states, - ) - - # Compute the logits in the last pipeline stage. - if not get_pp_group().is_last_rank: - if (self.is_driver_worker - and hidden_or_intermediate_states is not None - and isinstance(hidden_or_intermediate_states, - IntermediateTensors) - and self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - hidden_or_intermediate_states.tensors["model_forward_time"] = ( - torch.tensor(model_forward_time + orig_model_forward_time)) - return hidden_or_intermediate_states - - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + if (self.is_driver_worker + and hidden_or_intermediate_states is not None + and isinstance(hidden_or_intermediate_states, + IntermediateTensors) + and self.observability_config is not None and + self.observability_config.collect_model_forward_time): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + hidden_or_intermediate_states.tensors[ + "model_forward_time"] = ( + torch.tensor(model_forward_time + + orig_model_forward_time)) + return hidden_or_intermediate_states + # TODO: remove the synchronize here + torch.npu.synchronize() + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) if not self.is_driver_worker: return [] @@ -1348,6 +1430,9 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): hidden_states = hidden_or_intermediate_states.index_select( 0, indices) output.prefill_hidden_states = hidden_or_intermediate_states + elif self.vllm_config.compilation_config.level == \ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + hidden_states = hidden_or_intermediate_states[:len(indices)] else: hidden_states = hidden_or_intermediate_states diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3425ce7..a67b3c4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -19,6 +19,7 @@ import gc import os +import weakref from typing import TYPE_CHECKING, Dict, List, Optional, Union import numpy as np @@ -47,8 +48,7 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.attention.attention import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import (AscendAttentionState, - AscendMetadata) +from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.platform import NPUPlatform if TYPE_CHECKING: @@ -104,6 +104,27 @@ class NPUModelRunner: raise NotImplementedError( "Non-Attention backend is not supported by V1 NPUModelRunner.") + self.attn_backend = get_attn_backend( + self.head_size, + self.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) + if self.attn_backend is None: + error_msg = ( + f"Error with get_att_backend: {self.head_size=}, " + f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{self.model_config.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 GPUModelRunner.") + + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + weakref.proxy(self)) + # Multi-modal data support self.input_registry = INPUT_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY @@ -191,6 +212,12 @@ class NPUModelRunner: pin_memory=True) self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=True) + self.query_start_loc_np = self.query_start_loc_cpu.numpy() + self.seq_lens_cpu = torch.zeros(self.max_num_reqs, dtype=torch.int32, device="cpu", @@ -200,6 +227,8 @@ class NPUModelRunner: self.input_positions_cpu = torch.arange(0, self.max_num_tokens, device="cpu") + self.attn_mask = None + self.attn_state = None # NOTE: Pre-construct a mask matrix to improve the efficiency of # attention mask construction during inference. @@ -396,7 +425,11 @@ class NPUModelRunner: num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - # Copy the blocks from CPU to NPU. + modified_batch = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output) + if modified_batch: + self.input_batch.refresh_sampling_metadata() + # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) @@ -430,14 +463,13 @@ class NPUModelRunner: self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) positions = self.positions[:total_num_scheduled_tokens] + self.query_lens = torch.from_numpy(num_scheduled_tokens) self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) seq_lens = self.seq_lens_cpu[:num_reqs] - query_lens = torch.from_numpy(num_scheduled_tokens) - block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) block_table_cpu = self.input_batch.block_table.get_cpu_tensor() @@ -446,8 +478,6 @@ class NPUModelRunner: np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) - slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( - self.device, non_blocking=True) attn_state = AscendAttentionState.ChunkedPrefill if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): @@ -461,15 +491,14 @@ class NPUModelRunner: query_lens=num_scheduled_tokens, position=positions, attn_state=attn_state) + self.attn_mask = attn_mask + self.attn_state = attn_state # type: ignore - attn_metadata = AscendMetadata( - seq_lens=query_lens, - context_lens=seq_lens, - slot_mapping=slot_mapping, - block_tables=( - self.input_batch.block_table.get_device_tensor()[:num_reqs]), - attn_mask=attn_mask, - attn_state=attn_state, + attn_metadata = self.attn_metadata_builder.build( # type: ignore + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=None, ) # Prepare input_ids @@ -804,6 +833,9 @@ class NPUModelRunner: # different GPUs, and `kv_cache_config.num_blocks` is set to # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks + # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may + # encounter OOM issue + num_blocks = num_blocks // 4 if isinstance(kv_cache_spec, FullAttentionSpec): kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index f1bf0b8..1479bdc 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -44,6 +44,7 @@ from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) from vllm_ascend.device_allocator.camem import CaMemAllocator +from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import try_register_lib, vllm_version_is from vllm_ascend.worker.model_runner import NPUModelRunner @@ -313,8 +314,14 @@ class NPUWorker(LocalOrDistributedWorkerBase): for ve in range(self.parallel_config.pipeline_parallel_size): num_layers = len(self.cache_engine[ve].gpu_cache) for i in range(num_layers): - torch_npu.npu_format_cast(self.cache_engine[ve].gpu_cache[i], - 2) + if torch.is_tensor(self.cache_engine[ve].gpu_cache[i]): + torch_npu.npu_format_cast( + self.cache_engine[ve].gpu_cache[i], 2) + else: + torch_npu.npu_format_cast( + self.cache_engine[ve].gpu_cache[i][0], 2) + torch_npu.npu_format_cast( + self.cache_engine[ve].gpu_cache[i][1], 2) self.gpu_cache = [ self.cache_engine[ve].gpu_cache for ve in range(self.parallel_config.pipeline_parallel_size) @@ -495,6 +502,7 @@ class NPUWorker(LocalOrDistributedWorkerBase): backend: str = "hccl") -> None: """Initialize the distributed environment.""" parallel_config = self.parallel_config + additional_config = self.vllm_config.additional_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, @@ -502,6 +510,14 @@ class NPUWorker(LocalOrDistributedWorkerBase): ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + expert_tensor_parallel_size = 1 + if additional_config is not None and hasattr( + additional_config, "expert_tensor_parallel_size"): + expert_tensor_parallel_size = getattr( + additional_config, "expert_tensor_parallel_size") + init_ascend_model_parallel(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + expert_tensor_parallel_size) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 0cdb00a..7e98d4b 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -38,6 +38,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.worker_base import WorkerBase +from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import try_register_lib, vllm_version_is from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -209,6 +210,8 @@ class NPUWorker(WorkerBase): def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment.""" + additional_config = self.vllm_config.additional_config + parallel_config = self.vllm_config.parallel_config set_custom_all_reduce( not self.parallel_config.disable_custom_all_reduce) init_distributed_environment(self.parallel_config.world_size, @@ -217,6 +220,13 @@ class NPUWorker(WorkerBase): ensure_model_parallel_initialized( self.parallel_config.tensor_parallel_size, self.parallel_config.pipeline_parallel_size) + expert_tensor_parallel_size = 1 + if additional_config is not None and "expert_tensor_parallel_size" in additional_config: + expert_tensor_parallel_size = int( + additional_config["expert_tensor_parallel_size"]) + init_ascend_model_parallel(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + expert_tensor_parallel_size) ensure_kv_transfer_initialized(self.vllm_config) def _init_profiler(self):