diff --git a/pyproject.toml b/pyproject.toml index 514b755..aebad2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,12 +12,14 @@ requires = [ "scipy", "setuptools>=64", "setuptools-scm>=8", - "torch-npu==2.5.1.post1.dev20250528", + "torch-npu==2.5.1.post1.dev20250619", "torch>=2.5.1", "torchvision<0.21.0", "wheel", "msgpack", "quart", "numba", + # Remove after https://github.com/vllm-project/vllm-ascend/issues/1470 + "transformers<4.53.0", ] build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index 375b554..6d84ec6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,6 @@ numba --pre --extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi torch-npu==2.5.1.post1.dev20250619 + +# Remove after https://github.com/vllm-project/vllm-ascend/issues/1470 +transformers<4.53.0 diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index bace69d..b49b4e4 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -25,7 +25,7 @@ # # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py # """Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union import torch import torch.distributed as dist @@ -49,16 +49,18 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.deepseek_v2 import \ DeepseekV2ForCausalLM # noqa: E501 from vllm.model_executor.models.deepseek_v2 import \ yarn_get_mscale # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, - DeepseekV2DecoderLayer, - DeepseekV2MLAAttention) +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention, + get_spec_layer_idx_from_weight_name) from vllm.model_executor.models.utils import ( - PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend @@ -76,7 +78,7 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig, make_multistream_metadata_ds) from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.utils import dispose_tensor +from vllm_ascend.utils import dispose_tensor, vllm_version_is VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO @@ -963,6 +965,107 @@ class CustomDeepseekDBOForCausalLM(DeepseekV2ForCausalLM): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + # NOTE: This `load_weights` is mainly copied from + # https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5 + # to fix CI, and it is different from the implementation in main + # TODO: support eplb style load_weights + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + """""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = AscendFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + if vllm_version_is("0.9.1"): + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + else: + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + return_success=False) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + def forward( self, input_ids: torch.Tensor, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 807c0a2..d7f68a1 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -25,7 +25,7 @@ # # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py # """Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch import torch_npu @@ -55,16 +55,18 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.deepseek_v2 import \ DeepseekV2ForCausalLM # noqa: E501 from vllm.model_executor.models.deepseek_v2 import \ yarn_get_mscale # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, - DeepseekV2DecoderLayer, - DeepseekV2MLAAttention) +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention, + get_spec_layer_idx_from_weight_name) from vllm.model_executor.models.utils import ( - PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors from vllm_ascend.ascend_config import get_ascend_config @@ -73,7 +75,7 @@ from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import (dispose_tensor, npu_stream_switch, - npu_wait_tensor) + npu_wait_tensor, vllm_version_is) class CustomDeepseekV2SiluAndMul(SiluAndMul): @@ -867,6 +869,107 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + # NOTE: This `load_weights` is mainly copied from + # https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5 + # to fix CI, and it is different from the implementation in main + # TODO: support eplb style load_weights + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + """""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = AscendFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + if vllm_version_is("0.9.1"): + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + else: + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + return_success=False) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + def forward( self, input_ids: torch.Tensor,