# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project """A MLU quant class.""" import functools from collections import defaultdict from typing import Dict, Any, List, Optional, Union import numpy as np import torch import torch.distributed from vllm.distributed import ( get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size, get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) import vllm.envs as envs from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.model_executor.layers.vocab_parallel_embedding import (VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLAAttention from vllm_mlu.model_executor.layers.feed_forward import FeedForward from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp from vllm.logger import init_logger logger = init_logger(__name__) def default_act_range_value(): return { "x": None, "split": None, "is_linear": False, "is_qkv": False, "q_proj_size": 0, "num_kv_head_replicas": 1, "is_merge": False, "input_id": [], "self_rank": 0, "rank": None, "tensor_rank": None, "tp_world_size": None, "moe_tp_rank": None, "moe_tp_world_size": None, "moe_ep_rank": None, "moe_ep_world_size": None, "weight": None, } def _str_to_torch_dtype(dtype: str) -> torch.dtype: dtype = dtype.split(".")[-1] # STR_DTYPE_TO_TORCH_DTYPE dict does not have float16 type return STR_DTYPE_TO_TORCH_DTYPE[dtype] if dtype != "float16" else torch.float16 class ActRangeValue: """ ActRangeValue for v1 MsgpackEncoder and MsgpackDecoder. This is a *WorkAround*. The decode tensor can be wrong if we pass act range dict directly. NOTE: here, we transfer torch.Tensor to numpy ndarray because torch.Tensor may cause core dump. """ def __init__(self): self.layer_name: str = "" self.x: Optional[np.ndarray] = None self.split: str = None self.is_linear: bool = False self.is_qkv: bool = False self.q_proj_size: int = 0 self.num_kv_head_replicas: int = 1 self.is_merge: bool = False self.input_id_dtype: str = None self.input_id: Optional[List[np.ndarray]] = [] self.self_rank: int = 0 self.rank: Optional[int] = None self.tensor_rank: Optional[int] = None self.tp_world_size: Optional[int] = None self.moe_tp_rank: Optional[int] = None self.moe_tp_world_size: Optional[int] = None self.moe_ep_rank: Optional[int] = None self.moe_ep_world_size: Optional[int] = None self.weight: Optional[np.ndarray] = None self.weight_dtype: str = None @classmethod def serial(cls, layer_name: str, act_range: Dict[str, Any]) -> 'ActRangeValue': instance = cls() instance.layer_name = layer_name instance.x = act_range.get("x") instance.split = act_range.get("split") instance.is_linear = act_range.get("is_linear", False) instance.is_qkv = act_range.get("is_qkv", False) instance.q_proj_size = act_range.get("q_proj_size", 0) instance.num_kv_head_replicas = act_range.get("num_kv_head_replicas", 1) instance.is_merge = act_range.get("is_merge", False) instance.input_id = act_range.get("input_id", []) instance.self_rank = act_range.get("self_rank", 0) instance.rank = act_range.get("rank") instance.tensor_rank = act_range.get("tensor_rank") instance.tp_world_size = act_range.get("tp_world_size") instance.moe_tp_rank = act_range.get("moe_tp_rank") instance.moe_tp_world_size = act_range.get("moe_tp_world_size") instance.moe_ep_rank = act_range.get("moe_ep_rank") instance.moe_ep_world_size = act_range.get("moe_ep_world_size") instance.weight = act_range.get("weight") if instance.x is not None: instance.x = instance.x.numpy() # input_id and weight are used for debug if isinstance(instance.input_id, torch.Tensor): instance.input_id_dtype = str(instance.input_id.dtype) instance.input_id = instance.input_id.float().numpy() else: input_id_np = [] for input_id in instance.input_id: instance.input_id_dtype = str(input_id.dtype) input_id_np.append(input_id.float().numpy()) instance.input_id = input_id_np if instance.weight is not None: instance.weight_dtype = str(instance.weight.dtype) instance.weight = instance.weight.float().numpy() return instance def deserial(self) -> Dict[str, Any]: act_range = self.to_dict() if self.x is not None: act_range["x"] = torch.from_numpy(self.x) if self.input_id is not None: if isinstance(self.input_id, torch.Tensor): act_range["input_id"] = torch.from_numpy(self.input_id).to( _str_to_torch_dtype(self.input_id_dtype)) else: input_id_tensor = [] for input_id in self.input_id: input_id_tensor.append(torch.from_numpy(input_id).to( _str_to_torch_dtype(self.input_id_dtype))) act_range["input_id"] = input_id_tensor if self.weight_dtype is not None: act_range["weight"] = torch.from_numpy(self.weight).to( _str_to_torch_dtype(self.weight_dtype)) return act_range def to_dict(self) -> Dict[str, Any]: return { "x": self.x, "split": self.split, "is_linear": self.is_linear, "is_qkv": self.is_qkv, "q_proj_size": self.q_proj_size, "num_kv_head_replicas": self.num_kv_head_replicas, "is_merge": self.is_merge, "input_id": self.input_id, "self_rank": self.self_rank, "rank": self.rank, "tensor_rank": self.tensor_rank, "tp_world_size": self.tp_world_size, "moe_tp_rank": self.moe_tp_rank, "moe_tp_world_size": self.moe_tp_world_size, "moe_ep_rank": self.moe_ep_rank, "moe_ep_world_size": self.moe_ep_world_size, "weight": self.weight, } def __repr__(self) -> str: return f"layer: {self.layer_name}, ActRangeValue({self.to_dict()})" class MLUWorkerQuant(object): ''' mlu quant ''' def stat_tensor(self, name, tensor, act_range, key, dim): logger.debug(f"name:{name}, key:{key}, dim:{dim}, tensor.shape:{tensor.shape}") hidden_dim = tensor.shape[-1] tensor = tensor.view(-1, hidden_dim).abs() comming_max = torch.max(tensor, dim=dim)[0].float() if act_range[name][key] is None: act_range[name][key] = comming_max else: act_range[name][key] = torch.max(act_range[name][key], comming_max) def stat_input_hook(self, m, x, y, name, act_range, is_linear, is_save_input_id): if isinstance(x, tuple): x = x[0] if isinstance(y, tuple): y = y[0] logger.debug(f"name:{name}, x.shape:{x.shape}, y.shape:{y.shape}, m.weight.shape:{m.weight.shape}") if is_linear: self.stat_tensor(name, x, act_range, "x", 0) if act_range[name]["is_qkv"] and is_save_input_id and ".0." in name: x_cpu = x.clone().to("cpu") act_range[name]["input_id"].append(x_cpu) def setup_smooth_hook(self, is_save_input_id: bool = False, is_save_moe_info: bool = False): models = [self.model_runner.model] if hasattr(self.model_runner, "drafter") and self.model_runner.drafter is not None: models += [self.model_runner.drafter.model] self.act_range = defaultdict(default_act_range_value) self.hooks = [] linear_class_list = (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) other_class_list = (VocabParallelEmbedding, ParallelLMHead) class_list = linear_class_list + other_class_list row_class_list = (RowParallelLinear) for model in models: for name, m in model.named_modules(): if isinstance(m, FeedForward): m.use_bt_ffn = False if isinstance(m, SparseMoeMlp): m.is_use_fused_moe = False if isinstance(m, DeepseekV2MLAAttention): m.use_fused_mla_qkv = False if isinstance(m, class_list): is_linear = True if isinstance(m, linear_class_list) else False split_type = "row" if isinstance(m, row_class_list) else "col" self.act_range[name]["split"] = split_type self.act_range[name]["is_linear"] = is_linear if isinstance(m, QKVParallelLinear): self.act_range[name]["is_qkv"] = True self.act_range[name]["q_proj_size"] = m.num_heads * m.head_size self.act_range[name]["num_kv_head_replicas"] = m.num_kv_head_replicas self.act_range[name]["is_merge"] = isinstance(m, MergedColumnParallelLinear) if is_save_moe_info: self.act_range[name]["rank"] = torch.distributed.get_rank() self.act_range[name]["tensor_rank"] = get_tensor_model_parallel_rank() self.act_range[name]["tp_world_size"] = get_tensor_model_parallel_world_size() self.act_range[name]["moe_tp_rank"] = get_moe_tensor_parallel_rank() self.act_range[name]["moe_tp_world_size"] = get_moe_tensor_parallel_world_size() self.act_range[name]["moe_ep_rank"] = get_moe_expert_parallel_rank() self.act_range[name]["moe_ep_world_size"] = get_moe_expert_parallel_world_size() if ".expert." in name: self.act_range[name]["weight"] = m.weight logger.info(f"rank:{self.rank}, add hook to {name}, is_linear:{is_linear}, split_type:{split_type}") self.hooks.append(m.register_forward_hook(functools.partial(self.stat_input_hook, name=name, act_range=self.act_range, is_linear=is_linear, is_save_input_id=is_save_input_id))) def remove_hooks(self): for h in self.hooks: h.remove() def get_act_range(self): act_range = defaultdict(default_act_range_value) for layer_name, layer_range in self.act_range.items(): for tensor_key, tensor_value in layer_range.items(): if isinstance(tensor_value, torch.Tensor): act_range[layer_name][tensor_key] = tensor_value.to("cpu") elif tensor_key == "input_id" and isinstance(tensor_value, list): input_id_len = len(tensor_value) for i in range(input_id_len): if isinstance(tensor_value[i], torch.Tensor): act_range[layer_name][tensor_key].append(tensor_value[i].to("cpu")) else: act_range[layer_name][tensor_key].append(tensor_value[i]) else: act_range[layer_name][tensor_key] = tensor_value serialization_result = [] for layer_name, layer_range in act_range.items(): serialization_result.append(ActRangeValue.serial(layer_name, layer_range)) return serialization_result @torch.no_grad() def get_named_parameters(self): name_parameters = {} for name, param in self.model_runner.model.named_parameters(): name_parameters[name] = param.to("cpu") return name_parameters