init
This commit is contained in:
0
vllm_vacc/vllm/model_executor/__init__.py
Normal file
0
vllm_vacc/vllm/model_executor/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
38
vllm_vacc/vllm/model_executor/custom_op.py
Normal file
38
vllm_vacc/vllm/model_executor/custom_op.py
Normal file
@@ -0,0 +1,38 @@
|
||||
|
||||
import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
|
||||
def forward_vacc(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch_forward(self):
|
||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||
# specific backend. Currently, we do not support dynamic dispatching.
|
||||
|
||||
enabled = self.enabled()
|
||||
logger.debug("custom op %s %s", self.__class__.name,
|
||||
"enabled" if enabled else "disabled")
|
||||
|
||||
if not enabled:
|
||||
return self.forward_native
|
||||
|
||||
return self.forward
|
||||
|
||||
if current_platform.is_rocm():
|
||||
return self.forward_hip
|
||||
elif current_platform.is_cpu():
|
||||
return self.forward_cpu
|
||||
elif current_platform.is_hpu():
|
||||
return self.forward_hpu
|
||||
elif current_platform.is_tpu():
|
||||
return self.forward_tpu
|
||||
elif current_platform.is_xpu():
|
||||
return self.forward_xpu
|
||||
elif current_platform.is_vacc():
|
||||
return self.forward
|
||||
else:
|
||||
return self.forward_cuda
|
||||
0
vllm_vacc/vllm/model_executor/layers/__init__.py
Normal file
0
vllm_vacc/vllm/model_executor/layers/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
8
vllm_vacc/vllm/model_executor/layers/activation.py
Normal file
8
vllm_vacc/vllm/model_executor/layers/activation.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import torch
|
||||
|
||||
def SiluAndMul_forward_vacc(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.vacc.swiglu(x)
|
||||
|
||||
def QuickGELU_forward_vacc(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
82
vllm_vacc/vllm/model_executor/layers/fused_moe/fused_moe.py
Normal file
82
vllm_vacc/vllm/model_executor/layers/fused_moe/fused_moe.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||
gating_output = gating_output.view(num_tokens, -1)
|
||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(dtype)
|
||||
return topk_weights, selected_experts
|
||||
|
||||
def grouped_topk_with_itype(hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import fused_moe_preprocess
|
||||
return fused_moe_preprocess(gating_output, e_score_correction_bias)
|
||||
except Exception as e:
|
||||
print(f"fused group topk run fail, now use unfused group topk: {e}")
|
||||
|
||||
if scoring_func == "softmax":
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.shape[0]
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_scores = scores
|
||||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (scores.view(num_token, num_expert_group,
|
||||
-1).topk(2, dim=-1)[0].sum(dim=-1))
|
||||
else:
|
||||
group_scores = scores.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
|
||||
sorted=False)[1] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(),
|
||||
float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores,
|
||||
k=topk,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights.to(hidden_states.dtype), topk_ids.to(torch.int32)
|
||||
715
vllm_vacc/vllm/model_executor/layers/fused_moe/layer.py
Normal file
715
vllm_vacc/vllm/model_executor/layers/fused_moe/layer.py
Normal file
@@ -0,0 +1,715 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_moe import fused_experts
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
if current_platform.is_tpu():
|
||||
# the iterative moe implementation is used until the moe_pallas is fixed
|
||||
from .moe_torch_iterative import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
fused_moe_pallas = None # type: ignore
|
||||
logger = init_logger(__name__)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
GROUP = "group"
|
||||
BLOCK = "block"
|
||||
|
||||
def FusedMoE_init_(
|
||||
self,
|
||||
num_experts: int, # Global number of experts
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
dp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
has_bias: bool = False,
|
||||
is_sequence_parallel=False,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
super(FusedMoE, self).__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# FIXME (varun): We should have a better way of inferring the activation
|
||||
# datatype. This works for now as the tensor datatype entering the MoE
|
||||
# operation is typically unquantized (i.e. float16/bfloat16).
|
||||
if vllm_config.model_config is not None:
|
||||
moe_in_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||
# since model_config is not set in the pytest test.
|
||||
moe_in_dtype = params_dtype
|
||||
|
||||
tp_size_ = (tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size())
|
||||
dp_size_ = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
|
||||
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size_,
|
||||
dp_size_=dp_size_,
|
||||
vllm_parallel_config=vllm_config.parallel_config))
|
||||
|
||||
self.global_num_experts = num_experts + num_redundant_experts
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError("Duplicate layer name: {}".format(prefix))
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
|
||||
self.enable_eplb = enable_eplb
|
||||
self.expert_load_view: Optional[torch.Tensor] = None
|
||||
self.logical_to_physical_map: Optional[torch.Tensor] = None
|
||||
self.logical_replica_count: Optional[torch.Tensor] = None
|
||||
|
||||
# Determine expert maps
|
||||
if self.use_ep:
|
||||
if self.enable_eplb:
|
||||
assert self.global_num_experts % self.ep_size == 0, \
|
||||
"EPLB currently only supports even distribution of " \
|
||||
"experts across ranks."
|
||||
else:
|
||||
assert num_redundant_experts == 0, \
|
||||
"Redundant experts are only supported with EPLB."
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts)
|
||||
else:
|
||||
self.local_num_experts, self.expert_map = (self.global_num_experts,
|
||||
None)
|
||||
|
||||
self.top_k = top_k
|
||||
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self.activation = activation
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=moe_in_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
has_bias=has_bias,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
quant_method: Optional[QuantizeMethodBase] = None
|
||||
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
|
||||
else quant_config.get_quant_method(self, prefix))
|
||||
|
||||
assert quant_method is not None
|
||||
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
self.quant_method = quant_method
|
||||
|
||||
if self.enable_eplb:
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8MoEMethod)
|
||||
if not isinstance(quant_method, Fp8MoEMethod):
|
||||
# TODO: Add support for additional quantization methods.
|
||||
# The implementation for other quantization methods does not
|
||||
# contain essential differences, but the current quant API
|
||||
# design causes duplicated work when extending to new
|
||||
# quantization methods, so I'm leaving it for now.
|
||||
# If you plan to add support for more quantization methods,
|
||||
# please refer to the implementation in `Fp8MoEMethod`.
|
||||
raise NotImplementedError("EPLB is only supported for FP8 "
|
||||
"quantization for now.")
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
# self.scale_n = self.quant_method.scale_n
|
||||
# self.scale_k = self.quant_method.scale_k
|
||||
self.scale_n = 1
|
||||
self.scale_k = 1
|
||||
self.scale_n_prefill = 1
|
||||
if hasattr(self.quant_method, "scale_n") and hasattr(self.quant_method, "scale_k"):
|
||||
self.scale_n = self.quant_method.scale_n
|
||||
self.scale_k = self.quant_method.scale_k
|
||||
if hasattr(self.quant_method, "scale_n_prefill"):
|
||||
self.scale_n_prefill = self.quant_method.scale_n_prefill
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(moe.max_num_tokens, self.hidden_size),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(moe.max_num_tokens, num_experts),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
|
||||
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int,
|
||||
expert_id=0):
|
||||
#print("w13 shape is:", expert_data.shape, loaded_weight.shape)
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2 and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
n_v, k_v = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.reshape(1, n_v, 1, k_v) #[1, n, 1, k]
|
||||
if self.scale_n_prefill != self.scale_n and hasattr(self, 'w13_weight_scale_inv_prefill'):
|
||||
loaded_weight0 = loaded_weight.repeat(self.scale_n_prefill,1,self.scale_k,1).permute(1,0,3,2).reshape([n_v * self.scale_n_prefill, k_v * self.scale_k])
|
||||
shard_size = self.w13_weight_scale_inv_prefill.data[expert_id].shape[shard_dim] // 2
|
||||
loaded_weight0 = loaded_weight0.narrow(shard_dim, shard_size * tp_rank, shard_size)
|
||||
|
||||
if shard_id == "w1":
|
||||
self.w13_weight_scale_inv_prefill.data[expert_id, :loaded_weight0.shape[0], :loaded_weight0.shape[1]] = loaded_weight0
|
||||
elif shard_id == "w3":
|
||||
self.w13_weight_scale_inv_prefill.data[expert_id, -loaded_weight0.shape[0]:, -loaded_weight0.shape[1]:] = loaded_weight0
|
||||
else:
|
||||
raise ValueError('error shard_id: ',shard_id)
|
||||
|
||||
loaded_weight = loaded_weight.repeat(self.scale_n,1,self.scale_k,1).permute(1,0,3,2).reshape([n_v * self.scale_n, k_v * self.scale_k])
|
||||
#print("w13 repeat shape is:", expert_data.shape, loaded_weight.shape)
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False,
|
||||
expert_id=0):
|
||||
#print("w2 shape is:", expert_data.shape, loaded_weight.shape)
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2 and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
n_v, k_v = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.reshape(1, n_v, 1, k_v) #[1, n, 1, k]
|
||||
if self.scale_n_prefill != self.scale_n and hasattr(self, 'w2_weight_scale_inv_prefill'):
|
||||
loaded_weight0 = loaded_weight.repeat(self.scale_k,1,self.scale_n_prefill,1).permute(1,0,3,2).reshape([n_v * self.scale_k, k_v * self.scale_n_prefill])
|
||||
shard_size = self.w2_weight_scale_inv_prefill.data[expert_id].shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight0 = loaded_weight0.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
self.w2_weight_scale_inv_prefill.data[expert_id] = loaded_weight0
|
||||
|
||||
#print("loaded_weight:", loaded_weight.shape)
|
||||
loaded_weight = loaded_weight.repeat(self.scale_k,1,self.scale_n,1).permute(1,0,3,2).reshape([n_v * self.scale_k, k_v * self.scale_n])
|
||||
#print("w2 repeat shape is:", expert_data.shape, loaded_weight.shape)
|
||||
#if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def weight_loader(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int, return_success=True) -> None:
|
||||
|
||||
if self.quant_config and self.quant_config.get_name() == "mxfp4":
|
||||
# (FIXME) for gpt-oss all experts are combined
|
||||
if "bias" in weight_name:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
param.data[:, :dim1].copy_(loaded_weight)
|
||||
else:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
dim2 = loaded_weight.shape[2]
|
||||
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
|
||||
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||
if expert_id == -1:
|
||||
return False if return_success else None
|
||||
|
||||
quant_method_name = self.quant_method.__class__.__name__
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
if self.quant_method.__class__.__name__ in (
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod"):
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
|
||||
# WEIGHT_SCALE_SUPPORTED = [
|
||||
# e.value for e in FusedMoeWeightScaleSupported
|
||||
# ]
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
# dimension intermediate_size_per_partition is used.
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
param.data.copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
|
||||
# Case for BitsAndBytes
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
if use_bitsandbytes_4bit:
|
||||
shard_dim = 0
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
if shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
# BNB inflight quantization has already sharded the weights
|
||||
full_load = True
|
||||
self._load_w13(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank,
|
||||
load_full=full_load,
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size_per_partition is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
shard_dim = int(not shard_dim)
|
||||
|
||||
full_load = len(loaded_weight.shape) == 3
|
||||
if full_load:
|
||||
shard_dim += 1
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
final_shape = list(loaded_weight.shape)
|
||||
if shard_id in ["w1", "w3"]:
|
||||
final_shape[1] *= 2
|
||||
final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size
|
||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
expert_data = param.data if full_load else param.data[expert_id]
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
# this is needed for compressed-tensors only
|
||||
loaded_weight = loaded_weight.to(param.data.device)
|
||||
|
||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param.data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return True if return_success else None
|
||||
|
||||
# Case g_idx
|
||||
if "g_idx" in weight_name:
|
||||
self._load_g_idx(shard_dim=0,
|
||||
shard_id=shard_id,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
|
||||
if "ModelOpt" in quant_method_name:
|
||||
# Determine per-tensor weight scale patterns based on variant
|
||||
# Use the dedicated method instead of brittle string matching
|
||||
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
|
||||
)
|
||||
|
||||
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
|
||||
# weights scales.
|
||||
# Input scales are always per-tensor.
|
||||
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
|
||||
# "weight_scale" for per-tensor scales.
|
||||
is_per_tensor = ("weight_scale_2" in weight_name
|
||||
if uses_weight_scale_2 else "weight_scale"
|
||||
in weight_name) or "input_scale" in weight_name
|
||||
if is_per_tensor:
|
||||
self._load_per_tensor_weight_scale(
|
||||
shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# If the weight is w13_weight_scale and w13_weight_scales are
|
||||
# combined into single loaded_weight, call
|
||||
# _load_combined_w13_weight_scale() to load it.
|
||||
# This is checked by comparing the hidden_out dims of the
|
||||
# loaded_weight and the param.
|
||||
if "w13_weight_scale" in weight_name:
|
||||
loaded_weight_hidden_out = loaded_weight.shape[-2]
|
||||
param_hidden_out = param.data.shape[-2] * self.tp_size
|
||||
if loaded_weight_hidden_out == param_hidden_out:
|
||||
self._load_combined_w13_weight_scale(
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
param=param,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# For other weights, call _load_model_weight_or_group_weight_scale()
|
||||
# to load it.
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
|
||||
# Case weight scales, zero_points and offset
|
||||
if ("scale" in weight_name or "zero" in weight_name
|
||||
or "offset" in weight_name):
|
||||
# load the weight scales and zp based on the quantization scheme
|
||||
# supported weight scales/zp can be found in
|
||||
# FusedMoeWeightScaleSupported
|
||||
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
||||
# specific to each case
|
||||
quant_method = getattr(param, "quant_method", None)
|
||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||
self._load_per_channel_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
elif quant_method in [
|
||||
FusedMoeWeightScaleSupported.GROUP.value,
|
||||
FusedMoeWeightScaleSupported.BLOCK.value,
|
||||
]:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank,
|
||||
load_full_w2=getattr(param, "load_full_w2", False),
|
||||
expert_id=expert_id)
|
||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
else:
|
||||
WEIGHT_SCALE_SUPPORTED = [
|
||||
e.value for e in FusedMoeWeightScaleSupported
|
||||
]
|
||||
raise ValueError(
|
||||
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
||||
return True if return_success else None
|
||||
|
||||
# Case weight_shape
|
||||
if "weight_shape" in weight_name:
|
||||
# only required by compressed-tensors
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return True if return_success else None
|
||||
|
||||
# Case model weights
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
return False if return_success else None
|
||||
|
||||
def _load_model_weight_or_group_weight_scale(self,
|
||||
shard_dim: int,
|
||||
expert_data: torch.Tensor,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full_w2: bool = False,
|
||||
expert_id: int = 0):
|
||||
"""
|
||||
Load grouped weight scales for group quantization or model weights
|
||||
:param shard_dim: dimension to shard
|
||||
:param expert_data: parameter for a particular expert
|
||||
:param shard_id: either w1, w2, or w3
|
||||
:param loaded_weight: checkpoint weight to load into the param
|
||||
:param tp_rank: tensor parallel rank
|
||||
:param load_full_w2: whether or not the w2 loaded should be sharded.
|
||||
"""
|
||||
if shard_id == "w2":
|
||||
# In the case where we have actorder/g_idx, we do not partition the
|
||||
# w2 scales, as indicated by `load_full` argument, for all tp cases
|
||||
self._load_w2(shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
load_full=load_full_w2,
|
||||
expert_id=expert_id)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
expert_id=expert_id)
|
||||
|
||||
class UnquantizedFusedMoEMethod():
|
||||
def forward_vacc(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = x
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
dtype = hidden_states.dtype
|
||||
intermediate_size = layer.w2_weight.shape[-1]
|
||||
gating_output=router_logits
|
||||
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||
gating_output = gating_output.view(num_tokens, global_num_experts)
|
||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(dtype)
|
||||
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
sel_experts = topk_ids.shape[1]
|
||||
if hidden_states.shape[0] == 1:
|
||||
for id in range(sel_experts):
|
||||
expert_idx = topk_ids[0][id]
|
||||
expert_w1 = layer.w13_weight[expert_idx].contiguous()
|
||||
expert_w2 = layer.w2_weight[expert_idx].contiguous()
|
||||
|
||||
expert_weights = topk_weights[0][id].to(hidden_states.dtype)
|
||||
|
||||
x = hidden_states
|
||||
x = F.linear(x, expert_w1)
|
||||
gate = F.silu(x[:, :intermediate_size])
|
||||
x = x[:, intermediate_size:] * gate
|
||||
x = F.linear(x, expert_w2)
|
||||
|
||||
current_hidden_states = x * expert_weights
|
||||
current_hidden_states = current_hidden_states.to(x.dtype)
|
||||
final_hidden_states += current_hidden_states
|
||||
else:
|
||||
for expert_idx in range(global_num_experts):
|
||||
# topk_ids [tokens, experts] => sample:[10, 8]
|
||||
# expert_mask [tokens, experts] => sample:[10, 8]
|
||||
expert_mask = topk_ids == expert_idx
|
||||
|
||||
idx = torch.where(expert_mask)[0]
|
||||
if idx.numel() == 0:
|
||||
continue
|
||||
|
||||
expert_w1 = layer.w13_weight[expert_idx].contiguous()
|
||||
expert_w2 = layer.w2_weight[expert_idx].contiguous()
|
||||
|
||||
# [seq, experts]
|
||||
expert_weights = (
|
||||
topk_weights.masked_select(expert_mask)
|
||||
.unsqueeze(1)
|
||||
.to(hidden_states.dtype)
|
||||
)
|
||||
|
||||
x = hidden_states[idx]
|
||||
x = F.linear(x, expert_w1)
|
||||
gate = F.silu(x[:, :intermediate_size])
|
||||
x = x[:, intermediate_size:] * gate
|
||||
x = F.linear(x, expert_w2)
|
||||
|
||||
current_hidden_states = x * expert_weights
|
||||
current_hidden_states = current_hidden_states.to(x.dtype)
|
||||
# final_hidden_states[idx] += current_hidden_states
|
||||
final_hidden_states.index_add_(0, idx, current_hidden_states)
|
||||
return final_hidden_states.view(orig_shape) # type: ignore
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
custom_forward = self.forward
|
||||
if x.device.type == "vacc":
|
||||
custom_forward = UnquantizedFusedMoEMethod.forward_vacc
|
||||
|
||||
return custom_forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
|
||||
32
vllm_vacc/vllm/model_executor/layers/layernorm.py
Normal file
32
vllm_vacc/vllm/model_executor/layers/layernorm.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
|
||||
def RMSNorm_forward_vacc(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# if residual is not None:
|
||||
# x = x + residual
|
||||
# residual = x
|
||||
hidden_size = x.shape[-1]
|
||||
if hidden_size != self.hidden_size:
|
||||
raise ValueError("Expected hidden_size to be "
|
||||
f"{self.hidden_size}, but found: {hidden_size}")
|
||||
if self.variance_size_override is None:
|
||||
x_var = x
|
||||
else:
|
||||
if hidden_size < self.variance_size_override:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be at least "
|
||||
f"{self.variance_size_override}, but found: {hidden_size}")
|
||||
x_var = x[:, :, :self.variance_size_override]
|
||||
# x_var=x_var.unsqueeze(0)
|
||||
# out = torch.vacc.rms_norm(x_var,self.weight,self.variance_epsilon)
|
||||
# if residual is None:
|
||||
# return out.squeeze(0)
|
||||
# else:
|
||||
# return out.squeeze(0), residual
|
||||
out = torch.vacc.fused_residual_rmsnorm(x_var, self.weight, residual, self.variance_epsilon, x_var, residual)
|
||||
return out
|
||||
465
vllm_vacc/vllm/model_executor/layers/linear.py
Normal file
465
vllm_vacc/vllm/model_executor/layers/linear.py
Normal file
@@ -0,0 +1,465 @@
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
BlockQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
PerTensorScaleParameter,
|
||||
RowvLLMParameter)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
WEIGHT_LOADER_V2_SUPPORTED,
|
||||
LinearBase,
|
||||
RowParallelLinear)
|
||||
|
||||
def ReplicatedLinear__init__(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super(ReplicatedLinear,self).__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.scale_k = 1 # quant_block_k 128 需要除以 scale_k, 如设置为2 即 quant_block_k 是 64
|
||||
self.scale_k_slice = 1
|
||||
self.scale_n = 1
|
||||
self.scale_n_slice = 1
|
||||
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
|
||||
gcd_value = quant_config.weight_block_size[1]
|
||||
import math
|
||||
if input_size % quant_config.weight_block_size[1]:
|
||||
gcd_value = math.gcd(input_size % quant_config.weight_block_size[1], quant_config.weight_block_size[1])
|
||||
self.scale_k =self.scale_k * quant_config.weight_block_size[1] // gcd_value
|
||||
self.scale_k_slice = input_size // gcd_value
|
||||
if output_size % quant_config.weight_block_size[0]:
|
||||
gcd_value = math.gcd(output_size % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
|
||||
self.scale_n = self.scale_n * quant_config.weight_block_size[0] // gcd_value
|
||||
self.scale_n_slice = output_size // gcd_value
|
||||
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size, [self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
scale_k = self.scale_k,
|
||||
scale_n = self.scale_n,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def ReplicatedLinear_weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
# If the weight on disk does not have a shape, give it one
|
||||
# (such scales for AutoFp8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
if self.scale_k > 1 and len(loaded_weight.shape) == 2:
|
||||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||||
loaded_weight = loaded_weight.expand(self.scale_k, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,2,0).reshape([loaded_weight.shape[1], -1])[:, :self.scale_k_slice]
|
||||
#[1,n,k] -> [scale_k,n,k] -> [n,k,scale_k] -> [n, k*scale_k]
|
||||
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
|
||||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||||
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[2]])[:self.scale_n_slice]
|
||||
|
||||
assert param.size() == loaded_weight.size(), f'{param.size()}, {loaded_weight.size()}'
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def refine_block(block_size:list[int],
|
||||
weight_size:list[int],
|
||||
dim:int=0,
|
||||
pingpong_size:int = 2.5*1024*1024, #bytes
|
||||
core_number:int = 4,
|
||||
data_type:int = 2, #bfloat16
|
||||
max_iter_number:int = 2):
|
||||
'''
|
||||
对于不均匀分core, 需要每个core <= 2.5M 才能保证可以pingpong,
|
||||
core间相差数量为 block_size[dim] * weight_size[1-dim]
|
||||
缩小block_size可以减小core间差距,使得更平均一些,直到大core数据量不超
|
||||
如果均匀分core已经超了或者没有超,就没必要调整
|
||||
'''
|
||||
if dim < 0:
|
||||
dim = 2 + dim
|
||||
|
||||
pingpong_size = pingpong_size / data_type # number of data
|
||||
|
||||
block_size_refine = block_size[dim]
|
||||
all_block_number = weight_size[dim] // block_size_refine
|
||||
|
||||
if all_block_number % core_number == 0:
|
||||
#均分,这种情况不管有没有超,都无需调整
|
||||
return block_size_refine
|
||||
|
||||
block_number_tiny = all_block_number // core_number
|
||||
block_number_big = all_block_number // core_number + 1
|
||||
if block_number_tiny * block_size_refine * weight_size[1-dim] >= pingpong_size or \
|
||||
block_number_big * block_size_refine * weight_size[1-dim] <= pingpong_size :
|
||||
# 小的已经超了,无法再调整了
|
||||
# 大的没有超,无需调整
|
||||
return block_size_refine
|
||||
|
||||
all_block_number_tmp = all_block_number
|
||||
block_size_refine_tmp = block_size_refine
|
||||
for iter_index in range(max_iter_number):
|
||||
all_block_number_tmp = all_block_number_tmp * 2
|
||||
block_size_refine_tmp = block_size_refine_tmp // 2
|
||||
if all_block_number_tmp % core_number == 0:
|
||||
block_number_tiny = all_block_number // core_number
|
||||
if block_number_tiny * block_size_refine_tmp * weight_size[1-dim] <= pingpong_size:
|
||||
return block_size_refine_tmp
|
||||
else:
|
||||
#均分还是超了,无需调整
|
||||
return block_size_refine
|
||||
else:
|
||||
block_number_big = all_block_number_tmp // core_number + 1
|
||||
if block_number_big * block_size_refine_tmp * weight_size[1-dim] <= pingpong_size:
|
||||
return block_size_refine_tmp
|
||||
|
||||
return block_size_refine
|
||||
|
||||
def ColumnParallelLinear__init__(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[List[int]] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||
if not disable_tp else 0)
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
super(ColumnParallelLinear,self).__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
self.scale_n = 1
|
||||
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
|
||||
gcd_value = quant_config.weight_block_size[0]
|
||||
|
||||
import math
|
||||
if hasattr(self, "output_sizes"):
|
||||
# 对于Merge类型的ColumnParallelLinear来说,需要根据每个Part Linear的shape,去计算最小公约数
|
||||
output_size_no_merge = self.output_partition_sizes
|
||||
block_values = [o % quant_config.weight_block_size[0] for o in output_size_no_merge]
|
||||
is_gcd_recompute = sum(block_values)
|
||||
|
||||
if is_gcd_recompute:
|
||||
import math
|
||||
block_values.append(quant_config.weight_block_size[0])
|
||||
gcd_value = math.gcd(*block_values)
|
||||
# Notice:
|
||||
# 这儿对于非对齐的Part-Weight, 可能需要验证一下流程
|
||||
# 对于DeepSeek来说,仅存在于MLP&MOE中的MergeColumnLinear,都是Shape一致的PartWeight
|
||||
# 对于QWen3来说,会存在QKVColumnLinear,是Shape不一致的PartWeight,但是由于QWen3当下的切分方案,对于gcd_value无感,无需重计算所以暂时不会进来
|
||||
if hasattr(self, "output_sizes") and len(output_size_no_merge) == 2 and output_size_no_merge[0] == output_size_no_merge[1]:
|
||||
#only refine mlp w13
|
||||
gcd_value = refine_block([gcd_value, quant_config.weight_block_size[1]], [output_size_no_merge[0], input_size])
|
||||
self.scale_n =self.scale_n * quant_config.weight_block_size[0] // gcd_value
|
||||
else:
|
||||
# 对于非Merge的ColumnParallelLinear来说, 仅仅根据当下shape去计算最小公约数
|
||||
output_size_no_merge = self.output_size_per_partition
|
||||
is_gcd_recompute = output_size_no_merge % quant_config.weight_block_size[0]
|
||||
if is_gcd_recompute:
|
||||
gcd_value = math.gcd(output_size_no_merge % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
|
||||
self.scale_n =self.scale_n * quant_config.weight_block_size[0] // gcd_value
|
||||
|
||||
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
scale_n = self.scale_n,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.update_param_tp_status()
|
||||
|
||||
def ColumnParallelLinear_weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
|
||||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||||
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[-1]])
|
||||
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,n,k] -> [n*scale_n, k]
|
||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
def weight_loader_v2(self,
|
||||
param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None):
|
||||
|
||||
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
|
||||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||||
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[-1]])
|
||||
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,n,k] -> [n*scale_n, k]
|
||||
|
||||
if self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
|
||||
if self.quant_method.scale_k > 1 and len(loaded_weight.shape) == 2 and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||||
loaded_weight = loaded_weight.unsqueeze(1) #[k,1,n]
|
||||
loaded_weight = loaded_weight.expand(loaded_weight.shape[0], self.quant_method.scale_k, loaded_weight.shape[2]).reshape([-1, loaded_weight.shape[2]])
|
||||
#[k,1,n] -> [k,scale_k,n]] -> [k*scale_k, n]
|
||||
|
||||
if loaded_shard_id is None:
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||||
shard_id=0)
|
||||
return
|
||||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
||||
return
|
||||
# TODO: @dsikka - move to parameter.py
|
||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||
return
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8LinearMethod, Fp8MoEMethod)
|
||||
assert self.quant_method is not None
|
||||
assert isinstance(self.quant_method,
|
||||
(Fp8LinearMethod, Fp8MoEMethod))
|
||||
weight_block_size = self.quant_method.quant_config.weight_block_size
|
||||
assert weight_block_size is not None
|
||||
block_n, _ = weight_block_size[0] // self.scale_n, weight_block_size[1]
|
||||
shard_offset = (
|
||||
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
|
||||
block_n) // tp_size
|
||||
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
|
||||
block_n // tp_size)
|
||||
else:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size)
|
||||
|
||||
def RowParallelLinear__init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||
if not disable_tp else 0)
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
super(RowParallelLinear, self).__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.scale_k = 1 # quant_block_k 128 需要除以 scale_k, 如设置为2 即 quant_block_k 是 64
|
||||
self.scale_n = 1
|
||||
self.scale_n_slice = 1
|
||||
|
||||
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
|
||||
gcd_value = quant_config.weight_block_size[1]
|
||||
import math
|
||||
if self.input_size_per_partition % quant_config.weight_block_size[1]:
|
||||
gcd_value = math.gcd(self.input_size_per_partition % quant_config.weight_block_size[1], quant_config.weight_block_size[1])
|
||||
self.scale_k =self.scale_k * quant_config.weight_block_size[1] // gcd_value
|
||||
if output_size % quant_config.weight_block_size[0]:
|
||||
gcd_value = math.gcd(output_size % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
|
||||
self.scale_n = self.scale_n * quant_config.weight_block_size[0] // gcd_value
|
||||
self.scale_n_slice = output_size // gcd_value
|
||||
# N = 576, block = 128, n方向scale 扩充需要知道两个信息: 1.拷贝多少份 scale_n; 2. slice 有效的 scale_n_slice
|
||||
# scale = [s0,s1,s2,s3,s4] 拷贝scale_n=2份
|
||||
# scale = [s0,s0,s1,s1,s2,s2,s3,s3,s4,s4],slice scale_n_slice=9份 =>[s0,s0,s1,s1,s2,s2,s3,s3,s4]
|
||||
|
||||
if self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
|
||||
gcd_value = quant_config.group_size
|
||||
import math
|
||||
if self.input_size_per_partition % quant_config.group_size:
|
||||
gcd_value = math.gcd(self.input_size_per_partition % quant_config.group_size, quant_config.group_size)
|
||||
self.quant_method.scale_k = self.quant_method.scale_k * quant_config.group_size // gcd_value
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=[self.output_size],
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
scale_k = self.scale_k,
|
||||
scale_n = self.scale_n,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def RowParallelLinear_weight_loader_v2_vacc(self, param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
if self.scale_k > 1 and len(loaded_weight.shape) == 2:
|
||||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||||
loaded_weight = loaded_weight.expand(self.scale_k, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,2,0).reshape([loaded_weight.shape[1], -1])
|
||||
#[1,n,k] -> [scale_k,n,k] -> [n,k,scale_k] -> [n, k*scale_k]
|
||||
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
|
||||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||||
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[2]])[:self.scale_n_slice]
|
||||
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,k] -> [n*scale_n,k]
|
||||
|
||||
elif self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
|
||||
# broadcast scale TODO: broadcast zero
|
||||
if self.quant_method.scale_k > 1 and len(loaded_weight.shape) == 2 and loaded_weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
|
||||
loaded_weight = loaded_weight.unsqueeze(1) #[k,1,n]
|
||||
loaded_weight = loaded_weight.expand(loaded_weight.shape[0], self.quant_method.scale_k, loaded_weight.shape[2]).reshape([-1, loaded_weight.shape[2]])
|
||||
#[k,1,n] -> [k,scale_k,n]] -> [k*scale_k, n]
|
||||
|
||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
class UnquantizedLinearMethod():
|
||||
"""Linear method without quantization."""
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if bias is not None:
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
parallel_embedding_output = None
|
||||
if memory_recycler is not None:
|
||||
if memory_recycler.EMBEDDING_OUT_BUFFER.size(0) == x.size(0):
|
||||
parallel_embedding_output = memory_recycler.EMBEDDING_OUT_BUFFER
|
||||
return torch.mm(x.view(-1, x.shape[-1]), layer.weight.transpose(1,0), out=parallel_embedding_output).view(*(x.shape[:-1]), layer.weight.shape[0])
|
||||
81
vllm_vacc/vllm/model_executor/layers/logits_processor.py
Normal file
81
vllm_vacc/vllm/model_executor/layers/logits_processor.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A layer that compute logits from hidden_stats."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import tensor_model_parallel_gather
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
def DumpLogits(logits):
|
||||
import os
|
||||
import time
|
||||
VLLM_VACC_DUMP_LOGITS = os.getenv('VLLM_VACC_DUMP_LOGITS')
|
||||
if VLLM_VACC_DUMP_LOGITS:
|
||||
logit_arr = logits.cpu().to(torch.float32).numpy()
|
||||
timestamp = time.time()
|
||||
# print("timestamp:", timestamp)
|
||||
if not os.path.exists(VLLM_VACC_DUMP_LOGITS):
|
||||
os.makedirs(VLLM_VACC_DUMP_LOGITS)
|
||||
logit_path = os.path.join(VLLM_VACC_DUMP_LOGITS, f'logit_{timestamp}.bin')
|
||||
summary_path = os.path.join(VLLM_VACC_DUMP_LOGITS, 'summary.txt')
|
||||
with open(summary_path, 'a') as f:
|
||||
f.write(f'{logit_path}\n')
|
||||
# print("save file:", logit_path)
|
||||
logit_arr.tofile(logit_path)
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
total_size_all_gather_input = hidden_states.size(0) * lm_head.weight.shape[0] * hidden_states.element_size() * get_tp_group().world_size
|
||||
# fuse matmul and all_gather if hidden size less equal 7168 and divisible by 32 according to dsp developer
|
||||
# if hidden_states.size(1) <= 7168 and hidden_states.size(1) % 32 == 0:
|
||||
# fuse matmul and all_gather if total size of all_gather inputs doesn't exceeds 4MB
|
||||
if total_size_all_gather_input <= 4194304:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import fused_matmul_allgather
|
||||
# from vllm.distributed.parallel_state import get_tp_group
|
||||
|
||||
logits = fused_matmul_allgather(hidden_states, lm_head.weight.T,
|
||||
get_tp_group().world_size,
|
||||
get_tp_group().rank_in_group,
|
||||
get_tp_group().group_id,
|
||||
get_tp_group().rank_device_infos)
|
||||
# ensure there is PP last stage at front
|
||||
if (hasattr(current_platform, 'supports_v1') and current_platform.supports_v1(current_platform)) or get_tp_group().rank_in_group == 0:
|
||||
seq, hidden_dims = hidden_states.shape
|
||||
logits = logits.movedim(0, 1)
|
||||
logits = logits.reshape(seq, -1)
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
if get_tp_group().rank_in_group == 0:
|
||||
DumpLogits(logits)
|
||||
else:
|
||||
logits = None
|
||||
return logits
|
||||
except Exception as e:
|
||||
print("Fused Matmul with AllGather run Fail, now use unfused. " ,e)
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.quant_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
|
||||
#print("quant method:", lm_head, lm_head.quant_method, embedding_bias, logits.shape)
|
||||
# Gather logits for TP
|
||||
logits = self._gather_logits(logits)
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
return logits
|
||||
51
vllm_vacc/vllm/model_executor/layers/pooler.py
Normal file
51
vllm_vacc/vllm/model_executor/layers/pooler.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import PoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolerActivation, get_pooling_params
|
||||
|
||||
|
||||
class ClassifierPooler(Pooler):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_size]
|
||||
|
||||
if pooled_data.dtype != self.head_dtype:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
if self.classifier is not None:
|
||||
pooled_data = self.classifier(pooled_data)
|
||||
# pooled_data shape: [batchsize, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
pooled_data -= self.logit_bias
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
flags = [p.activation for p in pooling_params]
|
||||
|
||||
if len(set(flags)) == 1:
|
||||
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
|
||||
else:
|
||||
scores = [
|
||||
self.act_fn(vecs) if f else vecs
|
||||
for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
# scores shape: [batchsize, num_labels]
|
||||
return scores
|
||||
|
||||
|
||||
class PoolerNormalize(PoolerActivation):
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return torch.vacc.l2_norm(pooled_data, epsilon=1e-12)
|
||||
@@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Literal, Type, get_args
|
||||
|
||||
QuantizationMethods = Literal[
|
||||
# "aqlm",
|
||||
"awq",
|
||||
"deepspeedfp",
|
||||
"tpu_int8",
|
||||
"fp8",
|
||||
"ptpc_fp8",
|
||||
"fbgemm_fp8",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"bitblas",
|
||||
"gguf",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"gptq_bitblas",
|
||||
"awq_marlin",
|
||||
"gptq",
|
||||
"compressed-tensors",
|
||||
"bitsandbytes",
|
||||
"hqq",
|
||||
"experts_int8",
|
||||
"ipex",
|
||||
"quark",
|
||||
"moe_wna16",
|
||||
"torchao",
|
||||
"auto-round",
|
||||
"rtn",
|
||||
"inc",
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
615
vllm_vacc/vllm/model_executor/layers/quantization/fp8.py
Normal file
615
vllm_vacc/vllm/model_executor/layers/quantization/fp8.py
Normal file
@@ -0,0 +1,615 @@
|
||||
|
||||
import functools
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_fp8_block_linear, check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale, create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter, expert_weight_is_col_major,
|
||||
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
|
||||
validate_fp8_block_shape)
|
||||
# from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
# all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||
# cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
# normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
# requantize_with_max_scale)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, all_close_1d, convert_to_channelwise,
|
||||
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, is_layer_skipped)
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
from vllm.utils import has_deep_gemm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
def Fp8LinearMethod__init(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
|
||||
if self.block_quant:
|
||||
self.block_size = self.quant_config.weight_block_size
|
||||
if self.block_quant:
|
||||
# Marlin doesn't support block-wise fp8
|
||||
self.use_marlin = False
|
||||
self.scale_k = 1
|
||||
self.scale_n = 1
|
||||
self.scale_n_prefill = 1 # only for fp8 moe
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
class Fp8LinearMethod(LinearMethodBase):
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
if self.block_quant:
|
||||
|
||||
scale_n = extra_weight_attrs.get("scale_n")
|
||||
scale_k = extra_weight_attrs.get("scale_k")
|
||||
if scale_n is not None:
|
||||
self.scale_n = scale_n
|
||||
if scale_k is not None:
|
||||
self.scale_k = scale_k
|
||||
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0] // self.scale_n ,
|
||||
self.quant_config.weight_block_size[1] // self.scale_k ,
|
||||
)
|
||||
# Required by row parallel
|
||||
if (tp_size > 1
|
||||
and input_size // input_size_per_partition == tp_size
|
||||
and input_size_per_partition % block_k != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
# Required by column parallel or enabling merged weights
|
||||
if (tp_size > 1 and output_size // output_size_per_partition
|
||||
== tp_size) or len(output_partition_sizes) > 1:
|
||||
for output_partition_size in output_partition_sizes:
|
||||
if output_partition_size % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_partition_size = "
|
||||
f"{output_partition_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
# WEIGHT
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
params_dtype)
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
if not self.block_quant:
|
||||
scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
else:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
scale = BlockQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
(output_size_per_partition + block_n - 1) // block_n,
|
||||
(input_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale_inv, _ = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv)
|
||||
else:
|
||||
weight = layer.weight.data
|
||||
weight_scale_inv = layer.weight_scale_inv.data
|
||||
|
||||
if isinstance(layer, QKVParallelLinear):
|
||||
# NOTE: for QKVParallelLinear
|
||||
# weight_scale should be divisible by 8 Dsps
|
||||
shape = weight_scale_inv.shape[0]
|
||||
repeat = 1
|
||||
while shape % 8 != 0:
|
||||
repeat *= 2
|
||||
shape = shape * repeat
|
||||
weight_scale_inv = torch.repeat_interleave(weight_scale_inv, repeats=repeat, dim=0)
|
||||
|
||||
# weight = self._maybe_pad_weight(weight)
|
||||
# if self.block_quant:
|
||||
# maybe_post_process_fp8_weight_block(
|
||||
# layer, self.cutlass_block_fp8_supported)
|
||||
|
||||
# Torch.compile cannot use Parameter subclasses.
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale_inv,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
# Note: lazy import to avoid triton import error.
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_w8a8_block_fp8_linear)
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
block_size=[layer.weight.shape[0] // layer.weight_scale_inv.shape[0], layer.weight.shape[1] // layer.weight_scale_inv.shape[1]],
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
# return apply_fp8_linear(
|
||||
# input=x,
|
||||
# weight=layer.weight,
|
||||
# weight_scale=layer.weight_scale,
|
||||
# input_scale=layer.input_scale,
|
||||
# bias=bias,
|
||||
# cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||
# # Default to using per_token quantization if cutlass is supported
|
||||
# use_per_token_if_dynamic=self.cutlass_fp8_supported)
|
||||
|
||||
def Fp8MoEMethod_init_(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
self.layer = layer
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
self.flashinfer_moe_backend = None
|
||||
|
||||
self.scale_k = 1
|
||||
self.scale_n = 1
|
||||
self.scale_n_prefill = 1
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm() or current_platform.is_vacc:
|
||||
self.use_marlin = False
|
||||
|
||||
# Check for DeepGemm support.
|
||||
self.allow_deep_gemm = False
|
||||
if envs.VLLM_USE_DEEP_GEMM:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||
elif not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using "
|
||||
" DeepGemm kernels")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(90)):
|
||||
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
||||
self.allow_deep_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
|
||||
# Check for CutlassBlockScaledGroupedGemm support.
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = False
|
||||
if not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using "
|
||||
"CutlassBlockScaledGroupedGemm kernels")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(100)):
|
||||
logger.info_once(
|
||||
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
|
||||
)
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"CutlassBlockScaledGroupedGemm not supported on the current "
|
||||
"platform.")
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
self.fused_experts = functools.partial( # type: ignore
|
||||
fused_experts,
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
allow_cutlass_block_scaled_grouped_gemm=(
|
||||
self.allow_cutlass_block_scaled_grouped_gemm))
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
|
||||
scale_n = extra_weight_attrs.get("scale_n")
|
||||
scale_n_prefill = extra_weight_attrs.get("scale_n_prefill")
|
||||
scale_k = extra_weight_attrs.get("scale_k")
|
||||
if scale_n is not None:
|
||||
self.scale_n = scale_n
|
||||
if scale_k is not None:
|
||||
self.scale_k = scale_k
|
||||
if scale_n_prefill is not None:
|
||||
self.scale_n_prefill = scale_n_prefill
|
||||
|
||||
if self.quant_config is not None and self.quant_config.weight_block_size is not None:
|
||||
self.gcd_value = self.quant_config.weight_block_size[0]
|
||||
|
||||
output_size_no_merge = intermediate_size_per_partition
|
||||
#assert isinstance(output_size_no_merge, int), f"merge output size should divded int, valuue is: {output_size_no_merge}"
|
||||
|
||||
if output_size_no_merge % self.quant_config.weight_block_size[0]:
|
||||
import math
|
||||
gcd_value = math.gcd(output_size_no_merge % self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[0])
|
||||
self.scale_n =self.scale_n * self.quant_config.weight_block_size[0] // gcd_value
|
||||
self.scale_n_prefill =self.scale_n_prefill * self.quant_config.weight_block_size[0] // gcd_value
|
||||
if hidden_size % self.quant_config.weight_block_size[1]:
|
||||
import math
|
||||
gcd_value = math.gcd(hidden_size % self.quant_config.weight_block_size[1], self.quant_config.weight_block_size[1])
|
||||
self.scale_k =self.scale_k * self.quant_config.weight_block_size[1] // gcd_value
|
||||
# self.scale_k = self.scale_n
|
||||
|
||||
# print('output_size_no_merge', output_size_no_merge)
|
||||
# 按 block_size 分core
|
||||
# output_size_no_merge = 384
|
||||
# block_size = 128: 384 = 3x128 只能分3core x 128
|
||||
# block_size = 16: 384 = 24x16 8core x (3x16) 可以分到 8core
|
||||
|
||||
# output_size_no_merge = 512
|
||||
# block_size = 128: 512 = 4x128 只能分 4core x 128
|
||||
# block_size = 64: 512 = 8x64 可以分到 8core x 64
|
||||
|
||||
# output_size_no_merge = 768
|
||||
# block_size = 128: 768 = 6x128 只能分 6core x 128
|
||||
# block_size = 32: 768 = 8x(3x32) 可以分到 8core x (3x32)
|
||||
|
||||
core_num = 8
|
||||
min_block_size = 4
|
||||
block_size_tmp = self.quant_config.weight_block_size[0] // self.scale_n
|
||||
if output_size_no_merge > block_size_tmp and \
|
||||
output_size_no_merge % block_size_tmp == 0 and \
|
||||
output_size_no_merge // block_size_tmp < core_num and \
|
||||
output_size_no_merge % core_num == 0:
|
||||
core_num_old = output_size_no_merge // block_size_tmp
|
||||
import math
|
||||
gcd_value = math.gcd(core_num, core_num_old)
|
||||
new_scale = core_num // gcd_value
|
||||
if block_size_tmp // new_scale >= min_block_size:
|
||||
self.scale_n = new_scale * self.scale_n
|
||||
|
||||
|
||||
#print("moe scale n is:", self.scale_n, self.scale_k, intermediate_size_per_partition)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.scale_n != self.scale_n_prefill:
|
||||
block_n_prefill = self.quant_config.weight_block_size[0] // self.scale_n_prefill
|
||||
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0] // self.scale_n,
|
||||
self.quant_config.weight_block_size[1] // self.scale_k,
|
||||
)
|
||||
# NOTE: To ensure proper alignment of the block-wise quantization
|
||||
# scales, the output_size of the weights for both the gate and up
|
||||
# layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size_per_partition % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
if (tp_size > 1
|
||||
and hidden_size % block_k != 0):
|
||||
# Required by row parallel
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
if not self.block_quant:
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, 2, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
else:
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size_per_partition + block_n - 1) //
|
||||
block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
(intermediate_size_per_partition + block_n - 1) // block_n,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
if self.scale_n != self.scale_n_prefill:
|
||||
w13_weight_scale_prefill = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size_per_partition + block_n_prefill - 1) //
|
||||
block_n_prefill),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale_prefill = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
(intermediate_size_per_partition + block_n_prefill - 1) // block_n_prefill,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_inv_prefill", w13_weight_scale_prefill)
|
||||
layer.register_parameter("w2_weight_scale_inv_prefill", w2_weight_scale_prefill)
|
||||
|
||||
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.
|
||||
value} if self.block_quant else
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
if self.scale_n != self.scale_n_prefill:
|
||||
set_weight_attrs(w13_weight_scale_prefill, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale_prefill, extra_weight_attrs)
|
||||
# INPUT_SCALES
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8.")
|
||||
|
||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def moe_fp8_apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import fused_experts
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
experts_output = None
|
||||
if memory_recycler is not None:
|
||||
# remove MOE_EXPERT_OUT_BUFFER
|
||||
# experts_output = memory_recycler.MOE_EXPERT_OUT_BUFFER
|
||||
experts_output = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w13_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
decode_with_batch=layer.is_decode and x.shape[0] > 1,
|
||||
output_opt=experts_output
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"vacc fused_expert run fail, now using unfused ops: {e}")
|
||||
from torch_vacc.vacc.custom_ops_cpu import fused_experts
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w13_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
282
vllm_vacc/vllm/model_executor/layers/quantization/gptq.py
Normal file
282
vllm_vacc/vllm/model_executor/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig as GPTQConfigOrig
|
||||
from vllm.model_executor.layers.quantization.gptq import ExllamaState
|
||||
from vllm_vacc.vllm.model_executor.models.vars import TRANSPOSE_GPTQ_WEIGHT
|
||||
import math
|
||||
|
||||
def GPTQLinearMethod__init(self, quant_config: GPTQConfigOrig):
|
||||
self.quant_config = quant_config
|
||||
self.scale_k = 1
|
||||
self.split_num = 4
|
||||
|
||||
def int32_to_int4(s0, axis = -2):
|
||||
# 要先拉平 shape[1, n]
|
||||
# 每个int32 拆成8个int4, 8个int32表示, 得到[8, n]
|
||||
|
||||
# x32(int32) => 32bit => 4bit x 8 x4[8] 4bit
|
||||
|
||||
# x32 31-28 => x4[7]
|
||||
# x32 27-24 => x4[6]
|
||||
# ...
|
||||
# x32 3-0 => x4[0]
|
||||
|
||||
# x32[index=0] => x4[7,6,5,4,3,2,1,0]
|
||||
|
||||
# 4bit转真实数字:
|
||||
# 不是按补码方式
|
||||
|
||||
# 1111 => 15 => 7
|
||||
# 15-8 = 7
|
||||
|
||||
# 0101 => 6 =>-2
|
||||
# 6-8 = -2
|
||||
|
||||
# 0x 6A CB 37 2B (内存中排列 2B 37 CB 6A) => B273BCA6 => (-8) => int4: 3, -6, -1, -5, 3, 4, 2, -2
|
||||
|
||||
# 内存中实际排布为小端模式:
|
||||
# int32: 2B 37 CB 6A => 2,11,3,7,12,11,6,10 => (-8) => -6,3, -5,-1, 4,3, -2,2 => 同一字节所在的两个交换得到 3, -6, -1, -5, 3, 4, 2, -2
|
||||
# int4: 3, -6, -1, -5, 3, 4, 2, -2
|
||||
|
||||
s = s0.view(torch.uint32)
|
||||
all = []
|
||||
for i in range(8):
|
||||
x = 15 << (i*4)
|
||||
# s2 = torch.bitwise_and(x,s)
|
||||
s2 = torch.from_numpy(np.bitwise_and(x, s.numpy()))
|
||||
s3 = s2 / (2 ** (i*4))
|
||||
s4 = s3.to(torch.int32)
|
||||
# 补码, 结果不对
|
||||
# s4[s4 > 7] = s4[s4 > 7]-16
|
||||
# 直接 - 8 结果正确, 范围: -8-7
|
||||
s4 = s4 - 8
|
||||
all.append(s4.reshape(1,*s4.shape))
|
||||
all = torch.concatenate(all, 0)
|
||||
if axis == -2 or axis == 0:
|
||||
# 8,K//8,N => K//8,8,N => K,N
|
||||
all = all.transpose(-2,0).reshape(-1,all.shape[-1]).contiguous()
|
||||
else:
|
||||
# 8,N,K//8 => N,K//8,8 => N,K
|
||||
all = all.permute(1,2,0).reshape(all.shape[-2],-1).contiguous()
|
||||
return all
|
||||
|
||||
|
||||
def dequant_weight(qw, scales, group_size = 128):
|
||||
N = qw.shape[1]
|
||||
int4_to_int32_axis = -2
|
||||
if TRANSPOSE_GPTQ_WEIGHT:
|
||||
N = qw.shape[0]
|
||||
int4_to_int32_axis = -1
|
||||
qweight = int32_to_int4(qw,int4_to_int32_axis).to(torch.float16) #int32 => 8 int4 +> fp16
|
||||
|
||||
if TRANSPOSE_GPTQ_WEIGHT:
|
||||
scales = scales.T.contiguous()
|
||||
qweight = qweight.T.contiguous()
|
||||
|
||||
scales = torch.concatenate([scales] * group_size, 1).reshape(-1, N) # scale 按 group_size 扩展, 每 group_size 个数共用一个scale
|
||||
|
||||
# print('qweight', qweight.shape, qweight.dtype)
|
||||
# print('scale', scales.shape, scales.dtype)
|
||||
|
||||
dequant_weight = qweight * scales #dequant
|
||||
return dequant_weight
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
class GPTQLinearMethod(LinearMethodBase):
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
# if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
# raise ValueError(
|
||||
# "The input size is not aligned with the quantized "
|
||||
# "weight shape. This can be caused by too large "
|
||||
# "tensor parallel size.")
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor.numerator
|
||||
!= 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
exllama_state = ExllamaState.UNINITIALIZED
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (input_size != input_size_per_partition
|
||||
and self.quant_config.group_size != -1):
|
||||
# For act-order models, we cannot use Exllama for row parallel layer
|
||||
if self.quant_config.desc_act:
|
||||
exllama_state = ExllamaState.UNUSED
|
||||
else:
|
||||
# we need to partition qzeros and scales for exllama kernel
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
g_idx = RowvLLMParameter(data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
layer.exllama_state = exllama_state
|
||||
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# for torch.compile
|
||||
# self.quant_config.weight_bits == 4
|
||||
if TRANSPOSE_GPTQ_WEIGHT:
|
||||
layer.qzeros = Parameter(layer.qzeros.data.T.contiguous(), requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data.T.contiguous(), requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data.T.contiguous(), requires_grad=False)
|
||||
else:
|
||||
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
self.quant_config.weight_bits)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=layer.g_idx.device)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
# print(f"~~~~ start dequant")
|
||||
# import time
|
||||
# start_quant_time = time.time()
|
||||
# weight = dequant_weight(layer.qweight.cpu(), layer.scales.cpu(), self.quant_config.group_size // self.scale_k).to(layer.qweight.device)
|
||||
# end_quant_time = time.time()
|
||||
# print(f"~~~~ dequant time: {end_quant_time - start_quant_time}")
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(f"~~~~ weight shape: {weight.shape}, dtype: {weight.dtype}")
|
||||
# output = torch.matmul(reshaped_x, weight)
|
||||
# print("entering GPTQLinearMethod apply, reshaped_x shape:", reshaped_x.shape, "reshaped_x stride", reshaped_x.stride(), "input_tensor", x.shape, "qweight shape:", layer.qweight.shape, "scales shape:", layer.scales.shape)
|
||||
output = torch.vacc.w4a8_block_int4_matmul(
|
||||
reshaped_x,
|
||||
layer.qweight.transpose(-1, -2),
|
||||
layer.scales.transpose(-1, -2),
|
||||
[1, self.quant_config.group_size // self.scale_k],
|
||||
)
|
||||
# print("exiting GPTQLinearMethod apply, output shape:", output.shape)
|
||||
# end_gemm_time = time.time()
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(f"~~~~ gemm time: {end_gemm_time - end_quant_time}")
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
372
vllm_vacc/vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
372
vllm_vacc/vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supports_layer)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
# [num_experts, N, K//8], int32 ==> [num_experts, N, K], int4 ==> [num_experts, N//8, K], int32
|
||||
def repack_quant_moe_weight_old(original_packed_tensor):
|
||||
num_experts = original_packed_tensor.shape[0]
|
||||
N = original_packed_tensor.shape[1]
|
||||
K = original_packed_tensor.shape[2] * 8
|
||||
if original_packed_tensor.dtype != torch.int32:
|
||||
raise ValueError("data type of input tensor should be int32")
|
||||
if N % 8 != 0:
|
||||
raise ValueError("N of input tensor should be divisible by 8")
|
||||
|
||||
# --- 1. 解包:将 int32 张量展开为逻辑上的 int4 张量 ---
|
||||
# 创建一个临时张量来存储解包后的所有 int4 值
|
||||
# 用 torch.uint8 作为 int4 的临时存储,因为 PyTorch 没有原生的 int4 dtype
|
||||
unpacked_int4_tensor = torch.zeros(
|
||||
(num_experts, N, K),
|
||||
dtype=torch.uint8,
|
||||
device=original_packed_tensor.device
|
||||
)
|
||||
mask = 0b1111
|
||||
for i in range(8):
|
||||
# 提取当前 int4 所需的 int32 块中的值
|
||||
# 通过右移 (i * 4) 位,我们将第 i 个 4 位整数移动到最低有效位
|
||||
# 然后通过按位与操作与掩码结合,提取出这 4 位的值
|
||||
extracted_int4s = (original_packed_tensor >> (i * 4)) & mask
|
||||
# 将提取出的 int4 值放置到 unpacked_int4_tensor 的正确位置
|
||||
# 使用切片 `i::8`,意思是:从索引 `i` 开始,每隔 8 个位置填充一次
|
||||
unpacked_int4_tensor[:, :, i::8] = extracted_int4s
|
||||
|
||||
# --- 2. 重新打包:将 int4 逻辑张量重新打包为新的 int32 张量 ---
|
||||
new_packed_tensor = torch.zeros(
|
||||
(num_experts, N//8, K),
|
||||
dtype=torch.int32,
|
||||
device=original_packed_tensor.device
|
||||
)
|
||||
for i in range(8):
|
||||
# 从解包后的 int4 张量中提取当前需要打包的 int4 序列,使用切片 `i::8` 沿着N方向来提取
|
||||
current_int4_segment = unpacked_int4_tensor[:, i::8, :]
|
||||
# 将这个 int4 序列转换为 int32 类型(因为打包到 int32)并左移到其在新 int32 块中的正确位置
|
||||
# 然后使用按位或操作符将其合并到 new_packed_tensor 中
|
||||
new_packed_tensor |= (current_int4_segment.to(torch.int32) << (i * 4))
|
||||
|
||||
return new_packed_tensor
|
||||
|
||||
|
||||
def repack_quant_moe_weight(original_packed_tensor):
|
||||
if original_packed_tensor.dtype != torch.int32:
|
||||
raise ValueError("data type of input tensor should be int32")
|
||||
|
||||
num_experts, N, K_packed = original_packed_tensor.shape
|
||||
K = K_packed * 8
|
||||
|
||||
if N % 8 != 0:
|
||||
raise ValueError("N of input tensor should be divisible by 8")
|
||||
|
||||
new_packed_tensor = torch.zeros((num_experts, N // 8, K),
|
||||
dtype=torch.int32,
|
||||
device=original_packed_tensor.device)
|
||||
for i in range(8):
|
||||
source_slice = original_packed_tensor[:, i::8, :]
|
||||
for j in range(8):
|
||||
unpacked_strip = (source_slice >> (j * 4)) & 0b1111
|
||||
new_packed_tensor[:, :, j::8] |= (unpacked_strip.to(torch.int32) << (i * 4))
|
||||
|
||||
return new_packed_tensor
|
||||
|
||||
|
||||
class MoeWNA16Method(FusedMoEMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
self.moe = layer
|
||||
layer.quant_config = self.quant_config
|
||||
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
||||
bit32_pack_factor = 32 // self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
group_size_div_factor = 1
|
||||
group_size_w13 = self.quant_config.group_size
|
||||
group_size_div_factor_w13 = 1
|
||||
group_size_w2 = self.quant_config.group_size
|
||||
group_size_div_factor_w2 = 1
|
||||
|
||||
# make intermediate_size and hidden_size divisible by group_size
|
||||
# we reduce the group size to ensure that
|
||||
# and we would repeat the loaded_weight later
|
||||
while intermediate_size_per_partition % group_size or \
|
||||
hidden_size % group_size:
|
||||
group_size = group_size // 2
|
||||
group_size_div_factor *= 2
|
||||
assert group_size >= 32
|
||||
layer.group_size = group_size
|
||||
layer.group_size_div_factor = group_size_div_factor
|
||||
|
||||
while intermediate_size_per_partition % group_size_w2:
|
||||
group_size_w2 = group_size_w2 // 2
|
||||
group_size_div_factor_w2 *= 2
|
||||
assert group_size_w2 >= 32
|
||||
layer.w2_block_size = group_size_w2
|
||||
layer.group_size_div_factor_w2 = group_size_div_factor_w2
|
||||
|
||||
while hidden_size % group_size_w13:
|
||||
group_size_w13 = group_size_w13 // 2
|
||||
group_size_div_factor_w13 *= 2
|
||||
assert group_size_w13 >= 32
|
||||
layer.w13_block_size = group_size_w13
|
||||
layer.group_size_div_factor_w13 = group_size_div_factor_w13
|
||||
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
extra_weight_attrs.update({
|
||||
"quant_method": strategy,
|
||||
"is_transposed": False
|
||||
})
|
||||
|
||||
assert 'weight_loader' in extra_weight_attrs
|
||||
weight_loader = extra_weight_attrs['weight_loader']
|
||||
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
|
||||
layer, weight_loader)
|
||||
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
# w13_qweight = torch.nn.Parameter(torch.empty(
|
||||
# num_experts,
|
||||
# 2 * intermediate_size_per_partition,
|
||||
# hidden_size // bit8_pack_factor,
|
||||
# dtype=torch.uint8),
|
||||
# requires_grad=False)
|
||||
w13_qweight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // bit32_pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
# w2_qweight = torch.nn.Parameter(torch.empty(
|
||||
# num_experts,
|
||||
# hidden_size,
|
||||
# intermediate_size_per_partition // bit8_pack_factor,
|
||||
# dtype=torch.uint8),
|
||||
# requires_grad=False)
|
||||
w2_qweight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // bit32_pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
w13_scales = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // group_size_w13,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
|
||||
w2_scales = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // group_size_w2,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.has_zp:
|
||||
w13_qzeros = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition // bit8_pack_factor,
|
||||
hidden_size // group_size,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
w2_qzeros = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size // bit8_pack_factor,
|
||||
intermediate_size_per_partition // group_size,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.linear_quant_method == "gptq":
|
||||
# some param are unused, but we need to init them in order to
|
||||
# load weights
|
||||
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
|
||||
if not self.quant_config.has_zp:
|
||||
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
|
||||
for key in invalid_param_keys:
|
||||
param = torch.nn.Parameter(torch.empty((0, ),
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter(key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
@staticmethod
|
||||
def get_weight_loader(layer, weight_loader):
|
||||
|
||||
def convert_awq_tensor(tensor, tensor_type):
|
||||
# convert awq qweight/qzeros to a standard format (assume int4)
|
||||
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
|
||||
# qzeros: (k // group_size, n // pack_factor_bit32) ->
|
||||
# (n // pack_factor_bit8, k // group_size)
|
||||
# pack_factor_bit32 = 32 // weight_bits
|
||||
# pack_factor_bit8 = 8 // weight_bits
|
||||
|
||||
# 0. suppose origin shape (a, b), dtype int32
|
||||
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
|
||||
size0 = tensor.size(0)
|
||||
tensor = tensor.view(torch.uint8)
|
||||
|
||||
# 2. unpack to uint4 (only when weight_bits == 4)
|
||||
# shape (a, 4 * b) -> (a, 4 * b, 2)
|
||||
shifter = torch.tensor([0, 4],
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
|
||||
# 3. change order, see
|
||||
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
|
||||
# shape -> (a, 4 * b * pack_factor_bit8)
|
||||
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
|
||||
tensor = tensor.view(size0, -1)
|
||||
|
||||
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
|
||||
tensor = tensor.T.contiguous()
|
||||
|
||||
# 5. repack (only when weight_bits == 4)
|
||||
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
|
||||
# qzeros shape -> (4 * b, a)
|
||||
|
||||
if tensor_type == "qweight":
|
||||
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
|
||||
elif tensor_type == "qzeros":
|
||||
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
|
||||
return tensor
|
||||
|
||||
def convert_gptq_int4_qzeros(tensor):
|
||||
tensor = tensor.view(torch.uint8)
|
||||
shifter = torch.tensor([0, 4],
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
tensor = tensor + 1
|
||||
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
|
||||
return tensor
|
||||
|
||||
def moe_wna16_weight_loader(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
return_success: bool = False):
|
||||
if "g_idx" in weight_name:
|
||||
return False if return_success else None
|
||||
if not layer.quant_config.has_zp and "qzeros" in weight_name:
|
||||
return False if return_success else None
|
||||
|
||||
device = get_tp_group().device
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
loaded_weight = loaded_weight.to(device)
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
|
||||
# convert gptq and awq weight to a standard format
|
||||
if layer.quant_config.linear_quant_method == "awq":
|
||||
assert layer.quant_config.weight_bits == 4
|
||||
if "weight" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight,
|
||||
"qweight")
|
||||
elif "zeros" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
|
||||
else:
|
||||
loaded_weight = loaded_weight.T
|
||||
elif layer.quant_config.linear_quant_method == "gptq":
|
||||
assert layer.quant_config.weight_bits in [4, 8]
|
||||
if "weight" in weight_name:
|
||||
# loaded_weight = loaded_weight.T.contiguous().view(
|
||||
# torch.uint8)
|
||||
loaded_weight = loaded_weight.T.contiguous()
|
||||
elif "zeros" in weight_name:
|
||||
# add 1 to gptq qzeros to align with awq
|
||||
loaded_weight = loaded_weight.view(torch.uint8)
|
||||
if layer.quant_config.weight_bits == 4:
|
||||
loaded_weight = convert_gptq_int4_qzeros(
|
||||
loaded_weight).T
|
||||
else:
|
||||
loaded_weight = loaded_weight.T + 1
|
||||
else:
|
||||
# loaded_weight = loaded_weight.T
|
||||
loaded_weight = loaded_weight.T.contiguous()
|
||||
|
||||
# repeat the qzeros/scales to fit new group size
|
||||
if layer.group_size_div_factor_w13 > 1 and \
|
||||
"qzeros" in weight_name or "scales" in weight_name and \
|
||||
shard_id == "w1" or shard_id == "w3":
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor_w13, 1)
|
||||
elif layer.group_size_div_factor_w2 > 1 and \
|
||||
"qzeros" in weight_name or "scales" in weight_name and \
|
||||
shard_id == "w2":
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor_w2, 1)
|
||||
elif layer.group_size_div_factor > 1 and \
|
||||
"qzeros" in weight_name or "scales" in weight_name:
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor, 1)
|
||||
|
||||
if "w13_qzeros" in weight_name:
|
||||
tensor = loaded_weight.view(layer.tp_size, -1,
|
||||
loaded_weight.size(1))[tp_rank]
|
||||
if shard_id == "w1":
|
||||
param.data[expert_id, :shard_size // 2] = tensor
|
||||
else:
|
||||
param.data[expert_id, shard_size // 2:] = tensor
|
||||
return True if return_success else None
|
||||
elif "w2_qzeros" in weight_name:
|
||||
param.data[expert_id] = loaded_weight.view(
|
||||
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
|
||||
return True if return_success else None
|
||||
else:
|
||||
# Delegate to the original loader, passing return_success
|
||||
return weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
shard_id,
|
||||
expert_id,
|
||||
return_success=return_success)
|
||||
|
||||
return moe_wna16_weight_loader
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
dev_w2 = layer.w2_qweight.device
|
||||
# torch.Size([128, 2048, 24]), torch.int32, strides: (49152, 24, 1)
|
||||
# ======>
|
||||
# torch.Size([128, 256, 192]), torch.int32, strides: (49152, 1, 256)
|
||||
layer.w2_qweight = torch.nn.Parameter(repack_quant_moe_weight(layer.w2_qweight.cpu()).transpose(-1, -2).contiguous().transpose(-1, -2).to(device=dev_w2), requires_grad=False)
|
||||
# torch.Size([128, 2048, 3]), torch.float16, strides: (6144, 3, 1)
|
||||
# ======>
|
||||
# torch.Size([128, 2048, 3]), torch.float16, strides: (6144, 1, 2048)
|
||||
layer.w2_scales = torch.nn.Parameter(layer.w2_scales.transpose(-1, -2).contiguous().transpose(-1, -2), requires_grad=False)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
from typing import List, Optional
|
||||
|
||||
def _apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = True,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
assert len(block_size) == 2, "only support dim2 block now"
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import w8a8_block_fp8_linear
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
|
||||
mla_oproj_output = None
|
||||
if memory_recycler is not None:
|
||||
os1, os2 = memory_recycler.MLA_OPROJ_OUT_BUFFER.shape
|
||||
if os1 == input_2d.size(0) and os2 == weight.size(0):
|
||||
mla_oproj_output = memory_recycler.MLA_OPROJ_OUT_BUFFER
|
||||
|
||||
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size, output = mla_oproj_output)
|
||||
except Exception as e:
|
||||
print("vacc fuse fp8 matmul run fail:", e, " , now use unfused ops")
|
||||
from torch_vacc.vacc.custom_ops_cpu import w8a8_block_fp8_linear
|
||||
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size)
|
||||
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
@@ -0,0 +1,9 @@
|
||||
from .rotary_embedding import (
|
||||
RotaryEmbedding_init_vacc,
|
||||
RotaryEmbedding_forward_vacc,
|
||||
ScalingRotaryEmbedding_forward_vacc,
|
||||
_compute_inv_freq_vacc,
|
||||
_deepseek_compute_cos_sin_cache_vacc,
|
||||
_yarn_compute_cos_sin_cache_vacc,
|
||||
_compute_cos_sin_cache_vacc
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
101
vllm_vacc/vllm/model_executor/layers/rotary_embedding/mrope.py
Normal file
101
vllm_vacc/vllm/model_executor/layers/rotary_embedding/mrope.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import itertools
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MRotaryEmbedding:
|
||||
@classmethod
|
||||
def _qwen3vl_get_input_positions_tensor(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw
|
||||
for _ in range(t)]
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
vision_start_token_id = hf_config.vision_start_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
if image_token_id in input_tokens and remain_images > 0:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if video_token_id in input_tokens and remain_videos > 0:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_merge_size, w // spatial_merge_size
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).flatten()
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
llm_grid_t, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
llm_grid_t, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 -
|
||||
len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
return llm_positions, mrope_position_delta
|
||||
@@ -0,0 +1,203 @@
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# from vllm.model_executor.layers.rotary_embedding import _apply_rotary_emb
|
||||
# from vllm.model_executor.layers.rotary_embedding import _yarn_find_correction_range, _yarn_linear_ramp_mask
|
||||
from vllm.model_executor.layers.rotary_embedding.common import yarn_find_correction_range as _yarn_find_correction_range
|
||||
from vllm.model_executor.layers.rotary_embedding.common import yarn_linear_ramp_mask as _yarn_linear_ramp_mask
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from ...ops.mrope_op import get_sin_cos_mrope
|
||||
|
||||
def RotaryEmbedding_init_vacc(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super(CustomOp, self).__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
# cache = self._compute_cos_sin_cache()
|
||||
cos, sin = self._compute_cos_sin_cache()
|
||||
cos = cos.to(dtype)
|
||||
sin = sin.to(dtype)
|
||||
|
||||
self.register_buffer("cos_cache", cos, persistent=False)
|
||||
self.register_buffer("sin_cache", sin, persistent=False)
|
||||
|
||||
# cache = cache.to(dtype)
|
||||
# self.cos_sin_cache: torch.Tensor
|
||||
# self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def RotaryEmbedding_forward_vacc(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-vacc implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
num_tokens = positions.numel()
|
||||
# positions = positions.flatten()
|
||||
# num_tokens = positions.shape[0]
|
||||
# cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
# cos_sin = self.cos_sin_cache[positions]
|
||||
# cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
|
||||
if isinstance(self, MRotaryEmbedding):
|
||||
# get mrope sin/cos
|
||||
cos, sin = get_sin_cos_mrope(self, positions)
|
||||
num_tokens = num_tokens//3
|
||||
else:
|
||||
positions = positions.flatten()
|
||||
cos = self.cos_cache[positions]
|
||||
sin = self.sin_cache[positions]
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
mode = "neox"
|
||||
if not self.is_neox_style:
|
||||
mode = "gptj"
|
||||
query_rot, key_rot=torch.vacc.RotaryPosEmbedding(query_rot, key_rot, cos, sin, 0, mode)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
return query, key
|
||||
|
||||
def ScalingRotaryEmbedding_forward_vacc(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
# self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
||||
# positions.device)
|
||||
# cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||
# if offsets is not None else positions]
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
positions = positions.flatten()
|
||||
|
||||
# cos_sin = self.cos_sin_cache[positions]
|
||||
# cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = self.cos_cache[positions]
|
||||
sin = self.sin_cache[positions]
|
||||
|
||||
# TODO: to be removed (require odsp support)
|
||||
# if self.is_neox_style:
|
||||
# # NOTE(woosuk): Here we assume that the positions tensor has the
|
||||
# # shape [batch_size, seq_len].
|
||||
# cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
# sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
# else:
|
||||
# cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
# sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
# rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
||||
mode = "neox" if self.is_neox_style else "gptj"
|
||||
# query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||
# key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
query_rot, key_rot=torch.vacc.RotaryPosEmbedding(query_rot, key_rot, cos, sin, 0, mode)
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
else:
|
||||
query = query_rot
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
def _compute_inv_freq_vacc(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device=current_platform.device_type) /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2,
|
||||
dtype=torch.float)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
def _deepseek_compute_cos_sin_cache_vacc(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
device=current_platform.device_type,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
# NOTE: for odsp friendly
|
||||
# seperate cos/sin cache can gurantee cos/sin
|
||||
# always has contigous layout for dim[-1]
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
return cos, sin
|
||||
# cache = torch.cat((cos, sin), dim=-1)
|
||||
# return cache
|
||||
|
||||
def _yarn_compute_cos_sin_cache_vacc(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
device=current_platform.device_type,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
return cos, sin
|
||||
# cache = torch.cat((cos, sin), dim=-1)
|
||||
# return cache
|
||||
|
||||
def _compute_cos_sin_cache_vacc(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings,
|
||||
device=current_platform.device_type,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
# NOTE: for odsp friendly
|
||||
# seperate cos/sin cache can gurantee cos/sin
|
||||
# always has contigous layout for dim[-1]
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
return cos, sin
|
||||
# cache = torch.cat((cos, sin), dim=-1)
|
||||
# return cache
|
||||
|
||||
|
||||
# import vllm.model_executor.layers.rotary_embedding as rotary_embedding
|
||||
# rotary_embedding.RotaryEmbedding.forward_vacc=RotaryEmbedding_forward_vacc
|
||||
# rotary_embedding.DeepseekScalingRotaryEmbedding._compute_inv_freq=_compute_inv_freq_vacc
|
||||
# rotary_embedding.DeepseekScalingRotaryEmbedding._compute_cos_sin_cache=_compute_cos_sin_cache_vacc
|
||||
|
||||
542
vllm_vacc/vllm/model_executor/layers/sampler.py
Normal file
542
vllm_vacc/vllm/model_executor/layers/sampler.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
import itertools
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from math import inf
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingTensors,
|
||||
SequenceGroupToSample)
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
CompletionSequenceGroupOutput, Logprob,
|
||||
PromptLogprobs, SampleLogprobs, SequenceOutput)
|
||||
from vllm.model_executor.layers.sampler import (SamplerOutput,
|
||||
_apply_min_tokens_penalty,
|
||||
_apply_top_k_top_p,
|
||||
_apply_min_p,
|
||||
_sample,
|
||||
SampleResultArgsType,
|
||||
get_logprobs,
|
||||
_build_sampler_output,
|
||||
SampleReturnType,
|
||||
SampleResultsDictType,
|
||||
SampleMetadataType,
|
||||
MultinomialSamplesType,
|
||||
_modify_greedy_probs_inplace,
|
||||
_top_k_top_p_multinomial_with_flashinfer,
|
||||
_multinomial,
|
||||
get_pythonized_sample_results,
|
||||
)
|
||||
from vllm_vacc.vllm.model_executor.models.vars import USE_DS3_SAMPLER as use_ds3_sampler
|
||||
from vllm_vacc.vllm.model_executor.models.vars import USE_DS3_SAMPLER_OP as use_ds3_sampler_op
|
||||
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
||||
import flashinfer.sampling
|
||||
# yapf: disable
|
||||
from flashinfer.sampling import (
|
||||
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
|
||||
|
||||
# yapf: enable
|
||||
else:
|
||||
flashinfer_top_k_top_p_sampling = None
|
||||
|
||||
class SamplerOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""For each sequence group, we generate a list of SequenceOutput object,
|
||||
each of which contains one possible candidate for the next token.
|
||||
|
||||
This data structure implements methods, so it can be used like a list, but
|
||||
also has optional fields for device tensors.
|
||||
"""
|
||||
|
||||
outputs: List[CompletionSequenceGroupOutput]
|
||||
|
||||
# On-device tensor containing probabilities of each token.
|
||||
sampled_token_probs: Optional[torch.Tensor] = None
|
||||
|
||||
# On-device tensor containing the logprobs of each token.
|
||||
logprobs: Optional["torch.Tensor"] = None
|
||||
|
||||
# Holds either (1) the pythonized sampler result (single-step scheduling)
|
||||
# or (2) what will be arguments for later deferred pythonization of the
|
||||
# sampler result (muliti-step scheduling)
|
||||
deferred_sample_results_args: Optional[SampleResultArgsType] = None
|
||||
|
||||
# On-device tensor containing the sampled token ids.
|
||||
sampled_token_ids: Optional[torch.Tensor] = None
|
||||
# CPU tensor containing the sampled token ids. Used during multi-step to
|
||||
# return the sampled token ids from last rank to AsyncLLMEngine to be
|
||||
# 'broadcasted' to all other PP ranks for next step.
|
||||
sampled_token_ids_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
# On-device tensor containing the sampled token embeddings (embeddings
|
||||
# corresponding to the sampled token ids). Used when prompt embeddings are
|
||||
# specified in lieu of prompt token ids or text.
|
||||
sampled_token_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
# Optional last hidden states from the model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
# Optional prefill hidden states from the model
|
||||
# (used for models like EAGLE).
|
||||
prefill_hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
# Time taken in the forward pass for this across all workers
|
||||
model_forward_time: Optional[float] = None
|
||||
|
||||
# Time taken in the model execute function. This will include model forward,
|
||||
# block/sync across workers, cpu-gpu sync time and sampling time.
|
||||
model_execute_time: Optional[float] = None
|
||||
|
||||
def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
|
||||
return self.outputs[idx]
|
||||
|
||||
def __setitem__(self, idx: int, value):
|
||||
self.outputs[idx] = value
|
||||
|
||||
def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
|
||||
return iter(self.outputs)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.outputs)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
return isinstance(other,
|
||||
self.__class__) and self.outputs == other.outputs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Show the shape of a tensor instead of its values to reduce noise.
|
||||
"""
|
||||
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
|
||||
else self.sampled_token_probs.shape)
|
||||
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
|
||||
self.sampled_token_ids.shape)
|
||||
return (
|
||||
f"SamplerOutput(outputs={self.outputs}, "
|
||||
f"sampled_token_probs={sampled_token_probs_repr}, "
|
||||
f"sampled_token_ids={sampled_token_ids_repr},")
|
||||
|
||||
|
||||
def Sampler_forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
"""
|
||||
Single-step scheduling:
|
||||
* Perform GPU-side sampling computation & compute
|
||||
GPU-side logprobs tensor
|
||||
* Pythonize sampling result & logprobs tensor
|
||||
|
||||
Multi-step scheduling:
|
||||
* Perform GPU-side sampling computation & compute
|
||||
GPU-side logprobs tensor
|
||||
* Defer Pythonization of sampling result & logprobs
|
||||
tensor
|
||||
* Encapsulate arguments required for deferred Pythonization
|
||||
in the :class:`SamplerOutput` structure
|
||||
|
||||
Args:
|
||||
logits: (num_tokens, vocab_size).
|
||||
sampling_metadata: Metadata for sampling.
|
||||
"""
|
||||
|
||||
assert logits is not None
|
||||
# print(f'Sampler_forward all_greedy={all_greedy}')
|
||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||
if not sampling_metadata.reuse_sampling_tensors:
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
elif self._do_penalties:
|
||||
# In this case, the sampling tensors logic depends on
|
||||
# "output_tokens" of a sequence. As a result, we cannot
|
||||
# reuse sampling tensors, since "output_tokens" changes
|
||||
# between decode runs.
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
|
||||
assert self._sampling_tensors is not None
|
||||
sampling_tensors = self._sampling_tensors
|
||||
do_penalties = self._do_penalties
|
||||
do_top_p_top_k = self._do_top_p_top_k
|
||||
do_min_p = self._do_min_p
|
||||
|
||||
is_greedy = (len(sampling_metadata.categorized_sample_indices[SamplingType.GREEDY]) == logits.shape[0])
|
||||
is_random = (len(sampling_metadata.categorized_sample_indices[SamplingType.RANDOM]) == logits.shape[0])
|
||||
is_random_seed = (len(sampling_metadata.categorized_sample_indices[SamplingType.RANDOM_SEED]) == logits.shape[0])
|
||||
|
||||
|
||||
max_n_in_batch = sampling_metadata.seq_groups[0].sampling_params.n
|
||||
generator = sampling_metadata.seq_groups[0].generator
|
||||
min_tokens = sampling_metadata.seq_groups[0].sampling_params.min_tokens
|
||||
# print("use_ds3_sampler ", use_ds3_sampler)
|
||||
if use_ds3_sampler == True and (is_greedy == True or ((is_random == True or is_random_seed == True) \
|
||||
and do_penalties == False \
|
||||
and flashinfer_top_k_top_p_sampling is None \
|
||||
and min_tokens <= 0 \
|
||||
and do_min_p == False \
|
||||
and max_n_in_batch == 1 \
|
||||
# and self._should_modify_greedy_probs_inplace == False
|
||||
# and self.include_gpu_probs_tensor == False
|
||||
)):
|
||||
sampling_type = SamplingType.GREEDY
|
||||
sample_metadata: SampleMetadataType = {}
|
||||
multinomial_samples: MultinomialSamplesType = {}
|
||||
greedy_samples: Optional[torch.Tensor] = None
|
||||
multinomial_out: Optional[torch.Tensor] = None
|
||||
vacc_device = logits.device
|
||||
# Create output tensor for sampled token ids.
|
||||
if self.include_gpu_probs_tensor:
|
||||
sampled_token_ids_tensor = torch.full((logits.shape[0], 1),
|
||||
VLLM_INVALID_TOKEN_ID,
|
||||
dtype=torch.long,
|
||||
device=vacc_device)
|
||||
probs_out = torch.empty_like(logits)
|
||||
logprobs_out = torch.empty_like(logits)
|
||||
else:
|
||||
probs_out = None
|
||||
logprobs_out = None
|
||||
sampled_token_ids_tensor = None
|
||||
if is_greedy == True:
|
||||
greedy_samples, _ = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 0)
|
||||
sampling_type = SamplingType.GREEDY
|
||||
if sampled_token_ids_tensor is not None:
|
||||
# Store sampled tokens in output tensor.
|
||||
sampled_token_ids_tensor = greedy_samples.unsqueeze(-1).to(torch.long)
|
||||
if probs_out is not None:
|
||||
# probs_out = torch.softmax(logits.to(torch.float), dim=-1, dtype=torch.float).to(logits)
|
||||
probs_out = torch.softmax(logits, dim=-1)
|
||||
if self._should_modify_greedy_probs_inplace == True:
|
||||
sample_indices = (sampling_metadata.categorized_sample_indices[SamplingType.GREEDY]).long()
|
||||
probs_out[sample_indices, :] = 0
|
||||
probs_out[sample_indices, greedy_samples] = 1.0
|
||||
elif is_random == True and do_top_p_top_k == True:
|
||||
if use_ds3_sampler_op:
|
||||
logits = logits.to(torch.float)
|
||||
multinomial_out, probs_out = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 2)
|
||||
multinomial_out = multinomial_out.view(-1, max_n_in_batch)
|
||||
else:
|
||||
logits = logits.to(torch.float)
|
||||
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
|
||||
logits = torch.vacc.topk_topp(logits, sampling_tensors.top_ps, sampling_tensors.top_ks)
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
probs_out = probs
|
||||
# multinomial_out = torch.multinomial(probs, 1)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
multinomial_out = probs.div_(q).argmax(dim=1).view(-1, max_n_in_batch)
|
||||
sampling_type = SamplingType.RANDOM
|
||||
elif is_random_seed == True and generator is not None and do_top_p_top_k == True:
|
||||
if use_ds3_sampler_op:
|
||||
# print("is_random_seed ", is_random_seed)
|
||||
logits = logits.to(torch.float)
|
||||
multinomial_out, probs_out = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 1, generator)
|
||||
multinomial_out = multinomial_out.view(-1, max_n_in_batch)
|
||||
else:
|
||||
logits = logits.to(torch.float)
|
||||
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
|
||||
logits = torch.vacc.topk_topp(logits, sampling_tensors.top_ps, sampling_tensors.top_ks).to(torch.float)
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
probs_out = probs
|
||||
# torch.manual_seed(sampling_metadata.seq_groups[0].sampling_params.seed)
|
||||
# multinomial_out = torch.multinomial(probs, 1)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_(generator=generator)
|
||||
multinomial_out = probs.div_(q).argmax(dim=1).view(-1, max_n_in_batch)
|
||||
sampling_type = SamplingType.RANDOM_SEED
|
||||
|
||||
multinomial_samples[sampling_type] = multinomial_out
|
||||
|
||||
if sampled_token_ids_tensor is not None:
|
||||
if(sampling_type != SamplingType.GREEDY):
|
||||
# Store sampled tokens in output tensor.
|
||||
sampled_token_ids_tensor = multinomial_samples[sampling_type].to(torch.long)
|
||||
|
||||
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
sampling_params = seq_group.sampling_params
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
|
||||
seq_group_id = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
||||
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
|
||||
sample_results_dict: SampleResultsDictType = {}
|
||||
|
||||
maybe_deferred_args = SampleResultArgsType(
|
||||
sampling_metadata=sampling_metadata,
|
||||
sample_metadata=sample_metadata,
|
||||
multinomial_samples=multinomial_samples,
|
||||
greedy_samples=greedy_samples,
|
||||
# beam_search_logprobs=None,
|
||||
sample_results_dict=sample_results_dict)
|
||||
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
# GPU<->CPU sync happens here.
|
||||
# This also converts the sampler output to a Python object.
|
||||
# Return Pythonized sampler result & sampled token ids
|
||||
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = get_pythonized_sample_results(
|
||||
maybe_deferred_args), sampled_token_ids_tensor
|
||||
else:
|
||||
# Defer sampler result Pythonization; return deferred
|
||||
# Pythonization args & sampled token ids
|
||||
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = (
|
||||
maybe_deferred_args,
|
||||
sampled_token_ids_tensor,
|
||||
)
|
||||
|
||||
if self.include_gpu_probs_tensor:
|
||||
on_device_tensors = (probs_out, logprobs_out, maybe_sampled_tokens_tensor)
|
||||
else:
|
||||
on_device_tensors = None
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs = None
|
||||
sample_logprobs = None
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
||||
assert not isinstance(maybe_deferred_sample_results,
|
||||
SampleResultArgsType)
|
||||
logprobs = logits
|
||||
prompt_logprobs, sample_logprobs = get_logprobs(
|
||||
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
||||
|
||||
return _build_sampler_output(
|
||||
maybe_deferred_sample_results,
|
||||
sampling_metadata,
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
on_device_tensors=on_device_tensors,
|
||||
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
|
||||
|
||||
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
# if do_penalties:
|
||||
# logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
|
||||
# sampling_tensors.output_tokens,
|
||||
# sampling_tensors.presence_penalties.to(logits.device),
|
||||
# sampling_tensors.frequency_penalties.to(logits.device),
|
||||
# sampling_tensors.repetition_penalties.to(logits.device))
|
||||
|
||||
# Use float32 to apply temperature scaling.
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits = logits.to(torch.float)
|
||||
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
|
||||
|
||||
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
|
||||
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps.to(logits.device),
|
||||
sampling_tensors.top_ks.to(logits.device))
|
||||
|
||||
if do_min_p:
|
||||
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
||||
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
|
||||
probs,
|
||||
logprobs,
|
||||
sampling_metadata,
|
||||
sampling_tensors,
|
||||
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
||||
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
||||
)
|
||||
|
||||
if self.include_gpu_probs_tensor:
|
||||
# Since we will defer sampler result Pythonization,
|
||||
# preserve GPU-side tensors in support of later
|
||||
# deferred pythonization of logprobs
|
||||
assert maybe_sampled_tokens_tensor is not None
|
||||
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
||||
else:
|
||||
# Since Pythonization has already happened, don't preserve
|
||||
# GPU-side tensors.
|
||||
on_device_tensors = None
|
||||
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs = None
|
||||
sample_logprobs = None
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
||||
assert not isinstance(maybe_deferred_sample_results,
|
||||
SampleResultArgsType)
|
||||
prompt_logprobs, sample_logprobs = get_logprobs(
|
||||
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
||||
|
||||
return _build_sampler_output(
|
||||
maybe_deferred_sample_results,
|
||||
sampling_metadata,
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
on_device_tensors=on_device_tensors,
|
||||
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
|
||||
|
||||
def rejection_forward(
|
||||
self,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
|
||||
) -> torch.Tensor:
|
||||
if seeded_seqs is None:
|
||||
out, index = torch.vacc.rejection_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, 1)
|
||||
else:
|
||||
out, index = torch.vacc.rejection_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, 0, seeded_seqs[0])
|
||||
return out
|
||||
|
||||
class Sampler(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
"""
|
||||
Single-step scheduling:
|
||||
* Perform GPU-side sampling computation & compute
|
||||
GPU-side logprobs tensor
|
||||
* Pythonize sampling result & logprobs tensor
|
||||
|
||||
Multi-step scheduling:
|
||||
* Perform GPU-side sampling computation & compute
|
||||
GPU-side logprobs tensor
|
||||
* Defer Pythonization of sampling result & logprobs
|
||||
tensor
|
||||
* Encapsulate arguments required for deferred Pythonization
|
||||
in the :class:`SamplerOutput` structure
|
||||
|
||||
Args:
|
||||
logits: (num_tokens, vocab_size).
|
||||
sampling_metadata: Metadata for sampling.
|
||||
"""
|
||||
assert logits is not None
|
||||
_, vocab_size = logits.shape
|
||||
|
||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||
if not sampling_metadata.reuse_sampling_tensors:
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
elif self._do_penalties:
|
||||
# In this case, the sampling tensors logic depends on
|
||||
# "output_tokens" of a sequence. As a result, we cannot
|
||||
# reuse sampling tensors, since "output_tokens" changes
|
||||
# between decode runs.
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
|
||||
assert self._sampling_tensors is not None
|
||||
sampling_tensors = self._sampling_tensors
|
||||
do_penalties = self._do_penalties
|
||||
do_top_p_top_k = self._do_top_p_top_k
|
||||
do_min_p = self._do_min_p
|
||||
|
||||
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
if do_penalties:
|
||||
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
|
||||
sampling_tensors.output_tokens,
|
||||
sampling_tensors.presence_penalties,
|
||||
sampling_tensors.frequency_penalties,
|
||||
sampling_tensors.repetition_penalties)
|
||||
|
||||
# Use float32 to apply temperature scaling.
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits = logits.to(torch.float)
|
||||
# print("tempratures is:", temperatures)
|
||||
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1).to(logits.device))
|
||||
|
||||
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
|
||||
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
||||
sampling_tensors.top_ks)
|
||||
|
||||
if do_min_p:
|
||||
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
||||
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
|
||||
probs,
|
||||
logprobs,
|
||||
sampling_metadata,
|
||||
sampling_tensors,
|
||||
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
||||
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
||||
)
|
||||
|
||||
if self.include_gpu_probs_tensor:
|
||||
# Since we will defer sampler result Pythonization,
|
||||
# preserve GPU-side tensors in support of later
|
||||
# deferred pythonization of logprobs
|
||||
assert maybe_sampled_tokens_tensor is not None
|
||||
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
||||
else:
|
||||
# Since Pythonization has already happened, don't preserve
|
||||
# GPU-side tensors.
|
||||
on_device_tensors = None
|
||||
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs = None
|
||||
sample_logprobs = None
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
||||
assert not isinstance(maybe_deferred_sample_results,
|
||||
SampleResultArgsType)
|
||||
prompt_logprobs, sample_logprobs = get_logprobs(
|
||||
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
||||
|
||||
return _build_sampler_output(
|
||||
maybe_deferred_sample_results,
|
||||
sampling_metadata,
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
on_device_tensors=on_device_tensors,
|
||||
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
|
||||
|
||||
def _apply_top_k_top_p_vacc(
|
||||
logits: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = probs_sort.cumsum(dim=-1)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1).to(probs_sum.device)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
logits = torch.empty_like(logits_sort).scatter_(dim=-1,
|
||||
index=logits_idx,
|
||||
src=logits_sort)
|
||||
return logits
|
||||
@@ -0,0 +1,69 @@
|
||||
|
||||
import torch
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.platforms import current_platform
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.compile will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
||||
input_ < org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
|
||||
def VocabParallelEmbedding_forward(self, input_):
|
||||
|
||||
try:
|
||||
if self.tp_size > 1:
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
parallel_embedding_output = None
|
||||
if memory_recycler is not None:
|
||||
if memory_recycler.EMBEDDING_OUT_BUFFER.size(0) == input_.size(0):
|
||||
parallel_embedding_output = memory_recycler.EMBEDDING_OUT_BUFFER.to(self.weight.dtype)
|
||||
|
||||
output_parallel = torch.vacc.parallel_embedding(
|
||||
input_,
|
||||
self.weight,
|
||||
self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index,
|
||||
output = parallel_embedding_output
|
||||
)
|
||||
else:
|
||||
raise ValueError("not support non-tp")
|
||||
except:
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
|
||||
#TODO: fuse all_reduce
|
||||
return tensor_model_parallel_all_reduce(output_parallel)
|
||||
0
vllm_vacc/vllm/model_executor/models/__init__.py
Normal file
0
vllm_vacc/vllm/model_executor/models/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
133
vllm_vacc/vllm/model_executor/models/bert.py
Normal file
133
vllm_vacc/vllm/model_executor/models/bert.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
|
||||
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
from .vars import *
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
if USE_FUSED_BERT_ATTENTION:
|
||||
tp_group = get_tp_group()
|
||||
world_size = tp_group.world_size
|
||||
rank = tp_group.rank_in_group
|
||||
total_bytes = hidden_states.numel() * hidden_states.element_size() * world_size
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata_all = forward_context.attn_metadata
|
||||
|
||||
if isinstance(attn_metadata_all, dict):
|
||||
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
||||
else:
|
||||
attn_metadata = attn_metadata_all
|
||||
|
||||
# (matmul + bias_add) with TP + all_reduce结构,为了避免重复加bias,只对rank 0下发bias做bias add,
|
||||
# bert layer里BertSelfOutput和BertOutput模块存在这种结构,对应的bias参数是下面的self_bias和output_bias
|
||||
if total_bytes < 4194304 or world_size == 1:
|
||||
# 1. TP场景all_reduce输入小于4MB时会在以下融合算子里调用dsp all_reduce,大于等于4MB时由于限制需要在外面调用vccl all_reduce
|
||||
# 2. 没有TP的场景也会调用下面的融合算子
|
||||
output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
|
||||
qkv_weight=self.attention.self.qkv_proj.weight,
|
||||
qkv_bias=self.attention.self.qkv_proj.bias,
|
||||
self_weight=self.attention.output.dense.weight,
|
||||
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
|
||||
self_norm_weight=self.attention.output.LayerNorm.weight,
|
||||
self_norm_bias=self.attention.output.LayerNorm.bias,
|
||||
intermediate_weight=self.intermediate.dense.weight,
|
||||
intermediate_bias=self.intermediate.dense.bias,
|
||||
output_weight=self.output.dense.weight,
|
||||
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
|
||||
output_norm_weight=self.output.LayerNorm.weight,
|
||||
output_norm_bias=self.output.LayerNorm.bias,
|
||||
dense_out=torch.Tensor(),
|
||||
seqs=attn_metadata.seq_lens,
|
||||
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.FullStage,
|
||||
sm_scale=self.attention.self.scaling,
|
||||
num_q_heads=self.attention.self.num_heads * world_size,
|
||||
num_kv_heads=self.attention.self.num_kv_heads * world_size,
|
||||
flash_attention=False,
|
||||
reduce_result=True if world_size > 1 else False,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
group_id=tp_group.group_id,
|
||||
dev_info=tp_group.rank_device_infos)
|
||||
else:
|
||||
attn_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
|
||||
qkv_weight=self.attention.self.qkv_proj.weight,
|
||||
qkv_bias=self.attention.self.qkv_proj.bias,
|
||||
self_weight=self.attention.output.dense.weight,
|
||||
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
|
||||
self_norm_weight=torch.Tensor(),
|
||||
self_norm_bias=torch.Tensor(),
|
||||
intermediate_weight=torch.Tensor(),
|
||||
intermediate_bias=torch.Tensor(),
|
||||
output_weight=torch.Tensor(),
|
||||
output_bias=torch.Tensor(),
|
||||
output_norm_weight=torch.Tensor(),
|
||||
output_norm_bias=torch.Tensor(),
|
||||
dense_out=torch.Tensor(),
|
||||
seqs=attn_metadata.seq_lens,
|
||||
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.AttnOutStage,
|
||||
sm_scale=self.attention.self.scaling,
|
||||
num_q_heads=self.attention.self.num_heads * world_size,
|
||||
num_kv_heads=self.attention.self.num_kv_heads * world_size,
|
||||
flash_attention=False,
|
||||
reduce_result=False,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
group_id=tp_group.group_id,
|
||||
dev_info=tp_group.rank_device_infos)
|
||||
if world_size > 1:
|
||||
attn_out_stage_output = tensor_model_parallel_all_reduce(attn_out_stage_output)
|
||||
if USE_FUSED_MLP_VISION:
|
||||
attn_output = self.attention.output.LayerNorm(attn_out_stage_output + hidden_states)
|
||||
inter_out_stage_output = torch.vacc.fuse_mlp_vision(src=attn_output,
|
||||
weights_13=self.intermediate.dense.weight,
|
||||
weights_2=self.output.dense.weight,
|
||||
weights_13_bias=self.intermediate.dense.bias,
|
||||
weights_2_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
|
||||
act_type=0 # gelu
|
||||
)
|
||||
else:
|
||||
inter_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
|
||||
qkv_weight=torch.Tensor(),
|
||||
qkv_bias=torch.Tensor(),
|
||||
self_weight=torch.Tensor(),
|
||||
self_bias=torch.Tensor(),
|
||||
self_norm_weight=self.attention.output.LayerNorm.weight,
|
||||
self_norm_bias=self.attention.output.LayerNorm.bias,
|
||||
intermediate_weight=self.intermediate.dense.weight,
|
||||
intermediate_bias=self.intermediate.dense.bias,
|
||||
output_weight=self.output.dense.weight,
|
||||
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
|
||||
output_norm_weight=torch.Tensor(),
|
||||
output_norm_bias=torch.Tensor(),
|
||||
dense_out=attn_out_stage_output,
|
||||
seqs=attn_metadata.seq_lens,
|
||||
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.InterOutStage,
|
||||
sm_scale=self.attention.self.scaling,
|
||||
num_q_heads=self.attention.self.num_heads * world_size,
|
||||
num_kv_heads=self.attention.self.num_kv_heads * world_size,
|
||||
flash_attention=False,
|
||||
reduce_result=False,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
group_id=tp_group.group_id,
|
||||
dev_info=tp_group.rank_device_infos)
|
||||
if world_size > 1:
|
||||
inter_out_stage_output = tensor_model_parallel_all_reduce(inter_out_stage_output)
|
||||
if USE_FUSED_MLP_VISION:
|
||||
output = self.output.LayerNorm(inter_out_stage_output + attn_output)
|
||||
else:
|
||||
output = self.output.LayerNorm(inter_out_stage_output + attn_out_stage_output)
|
||||
else:
|
||||
attn_output = self.attention(hidden_states)
|
||||
intermediate_output = self.intermediate(attn_output)
|
||||
output = self.output(intermediate_output, attn_output)
|
||||
return output
|
||||
292
vllm_vacc/vllm/model_executor/models/deepseek_mtp.py
Normal file
292
vllm_vacc/vllm/model_executor/models/deepseek_mtp.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Iterable, Set, Tuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
from .vars import *
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictorLayer as DeepSeekMultiTokenPredictorLayerOrig
|
||||
|
||||
from vllm.distributed import get_tp_group
|
||||
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
||||
|
||||
def DeepSeekMultiTokenPredictorLayer__init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
super(DeepSeekMultiTokenPredictorLayerOrig, self).__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
if USE_PARALLEL_MTP_EH_PROJ:
|
||||
self.eh_proj = RowParallelLinear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
return_bias=False)
|
||||
else:
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
|
||||
from vllm.model_executor.models.deepseek_mtp import SharedHead
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
if self.is_v32:
|
||||
topk_tokens = config.index_topk
|
||||
topk_indices_buffer = torch.empty(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
topk_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
else:
|
||||
topk_indices_buffer = None
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
|
||||
topk_indices_buffer)
|
||||
|
||||
class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
|
||||
|
||||
if not hasattr(self, "weight_capture"):
|
||||
from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekMTPWegitCapture
|
||||
self.weight_capture = DeepseekMTPWegitCapture(self.mtp_block)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
assert inputs_embeds is not None
|
||||
|
||||
if inputs_embeds.shape[0] > 256:
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
|
||||
deepseek_mtp_layer_input_buffer = None
|
||||
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler):
|
||||
deepseek_mtp_layer_input_buffer = memory_recycler.DEEPSEEK_MTP_LAYER_INPUT
|
||||
|
||||
from torch_vacc.vacc.custom_ops import fuse_mtp_stage0
|
||||
hidden_states = fuse_mtp_stage0(
|
||||
inputs_embeds,
|
||||
previous_hidden_states,
|
||||
positions,
|
||||
self.enorm.weight,
|
||||
self.hnorm.weight,
|
||||
self.enorm.variance_epsilon,
|
||||
world_size=get_tp_group().world_size,
|
||||
rank=get_tp_group().rank_in_group,
|
||||
group_id=get_tp_group().group_id,
|
||||
dev_info=get_tp_group().rank_device_infos,
|
||||
output=deepseek_mtp_layer_input_buffer,
|
||||
)
|
||||
|
||||
# if USE_PARALLEL_MTP_EH_PROJ:
|
||||
# tp_size = get_tensor_model_parallel_world_size()
|
||||
# rank_id = get_tensor_model_parallel_rank()
|
||||
# last_dim = hidden_states.shape[-1]
|
||||
# if tp_size > 1:
|
||||
# hiddens_tp = last_dim//tp_size
|
||||
# hidden_states = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp]
|
||||
|
||||
hidden_states = self.eh_proj(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.vacc.fuse_mtp_allreduce(
|
||||
inputs_embeds,
|
||||
previous_hidden_states,
|
||||
positions,
|
||||
self.enorm.weight,
|
||||
self.hnorm.weight,
|
||||
self.eh_proj.weight,
|
||||
self.enorm.variance_epsilon,
|
||||
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
|
||||
rank = self.weight_capture.layer_moe.dist_args._1_rank,
|
||||
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
|
||||
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
|
||||
|
||||
if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE):
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=None)
|
||||
else:
|
||||
from torch_vacc.vacc.custom_ops import fuse_mla_moe_v2_allreduce_decode
|
||||
layer = self.mtp_block
|
||||
layer_id = 0
|
||||
|
||||
kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||
positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens]
|
||||
cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions]
|
||||
sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions]
|
||||
|
||||
# 对于MTP Layer来说, residual为None,且需要返回residual
|
||||
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode(
|
||||
hidden_states = hidden_states,
|
||||
residual = None,
|
||||
hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[layer_id],
|
||||
q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[layer_id],
|
||||
q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[layer_id],
|
||||
q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[layer_id],
|
||||
w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[layer_id],
|
||||
w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[layer_id],
|
||||
w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[layer_id],
|
||||
w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[layer_id],
|
||||
w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[layer_id],
|
||||
w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[layer_id],
|
||||
kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[layer_id],
|
||||
sin_cache = sin_cache,
|
||||
cos_cache = cos_cache,
|
||||
slot_mapping = attn_metadata.slot_mapping,
|
||||
kv_cache = kv_cache,
|
||||
block_tables = attn_metadata.decode_metadata.block_tables,
|
||||
block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
|
||||
w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[layer_id],
|
||||
w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[layer_id],
|
||||
o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[layer_id],
|
||||
o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[layer_id],
|
||||
# mla params
|
||||
seq_lens = attn_metadata.decode_metadata.seq_lens,
|
||||
sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale,
|
||||
head_num = self.weight_capture.layer_moe.attn_args._22_head_num,
|
||||
# flash attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# moe weight
|
||||
rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[layer_id],
|
||||
mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[layer_id],
|
||||
mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[layer_id],
|
||||
mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[layer_id],
|
||||
mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[layer_id],
|
||||
moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[layer_id],
|
||||
moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[layer_id],
|
||||
moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[layer_id],
|
||||
moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[layer_id],
|
||||
mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[layer_id],
|
||||
moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[layer_id],
|
||||
# moe params
|
||||
mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
|
||||
moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
|
||||
moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
|
||||
# vccl info
|
||||
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
|
||||
rank = self.weight_capture.layer_moe.dist_args._1_rank,
|
||||
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
|
||||
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
|
||||
#hidden_states = residual + hidden_states
|
||||
hidden_states = residual.add_(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class DeepSeekMTP(nn.Module):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
from vllm.model_executor.models.deepseek_v2 import get_spec_layer_idx_from_weight_name
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is None:
|
||||
continue
|
||||
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
if USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
from vllm.model_executor.models.utils import PPMissingLayer
|
||||
for layer_id in self.model.layers:
|
||||
layer = self.model.layers[layer_id]
|
||||
if isinstance(layer_id, PPMissingLayer):
|
||||
continue
|
||||
layer.mtp_block.self_attn.merge_qkv_weights()
|
||||
return loaded_params
|
||||
|
||||
class SharedHead(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import rms_norm
|
||||
return rms_norm(hidden_states, self.norm.weight, output=hidden_states)
|
||||
except Exception as e:
|
||||
print(f"fuse rms_norm run fail, now use unfused ops: {e}")
|
||||
|
||||
return self.norm(hidden_states)
|
||||
658
vllm_vacc/vllm/model_executor/models/deepseek_v2.py
Normal file
658
vllm_vacc/vllm/model_executor/models/deepseek_v2.py
Normal file
@@ -0,0 +1,658 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tp_group,
|
||||
get_tensor_model_parallel_world_size,get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsPP
|
||||
from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
from vllm.model_executor.models.deepseek_v2 import yarn_get_mscale, DeepseekV2MLAAttention, Indexer
|
||||
|
||||
|
||||
from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
from .vars import *
|
||||
from ..ops.deepseek_fused_mlp_moe import (vacc_fused_decode_moe_fp8,
|
||||
vacc_fused_prefill_moe_fp8,
|
||||
vacc_fused_mlp_fp8)
|
||||
|
||||
from .fused_forward import *
|
||||
import os
|
||||
test_layer_en = os.getenv("test_layer_en", "0")
|
||||
|
||||
# class DeepseekV2MLAAttention(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# vllm_config: VllmConfig,
|
||||
# config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
# hidden_size: int,
|
||||
# num_heads: int,
|
||||
# qk_nope_head_dim: int,
|
||||
# qk_rope_head_dim: int,
|
||||
# v_head_dim: int,
|
||||
# q_lora_rank: Optional[int],
|
||||
# kv_lora_rank: int,
|
||||
# rope_theta: float = 10000,
|
||||
# rope_scaling: Optional[dict[str, Any]] = None,
|
||||
# max_position_embeddings: int = 8192,
|
||||
# cache_config: Optional[CacheConfig] = None,
|
||||
# quant_config: Optional[QuantizationConfig] = None,
|
||||
# prefix: str = "",
|
||||
# topk_indices_buffer: Optional[torch.Tensor] = None,
|
||||
# ) -> None:
|
||||
# super(DeepseekV2MLAAttention,self).__init__()
|
||||
# self.hidden_size = hidden_size
|
||||
# self.qk_nope_head_dim = qk_nope_head_dim
|
||||
# self.qk_rope_head_dim = qk_rope_head_dim
|
||||
# self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
# self.v_head_dim = v_head_dim
|
||||
|
||||
# self.q_lora_rank = q_lora_rank
|
||||
# self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
# self.num_heads = num_heads
|
||||
# tp_size = get_tensor_model_parallel_world_size()
|
||||
# assert num_heads % tp_size == 0
|
||||
# self.num_local_heads = num_heads // tp_size
|
||||
|
||||
# self.scaling = self.qk_head_dim**-0.5
|
||||
# self.rope_theta = rope_theta
|
||||
# self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# if self.q_lora_rank is not None:
|
||||
# if USE_PARALLEL_Q_KV_GEN:
|
||||
# self.q_a_proj = RowParallelLinear(self.hidden_size,
|
||||
# self.q_lora_rank,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.q_a_proj")
|
||||
# else:
|
||||
# self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
# self.q_lora_rank,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.q_a_proj")
|
||||
# self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
# eps=config.rms_norm_eps)
|
||||
# self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
||||
# self.num_heads *
|
||||
# self.qk_head_dim,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.q_b_proj")
|
||||
# else:
|
||||
# self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||
# self.num_heads *
|
||||
# self.qk_head_dim,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.q_proj")
|
||||
# if USE_PARALLEL_Q_KV_GEN:
|
||||
# self.kv_a_proj_with_mqa = RowParallelLinear(
|
||||
# self.hidden_size,
|
||||
# self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
# else:
|
||||
# self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
# self.hidden_size,
|
||||
# self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
# self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
# eps=config.rms_norm_eps)
|
||||
# self.kv_b_proj = ColumnParallelLinear(
|
||||
# self.kv_lora_rank,
|
||||
# self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.kv_b_proj")
|
||||
# self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
||||
# self.hidden_size,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.o_proj")
|
||||
|
||||
# rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
# self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
# rotary_dim=qk_rope_head_dim,
|
||||
# max_position=max_position_embeddings,
|
||||
# base=rope_theta,
|
||||
# rope_scaling=rope_scaling,
|
||||
# is_neox_style=False)
|
||||
# if rope_scaling:
|
||||
# mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
# scaling_factor = rope_scaling["factor"]
|
||||
# mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
# self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
# self.is_v32 = hasattr(config, "index_topk")
|
||||
|
||||
# if self.is_v32:
|
||||
# self.indexer = Indexer(vllm_config, config, hidden_size,
|
||||
# q_lora_rank, quant_config, cache_config,
|
||||
# topk_indices_buffer, f"{prefix}.indexer")
|
||||
# else:
|
||||
# self.indexer = None
|
||||
|
||||
# self.mla_attn = Attention(
|
||||
# num_heads=self.num_local_heads,
|
||||
# head_size=self.kv_lora_rank,
|
||||
# scale=self.scaling,
|
||||
# num_kv_heads=1,
|
||||
# cache_config=cache_config,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.attn",
|
||||
# use_mla=True,
|
||||
# # MLA Args
|
||||
# q_lora_rank=self.q_lora_rank,
|
||||
# kv_lora_rank=self.kv_lora_rank,
|
||||
# qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
# qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
# qk_head_dim=self.qk_head_dim,
|
||||
# v_head_dim=self.v_head_dim,
|
||||
# rotary_emb=self.rotary_emb,
|
||||
# q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
# kv_b_proj=self.kv_b_proj,
|
||||
# o_proj=self.o_proj,
|
||||
# )
|
||||
|
||||
# self.prefix = prefix
|
||||
# self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
|
||||
# def forward(
|
||||
# self,
|
||||
# positions: torch.Tensor,
|
||||
# hidden_states: torch.Tensor,
|
||||
# kv_cache: torch.Tensor,
|
||||
# attn_metadata: AttentionMetadata,
|
||||
# ) -> torch.Tensor:
|
||||
|
||||
# tp_size = get_tensor_model_parallel_world_size()
|
||||
# rank_id = get_tensor_model_parallel_rank()
|
||||
# last_dim = hidden_states.shape[-1]
|
||||
|
||||
# if USE_PARALLEL_Q_KV_GEN: #tp qa and kva
|
||||
# hidden_states_split = hidden_states
|
||||
# if tp_size > 1:
|
||||
# hiddens_tp = last_dim//tp_size
|
||||
# hidden_states_split = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp].contiguous()
|
||||
# if self.q_lora_rank is not None:
|
||||
# ckq = self.q_a_proj(hidden_states_split)[0]
|
||||
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
# else:
|
||||
# hidden_states_or_q_c = hidden_states
|
||||
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states_split)[0].split(
|
||||
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
|
||||
# attn_metadata)
|
||||
|
||||
# if self.q_lora_rank is not None:
|
||||
# ckq = self.q_a_proj(hidden_states)[0]
|
||||
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
# else:
|
||||
# hidden_states_or_q_c = hidden_states
|
||||
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
|
||||
# attn_metadata)
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, residual = None, rms_norm = None):
|
||||
# moe layer support prefill&decode vacc ops
|
||||
if residual is not None:
|
||||
try:
|
||||
reduce_result = self.tp_size > 1
|
||||
# decode moe, first seq
|
||||
if self.is_decode:
|
||||
hidden_states, residual = vacc_fused_decode_moe_fp8(self, self.shared_experts,
|
||||
hidden_states, residual,
|
||||
rms_norm, self.gate, self.experts,
|
||||
self.routed_scaling_factor,
|
||||
reduce_result)
|
||||
return hidden_states, residual
|
||||
# prefill moe, first expert
|
||||
else:
|
||||
hidden_states, residual = vacc_fused_prefill_moe_fp8(self, self.shared_experts,
|
||||
hidden_states, residual,
|
||||
rms_norm, self.gate, self.experts,
|
||||
self.routed_scaling_factor,
|
||||
reduce_result)
|
||||
return hidden_states, residual
|
||||
except Exception as e:
|
||||
logger.warning("vacc fused moe run fail, now use unfused ops %s", e)
|
||||
hidden_states, residual = rms_norm(hidden_states, residual)
|
||||
|
||||
self.experts.is_decode = self.is_decode
|
||||
|
||||
# 1. fuse_prefill_pre_moe
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
if self.n_shared_experts is not None:
|
||||
try:
|
||||
shared_output = vacc_fused_mlp_fp8(self.shared_experts, hidden_states, moe_share=True)
|
||||
except Exception as e:
|
||||
logger.warning("fused mlp is Error, now use Default:%s", e)
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
# 2. fused_moe
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
# 3. add_reduce
|
||||
# now fuse share_mlp add to experts
|
||||
# if shared_output is not None:
|
||||
# # out = input + other * alpha
|
||||
# final_hidden_states = shared_output.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
||||
# else:
|
||||
# final_hidden_states = final_hidden_states * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
if residual is not None:
|
||||
return final_hidden_states.view(num_tokens, hidden_dim), residual
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
|
||||
def forward(self, x, residual = None, rms_norm = None):
|
||||
# use all fused ops
|
||||
if residual is not None:
|
||||
reduce_result = self.down_proj.reduce_results and self.down_proj.tp_size > 1
|
||||
hidden_states, residual = vacc_fused_mlp_fp8(self,
|
||||
x, residual,
|
||||
rms_norm,
|
||||
reduce_result)
|
||||
return hidden_states, residual
|
||||
# use default fuse ops
|
||||
try:
|
||||
output_parallel = vacc_fused_mlp_fp8(self, x, residual, rms_norm)
|
||||
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
|
||||
x = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
x = output_parallel
|
||||
except Exception as e:
|
||||
logger.warning("fuse_mlp run fail, now use default: %s", e)
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
|
||||
first_k_dense_replace = self.config.first_k_dense_replace if hasattr(self.config, "first_k_dense_replace") else 3
|
||||
|
||||
if not hasattr(self, "weight_capture"):
|
||||
from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekWeightCapture
|
||||
self.weight_capture = DeepseekWeightCapture(self.layers, self.start_layer, self.end_layer)
|
||||
self.cached_weights_state = True
|
||||
self.cached_batch = 1
|
||||
self.layer_nums = self.end_layer - self.start_layer
|
||||
self.is_pipeline_first = get_pp_group().is_first_rank
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
else:
|
||||
# update global seq lens, use for serve infos
|
||||
# update_seqence_length(attn_metadata.decode_metadata.seq_lens)
|
||||
|
||||
if FUSE_ALL_DECODER_LAYERS:
|
||||
self.weight_capture.update_attn_args(attn_metadata.decode_metadata.seq_lens,
|
||||
attn_metadata.slot_mapping,
|
||||
[self.layers[i].self_attn.mla_attn.kv_cache[forward_context.virtual_engine] for i in range(self.start_layer, first_k_dense_replace)],
|
||||
[self.layers[i].self_attn.mla_attn.kv_cache[forward_context.virtual_engine] for i in range(first_k_dense_replace, self.end_layer)],
|
||||
attn_metadata.decode_metadata.block_tables)
|
||||
|
||||
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 0)
|
||||
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 1)
|
||||
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 2)
|
||||
|
||||
if hidden_states.shape[0] != self.cached_batch:
|
||||
# batch切换,重新执行缓存
|
||||
self.cached_weights_state = True
|
||||
self.cached_batch = hidden_states.shape[0]
|
||||
|
||||
if self.cached_weights_state:
|
||||
self.cached_weights_state = False
|
||||
hidden_states, residual = forward_mla_moe_layers_with_weights(hidden_states, residual, self.weight_capture)
|
||||
else:
|
||||
hidden_states, residual = forward_mla_moe_layers_without_weights(hidden_states, residual, self.weight_capture)
|
||||
else:
|
||||
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode,fuse_mla_moe_v2_allreduce_decode
|
||||
for i in range(0, self.layer_nums):
|
||||
layer_id = i + self.start_layer
|
||||
layer = self.layers[layer_id]
|
||||
kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||
positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens]
|
||||
cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions]
|
||||
sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions]
|
||||
|
||||
if layer_id < first_k_dense_replace:
|
||||
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode(
|
||||
hidden_states = hidden_states,
|
||||
residual = residual,
|
||||
hidden_states_norm_weight = self.weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight[i],
|
||||
q_a_proj_weight = self.weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights[i],
|
||||
q_a_proj_weight_scale_inv = self.weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv[i],
|
||||
q_a_layernorm_weight = self.weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight[i],
|
||||
w_q = self.weight_capture.layer_mlp.attn_args._3_W_Q[i],
|
||||
w_q_scale = self.weight_capture.layer_mlp.attn_args._4_W_Q_scales[i],
|
||||
w_uk = self.weight_capture.layer_mlp.attn_args._5_W_UK[i],
|
||||
w_uk_scale = self.weight_capture.layer_mlp.attn_args._6_W_UK_scales[i],
|
||||
w_qr = self.weight_capture.layer_mlp.attn_args._7_W_QR[i],
|
||||
w_qr_scale = self.weight_capture.layer_mlp.attn_args._8_W_QR_scales[i],
|
||||
kv_a_layernorm_weight = self.weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight[i],
|
||||
sin_cache = sin_cache,# self.weight_capture.layer_mlp.attn_args._10_sin_cache,
|
||||
cos_cache = cos_cache,# self.weight_capture.layer_mlp.attn_args._11_cos_cache,
|
||||
slot_mapping = attn_metadata.slot_mapping,#self.weight_capture.layer_mlp.attn_args._12_slot_mapping[i],
|
||||
kv_cache = kv_cache,#self.weight_capture.layer_mlp.attn_args._13_kv_cache[i],
|
||||
block_tables = attn_metadata.decode_metadata.block_tables,#self.weight_capture.layer_mlp.attn_args._14_block_tables[i],
|
||||
block_group_size = self.weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
|
||||
w_uv = self.weight_capture.layer_mlp.attn_args._16_W_UV[i],
|
||||
w_uv_scale = self.weight_capture.layer_mlp.attn_args._17_W_UV_scales[i],
|
||||
o_proj_weight = self.weight_capture.layer_mlp.attn_args._18_o_proj_weight[i],
|
||||
o_proj_weight_scale_inv = self.weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv[i],
|
||||
# mla params
|
||||
seq_lens = attn_metadata.decode_metadata.seq_lens,
|
||||
sm_scale = self.weight_capture.layer_mlp.attn_args._21_sm_scale,
|
||||
head_num = self.weight_capture.layer_mlp.attn_args._22_head_num,
|
||||
# flash attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# mlp weight
|
||||
rms_weight = self.weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight[i],
|
||||
mlp_weight_13 = self.weight_capture.layer_mlp.mlp_args._1_mlp_w13[i],
|
||||
mlp_weight_2 = self.weight_capture.layer_mlp.mlp_args._2_mlp_w2[i],
|
||||
mlp_weight_scale_13 = self.weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale[i],
|
||||
mlp_weight_scale_2 = self.weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale[i],
|
||||
# mlp params
|
||||
mlp_block_size_w13 = self.weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = self.weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
|
||||
# vccl info
|
||||
world_size = self.weight_capture.layer_mlp.dist_args._0_world_size,
|
||||
rank = self.weight_capture.layer_mlp.dist_args._1_rank,
|
||||
group_id = self.weight_capture.layer_mlp.dist_args._2_group_id,
|
||||
dev_info = self.weight_capture.layer_mlp.dist_args._3_dev_info)
|
||||
else:
|
||||
wid = i - first_k_dense_replace if self.is_pipeline_first else i
|
||||
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode(
|
||||
hidden_states = hidden_states,
|
||||
residual = residual,
|
||||
hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[wid],
|
||||
q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[wid],
|
||||
q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[wid],
|
||||
q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[wid],
|
||||
w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[wid],
|
||||
w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[wid],
|
||||
w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[wid],
|
||||
w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[wid],
|
||||
w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[wid],
|
||||
w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[wid],
|
||||
kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[wid],
|
||||
sin_cache = sin_cache,# self.weight_capture.layer_mlp.attn_args._10_sin_cache,
|
||||
cos_cache = cos_cache,# self.weight_capture.layer_mlp.attn_args._11_cos_cache,
|
||||
slot_mapping = attn_metadata.slot_mapping,#self.weight_capture.layer_mlp.attn_args._12_slot_mapping[i],
|
||||
kv_cache = kv_cache,#self.weight_capture.layer_mlp.attn_args._13_kv_cache[i],
|
||||
block_tables = attn_metadata.decode_metadata.block_tables,
|
||||
block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
|
||||
w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[wid],
|
||||
w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[wid],
|
||||
o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[wid],
|
||||
o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[wid],
|
||||
# mla params
|
||||
seq_lens = attn_metadata.decode_metadata.seq_lens,
|
||||
sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale,
|
||||
head_num = self.weight_capture.layer_moe.attn_args._22_head_num,
|
||||
# flash attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# moe weight
|
||||
rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[wid],
|
||||
mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[wid],
|
||||
mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[wid],
|
||||
mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[wid],
|
||||
mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[wid],
|
||||
moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[wid],
|
||||
moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[wid],
|
||||
moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[wid],
|
||||
moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[wid],
|
||||
mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[wid],
|
||||
moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[wid],
|
||||
# moe params
|
||||
mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
|
||||
moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
|
||||
moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
|
||||
# vccl info
|
||||
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
|
||||
rank = self.weight_capture.layer_moe.dist_args._1_rank,
|
||||
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
|
||||
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
from .memory.memory_recycling import init_huge_memory_allocator
|
||||
from .vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
|
||||
# default is deepseek, config can set to ['deepseek_mtp',]
|
||||
model_name = "deepseek"
|
||||
config_infos = vllm_vacc_config_manager().get_model_infos()
|
||||
if config_infos != "default":
|
||||
if config_infos in ['mtp']:
|
||||
model_name = "deepseek_mtp"
|
||||
else:
|
||||
model_name = config_infos
|
||||
|
||||
if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name):
|
||||
logger.warning("init huge memory allocator fail. prefill memory recycling will disable")
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if test_layer_en == "1":
|
||||
test_layer = 5
|
||||
if name not in ['model.embed_tokens.weight', 'model.norm.weight', 'lm_head.weight']:
|
||||
if int(name.split(".")[2]) > test_layer:
|
||||
continue
|
||||
# TODO(simon): support nextn predict layers
|
||||
if hasattr(self.config, "num_nextn_predict_layers"
|
||||
) and self.config.num_nextn_predict_layers > 0:
|
||||
assert self.config.num_nextn_predict_layers == 1
|
||||
layer_idx = self.config.num_hidden_layers
|
||||
if name.startswith(f"model.layers.{layer_idx}"):
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
if USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
layer.self_attn.merge_qkv_weights()
|
||||
|
||||
return loaded_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
|
||||
if attn_metadata.prefill_metadata is not None:
|
||||
from .memory.memory_recycling import alloc_memory_recycler
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
if hasattr(attn_metadata, 'num_prefill_tokens'):
|
||||
tokens = attn_metadata.num_prefill_tokens
|
||||
else:
|
||||
tokens = attn_metadata.prefill_metadata.num_prefill_tokens
|
||||
|
||||
vllm_model_mode = "deepseek"
|
||||
config_infos = vllm_vacc_config_manager().get_model_infos()
|
||||
if config_infos != "default":
|
||||
if config_infos in ['mtp']:
|
||||
vllm_model_mode = "deepseek_mtp"
|
||||
else:
|
||||
vllm_model_mode = config_infos
|
||||
|
||||
if get_tp_group().rank_in_group == 0:
|
||||
memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}'
|
||||
logger.info(memory_infos)
|
||||
|
||||
if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size):
|
||||
logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
1367
vllm_vacc/vllm/model_executor/models/deepseek_v2_fused.py
Normal file
1367
vllm_vacc/vllm/model_executor/models/deepseek_v2_fused.py
Normal file
File diff suppressed because it is too large
Load Diff
216
vllm_vacc/vllm/model_executor/models/fused_forward.py
Normal file
216
vllm_vacc/vllm/model_executor/models/fused_forward.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import torch
|
||||
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode_layers,fuse_mla_moe_v2_allreduce_decode_layers,fuse_mla_moe_v2_allreduce_decode_layers_v2
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .weight_capture.deepseek_weight_capture import DeepseekWeightCapture
|
||||
import time
|
||||
|
||||
from .vars import *
|
||||
|
||||
# 单层 mla + mlp
|
||||
def forward_mla_mlp_single_layer(hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
weight_capture : DeepseekWeightCapture,
|
||||
layer_id: int):
|
||||
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode
|
||||
|
||||
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode(
|
||||
hidden_states = hidden_states,
|
||||
residual = residual,
|
||||
hidden_states_norm_weight = weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight[layer_id],
|
||||
q_a_proj_weight = weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights[layer_id],
|
||||
q_a_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv[layer_id],
|
||||
q_a_layernorm_weight = weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight[layer_id],
|
||||
w_q = weight_capture.layer_mlp.attn_args._3_W_Q[layer_id],
|
||||
w_q_scale = weight_capture.layer_mlp.attn_args._4_W_Q_scales[layer_id],
|
||||
w_uk = weight_capture.layer_mlp.attn_args._5_W_UK[layer_id],
|
||||
w_uk_scale = weight_capture.layer_mlp.attn_args._6_W_UK_scales[layer_id],
|
||||
w_qr = weight_capture.layer_mlp.attn_args._7_W_QR[layer_id],
|
||||
w_qr_scale = weight_capture.layer_mlp.attn_args._8_W_QR_scales[layer_id],
|
||||
kv_a_layernorm_weight = weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight[layer_id],
|
||||
sin_cache = weight_capture.layer_mlp.attn_args._10_sin_cache,
|
||||
cos_cache = weight_capture.layer_mlp.attn_args._11_cos_cache,
|
||||
slot_mapping = weight_capture.layer_mlp.attn_args._12_slot_mapping,
|
||||
kv_cache = weight_capture.layer_mlp.attn_args._13_kv_cache[layer_id],
|
||||
block_tables = weight_capture.layer_mlp.attn_args._14_block_tables,
|
||||
block_group_size = weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
|
||||
w_uv = weight_capture.layer_mlp.attn_args._16_W_UV[layer_id],
|
||||
w_uv_scale = weight_capture.layer_mlp.attn_args._17_W_UV_scales[layer_id],
|
||||
o_proj_weight = weight_capture.layer_mlp.attn_args._18_o_proj_weight[layer_id],
|
||||
o_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv[layer_id],
|
||||
# mla params
|
||||
seq_lens = weight_capture.layer_mlp.attn_args._20_seq_lens,
|
||||
sm_scale = weight_capture.layer_mlp.attn_args._21_sm_scale,
|
||||
head_num = weight_capture.layer_mlp.attn_args._22_head_num,
|
||||
# flash attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# mlp weight
|
||||
rms_weight = weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight[layer_id],
|
||||
mlp_weight_13 = weight_capture.layer_mlp.mlp_args._1_mlp_w13[layer_id],
|
||||
mlp_weight_2 = weight_capture.layer_mlp.mlp_args._2_mlp_w2[layer_id],
|
||||
mlp_weight_scale_13 = weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale[layer_id],
|
||||
mlp_weight_scale_2 = weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale[layer_id],
|
||||
# mlp params
|
||||
mlp_block_size_w13 = weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
|
||||
# vccl info
|
||||
world_size = weight_capture.layer_mlp.dist_args._0_world_size,
|
||||
rank = weight_capture.layer_mlp.dist_args._1_rank,
|
||||
group_id = weight_capture.layer_mlp.dist_args._2_group_id,
|
||||
dev_info = weight_capture.layer_mlp.dist_args._3_dev_info)
|
||||
return hidden_states, residual
|
||||
|
||||
# 多层 mla + mlp
|
||||
def forward_mla_mlp_layers(hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
weight_capture : DeepseekWeightCapture):
|
||||
if residual == None:
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode_layers(
|
||||
hidden_states = hidden_states,
|
||||
residual = residual,
|
||||
hidden_states_norm_weight = weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight,
|
||||
q_a_proj_weight = weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights,
|
||||
q_a_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv,
|
||||
q_a_layernorm_weight = weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight,
|
||||
w_q = weight_capture.layer_mlp.attn_args._3_W_Q,
|
||||
w_q_scale = weight_capture.layer_mlp.attn_args._4_W_Q_scales,
|
||||
w_uk = weight_capture.layer_mlp.attn_args._5_W_UK,
|
||||
w_uk_scale = weight_capture.layer_mlp.attn_args._6_W_UK_scales,
|
||||
w_qr = weight_capture.layer_mlp.attn_args._7_W_QR,
|
||||
w_qr_scale = weight_capture.layer_mlp.attn_args._8_W_QR_scales,
|
||||
kv_a_layernorm_weight = weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight,
|
||||
sin_cache = weight_capture.layer_mlp.attn_args._10_sin_cache,
|
||||
cos_cache = weight_capture.layer_mlp.attn_args._11_cos_cache,
|
||||
slot_mapping = weight_capture.layer_mlp.attn_args._12_slot_mapping,
|
||||
kv_cache = weight_capture.layer_mlp.attn_args._13_kv_cache,
|
||||
block_tables = weight_capture.layer_mlp.attn_args._14_block_tables,
|
||||
block_group_size = weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
|
||||
w_uv = weight_capture.layer_mlp.attn_args._16_W_UV,
|
||||
w_uv_scale = weight_capture.layer_mlp.attn_args._17_W_UV_scales,
|
||||
o_proj_weight = weight_capture.layer_mlp.attn_args._18_o_proj_weight,
|
||||
o_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv,
|
||||
# mla params
|
||||
seq_lens = weight_capture.layer_mlp.attn_args._20_seq_lens,
|
||||
sm_scale = weight_capture.layer_mlp.attn_args._21_sm_scale,
|
||||
head_num = weight_capture.layer_mlp.attn_args._22_head_num,
|
||||
# flash attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# mlp weight
|
||||
rms_weight = weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight,
|
||||
mlp_weight_13 = weight_capture.layer_mlp.mlp_args._1_mlp_w13,
|
||||
mlp_weight_2 = weight_capture.layer_mlp.mlp_args._2_mlp_w2,
|
||||
mlp_weight_scale_13 = weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale,
|
||||
mlp_weight_scale_2 = weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale,
|
||||
# mlp params
|
||||
mlp_block_size_w13 = weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
|
||||
# vccl info
|
||||
world_size = weight_capture.layer_mlp.dist_args._0_world_size,
|
||||
rank = weight_capture.layer_mlp.dist_args._1_rank,
|
||||
group_id = weight_capture.layer_mlp.dist_args._2_group_id,
|
||||
dev_info = weight_capture.layer_mlp.dist_args._3_dev_info)
|
||||
return hidden_states, residual
|
||||
|
||||
# 多层 mla + moe, 未缓存weights
|
||||
def forward_mla_moe_layers_with_weights(hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
weight_capture : DeepseekWeightCapture):
|
||||
|
||||
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode_layers(
|
||||
hidden_states = hidden_states,
|
||||
residual = residual,
|
||||
hidden_states_norm_weight = weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight,
|
||||
q_a_proj_weight = weight_capture.layer_moe.attn_args._0_merge_q_kv_weights,
|
||||
q_a_proj_weight_scale_inv = weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv,
|
||||
q_a_layernorm_weight = weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight,
|
||||
w_q = weight_capture.layer_moe.attn_args._3_W_Q,
|
||||
w_q_scale = weight_capture.layer_moe.attn_args._4_W_Q_scales,
|
||||
w_uk = weight_capture.layer_moe.attn_args._5_W_UK,
|
||||
w_uk_scale = weight_capture.layer_moe.attn_args._6_W_UK_scales,
|
||||
w_qr = weight_capture.layer_moe.attn_args._7_W_QR,
|
||||
w_qr_scale = weight_capture.layer_moe.attn_args._8_W_QR_scales,
|
||||
kv_a_layernorm_weight = weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight,
|
||||
sin_cache = weight_capture.layer_moe.attn_args._10_sin_cache,
|
||||
cos_cache = weight_capture.layer_moe.attn_args._11_cos_cache,
|
||||
slot_mapping = weight_capture.layer_moe.attn_args._12_slot_mapping,
|
||||
kv_cache = weight_capture.layer_moe.attn_args._13_kv_cache,
|
||||
block_tables = weight_capture.layer_moe.attn_args._14_block_tables,
|
||||
block_group_size = weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
|
||||
w_uv = weight_capture.layer_moe.attn_args._16_W_UV,
|
||||
w_uv_scale = weight_capture.layer_moe.attn_args._17_W_UV_scales,
|
||||
o_proj_weight = weight_capture.layer_moe.attn_args._18_o_proj_weight,
|
||||
o_proj_weight_scale_inv = weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv,
|
||||
# mla params
|
||||
seq_lens = weight_capture.layer_moe.attn_args._20_seq_lens,
|
||||
sm_scale = weight_capture.layer_moe.attn_args._21_sm_scale,
|
||||
head_num = weight_capture.layer_moe.attn_args._22_head_num,
|
||||
# flash attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# moe weight
|
||||
rms_weight = weight_capture.layer_moe.moe_args._0_moe_rms_weight,
|
||||
mlp_weight_13 = weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13,
|
||||
mlp_weight_2 = weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2,
|
||||
mlp_weight_scale_13 = weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale,
|
||||
mlp_weight_scale_2 = weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale,
|
||||
moe_weight_13 = weight_capture.layer_moe.moe_args._5_moe_w13,
|
||||
moe_weight_2 = weight_capture.layer_moe.moe_args._6_moe_w2,
|
||||
moe_weight_scale_13 = weight_capture.layer_moe.moe_args._7_moe_w13_scale,
|
||||
moe_weight_scale_2 = weight_capture.layer_moe.moe_args._8_moe_w2_scale,
|
||||
mm_weight = weight_capture.layer_moe.moe_args._9_gate_weight,
|
||||
moe_bias = weight_capture.layer_moe.moe_args._10_moe_bias,
|
||||
# moe params
|
||||
mlp_block_size_w13 = weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
|
||||
moe_block_size_w13 = weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
|
||||
moe_block_size_w2 = weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
|
||||
# vccl info
|
||||
world_size = weight_capture.layer_moe.dist_args._0_world_size,
|
||||
rank = weight_capture.layer_moe.dist_args._1_rank,
|
||||
group_id = weight_capture.layer_moe.dist_args._2_group_id,
|
||||
dev_info = weight_capture.layer_moe.dist_args._3_dev_info)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
# 多层 mla + moe, 缓存weights,必须要在未缓存weights算子执行之后才可以调用
|
||||
def forward_mla_moe_layers_without_weights(hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
weight_capture : DeepseekWeightCapture):
|
||||
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode_layers_v2(
|
||||
hidden_states = hidden_states,
|
||||
residual = residual,
|
||||
sin_cache = weight_capture.layer_moe.attn_args._10_sin_cache,
|
||||
cos_cache = weight_capture.layer_moe.attn_args._11_cos_cache,
|
||||
slot_mapping = weight_capture.layer_moe.attn_args._12_slot_mapping,
|
||||
kv_cache = weight_capture.layer_moe.attn_args._13_kv_cache,
|
||||
block_tables = weight_capture.layer_moe.attn_args._14_block_tables,
|
||||
block_group_size = weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
|
||||
# mla params
|
||||
seq_lens = weight_capture.layer_moe.attn_args._20_seq_lens,
|
||||
sm_scale = weight_capture.layer_moe.attn_args._21_sm_scale,
|
||||
head_num = weight_capture.layer_moe.attn_args._22_head_num,
|
||||
# flash_attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# moe weight
|
||||
# moe params
|
||||
mlp_block_size_w13 = weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
|
||||
moe_block_size_w13 = weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
|
||||
moe_block_size_w2 = weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
|
||||
# vccl info
|
||||
world_size = weight_capture.layer_moe.dist_args._0_world_size,
|
||||
rank = weight_capture.layer_moe.dist_args._1_rank,
|
||||
group_id = weight_capture.layer_moe.dist_args._2_group_id,
|
||||
dev_info = weight_capture.layer_moe.dist_args._3_dev_info)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward_deepseekv3(model: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
weight_capture : DeepseekWeightCapture):
|
||||
|
||||
hidden_states, residual = forward_mla_mlp_layers(hidden_states, residual, weight_capture)
|
||||
hidden_states, residual = forward_mla_moe_layers_with_weights(hidden_states, residual, weight_capture)
|
||||
return hidden_states, residual
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,239 @@
|
||||
import warnings
|
||||
|
||||
from transformers.utils import (
|
||||
CONFIG_NAME,
|
||||
IMAGE_PROCESSOR_NAME,
|
||||
cached_file,
|
||||
is_timm_config_dict,
|
||||
is_timm_local_checkpoint,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
from transformers.models.auto.configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
AutoConfig,
|
||||
model_type_to_module_name,
|
||||
replace_list_option_in_docstrings,
|
||||
)
|
||||
|
||||
from transformers.image_processing_utils import ImageProcessingMixin
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto.image_processing_auto import (
|
||||
AutoImageProcessor,
|
||||
logger,
|
||||
FORCE_FAST_IMAGE_PROCESSOR,
|
||||
IMAGE_PROCESSOR_MAPPING_NAMES,
|
||||
IMAGE_PROCESSOR_MAPPING,
|
||||
get_image_processor_class_from_name,
|
||||
resolve_trust_remote_code,
|
||||
_warning_fast_image_processor_available,
|
||||
get_class_from_dynamic_module
|
||||
)
|
||||
|
||||
def check_vacc_support_module(module_class):
|
||||
if module_class.__name__ == "Qwen2VLImageProcessorFast":
|
||||
from .qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
|
||||
return Qwen2VLImageProcessorFastWithVacc
|
||||
return module_class
|
||||
|
||||
"""AutoImageProcessor class."""
|
||||
class AutoImageProcessorWithVacc(AutoImageProcessor):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if kwargs.get("token") is not None:
|
||||
raise ValueError(
|
||||
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
||||
)
|
||||
kwargs["token"] = use_auth_token
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
# TODO: @yoni, change in v4.48 (use_fast set to True by default)
|
||||
use_fast = kwargs.pop("use_fast", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
kwargs["_from_auto"] = True
|
||||
|
||||
# Resolve the image processor config filename
|
||||
if "image_processor_filename" in kwargs:
|
||||
image_processor_filename = kwargs.pop("image_processor_filename")
|
||||
elif is_timm_local_checkpoint(pretrained_model_name_or_path):
|
||||
image_processor_filename = CONFIG_NAME
|
||||
else:
|
||||
image_processor_filename = IMAGE_PROCESSOR_NAME
|
||||
|
||||
# Load the image processor config
|
||||
try:
|
||||
# Main path for all transformers models and local TimmWrapper checkpoints
|
||||
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
|
||||
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
|
||||
)
|
||||
except Exception as initial_exception:
|
||||
# Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
|
||||
# instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
|
||||
# except the model name, the only way to check if a remote checkpoint is a timm model is to try to
|
||||
# load `config.json` and if it fails with some error, we raise the initial exception.
|
||||
try:
|
||||
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
|
||||
pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
|
||||
)
|
||||
except Exception:
|
||||
raise initial_exception
|
||||
|
||||
# In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
|
||||
# because only timm models have image processing in `config.json`.
|
||||
if not is_timm_config_dict(config_dict):
|
||||
raise initial_exception
|
||||
|
||||
image_processor_type = config_dict.get("image_processor_type", None)
|
||||
|
||||
# 跳转vacc预处理算子相关替换
|
||||
# if image_processor_type == "Qwen2VLImageProcessorFast":
|
||||
# from .qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
|
||||
# return Qwen2VLImageProcessorFastWithVacc.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
image_processor_auto_map = None
|
||||
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
|
||||
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
|
||||
|
||||
# If we still don't have the image processor class, check if we're loading from a previous feature extractor config
|
||||
# and if so, infer the image processor class from there.
|
||||
if image_processor_type is None and image_processor_auto_map is None:
|
||||
feature_extractor_class = config_dict.pop("feature_extractor_type", None)
|
||||
if feature_extractor_class is not None:
|
||||
image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
|
||||
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
|
||||
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
|
||||
image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor")
|
||||
|
||||
# If we don't find the image processor class in the image processor config, let's try the model config.
|
||||
if image_processor_type is None and image_processor_auto_map is None:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
# It could be in `config.image_processor_type``
|
||||
image_processor_type = getattr(config, "image_processor_type", None)
|
||||
if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map:
|
||||
image_processor_auto_map = config.auto_map["AutoImageProcessor"]
|
||||
|
||||
image_processor_class = None
|
||||
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
|
||||
if image_processor_type is not None:
|
||||
# if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
|
||||
if use_fast is None:
|
||||
use_fast = image_processor_type.endswith("Fast")
|
||||
if not use_fast and image_processor_type in FORCE_FAST_IMAGE_PROCESSOR and is_torchvision_available():
|
||||
use_fast = True
|
||||
logger.warning_once(
|
||||
f"The image processor of type `{image_processor_type}` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. "
|
||||
"This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. "
|
||||
"Note that this behavior will be extended to all models in a future release."
|
||||
)
|
||||
if not use_fast:
|
||||
logger.warning_once(
|
||||
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
|
||||
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
|
||||
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
|
||||
)
|
||||
if use_fast and not image_processor_type.endswith("Fast"):
|
||||
image_processor_type += "Fast"
|
||||
if use_fast and not is_torchvision_available():
|
||||
# check if there is a slow image processor class to fallback to
|
||||
image_processor_class = get_image_processor_class_from_name(image_processor_type[:-4])
|
||||
if image_processor_class is None:
|
||||
raise ValueError(
|
||||
f"`{image_processor_type}` requires `torchvision` to be installed. Please install `torchvision` and try again."
|
||||
)
|
||||
logger.warning_once(
|
||||
"Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor."
|
||||
)
|
||||
use_fast = False
|
||||
if use_fast:
|
||||
for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
|
||||
if image_processor_type in image_processors:
|
||||
break
|
||||
else:
|
||||
image_processor_type = image_processor_type[:-4]
|
||||
use_fast = False
|
||||
logger.warning_once(
|
||||
"`use_fast` is set to `True` but the image processor class does not have a fast version. "
|
||||
" Falling back to the slow version."
|
||||
)
|
||||
image_processor_class = get_image_processor_class_from_name(image_processor_type)
|
||||
else:
|
||||
image_processor_type_slow = image_processor_type.removesuffix("Fast")
|
||||
image_processor_class = get_image_processor_class_from_name(image_processor_type_slow)
|
||||
if image_processor_class is None and image_processor_type.endswith("Fast"):
|
||||
raise ValueError(
|
||||
f"`{image_processor_type}` does not have a slow version. Please set `use_fast=True` when instantiating the processor."
|
||||
)
|
||||
|
||||
has_remote_code = image_processor_auto_map is not None
|
||||
has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
|
||||
if has_remote_code:
|
||||
if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple):
|
||||
# In some configs, only the slow image processor class is stored
|
||||
image_processor_auto_map = (image_processor_auto_map, None)
|
||||
if use_fast and image_processor_auto_map[1] is not None:
|
||||
class_ref = image_processor_auto_map[1]
|
||||
else:
|
||||
class_ref = image_processor_auto_map[0]
|
||||
if "--" in class_ref:
|
||||
upstream_repo = class_ref.split("--")[0]
|
||||
else:
|
||||
upstream_repo = None
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
|
||||
)
|
||||
|
||||
if has_remote_code and trust_remote_code:
|
||||
if not use_fast and image_processor_auto_map[1] is not None:
|
||||
_warning_fast_image_processor_available(image_processor_auto_map[1])
|
||||
|
||||
image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
||||
_ = kwargs.pop("code_revision", None)
|
||||
image_processor_class.register_for_auto_class()
|
||||
# check preprocess module supported by vacc
|
||||
image_processor_class = check_vacc_support_module(image_processor_class)
|
||||
return image_processor_class.from_dict(config_dict, **kwargs)
|
||||
elif image_processor_class is not None:
|
||||
# check preprocess module supported by vacc
|
||||
image_processor_class = check_vacc_support_module(image_processor_class)
|
||||
return image_processor_class.from_dict(config_dict, **kwargs)
|
||||
# Last try: we use the IMAGE_PROCESSOR_MAPPING.
|
||||
elif type(config) in IMAGE_PROCESSOR_MAPPING:
|
||||
image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)]
|
||||
|
||||
image_processor_class_py, image_processor_class_fast = image_processor_tuple
|
||||
|
||||
if not use_fast and image_processor_class_fast is not None:
|
||||
_warning_fast_image_processor_available(image_processor_class_fast)
|
||||
|
||||
if image_processor_class_fast and (use_fast or image_processor_class_py is None):
|
||||
# check preprocess module supported by vacc
|
||||
image_processor_class_fast = check_vacc_support_module(image_processor_class_fast)
|
||||
return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
else:
|
||||
if image_processor_class_py is not None:
|
||||
# check preprocess module supported by vacc
|
||||
image_processor_class_py = check_vacc_support_module(image_processor_class_py)
|
||||
return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a "
|
||||
f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
|
||||
f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in IMAGE_PROCESSOR_MAPPING_NAMES)}"
|
||||
)
|
||||
@@ -0,0 +1,402 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for Qwen2-VL."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
|
||||
from transformers.image_processing_utils import BatchFeature
|
||||
from transformers.image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from transformers.image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
logging,
|
||||
)
|
||||
from transformers.video_utils import VideoInput, make_batched_videos
|
||||
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import (
|
||||
Qwen2VLFastImageProcessorKwargs,
|
||||
logger
|
||||
)
|
||||
|
||||
# reseize_normalize_repeat_transpose_reshape
|
||||
def fuse_qwen2_vl_preprocess_img_cpu(
|
||||
image: "torch.Tensor",
|
||||
do_resize: bool,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
resized_height: int,
|
||||
resized_width: int,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
merge_size: int,
|
||||
image_mean_0: float,
|
||||
image_mean_1: float,
|
||||
image_mean_2: float,
|
||||
image_std_0: float,
|
||||
image_std_1: float,
|
||||
image_std_2: float,
|
||||
batch_size: int = 1,
|
||||
grid_t: int = 1,
|
||||
channel: int = 3,
|
||||
) -> "torch.Tensor":
|
||||
def resize(
|
||||
image: "torch.Tensor",
|
||||
size_h: int,
|
||||
size_w: int,
|
||||
interpolation: Optional["F.InterpolationMode"] = None,
|
||||
antialias: bool = True,
|
||||
) -> "torch.Tensor":
|
||||
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
|
||||
if size_h and size_w:
|
||||
new_size = (size_h, size_w)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
|
||||
f" {size_h} and {size_w}."
|
||||
)
|
||||
return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
|
||||
|
||||
def fuse_mean_std_and_rescale_factor(
|
||||
image_mean_0: float,
|
||||
image_mean_1: float,
|
||||
image_mean_2: float,
|
||||
image_std_0: float,
|
||||
image_std_1: float,
|
||||
image_std_2: float,
|
||||
do_normalize: Optional[bool] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
) -> tuple:
|
||||
if do_rescale and do_normalize:
|
||||
image_mean = torch.tensor([image_mean_0, image_mean_1, image_mean_2], device=device)
|
||||
image_std = torch.tensor([image_std_0, image_std_1, image_std_2], device=device)
|
||||
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
|
||||
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
|
||||
return image_mean, image_std
|
||||
|
||||
def rescale_and_normalize(
|
||||
images: "torch.Tensor",
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean_0: float,
|
||||
image_mean_1: float,
|
||||
image_mean_2: float,
|
||||
image_std_0: float,
|
||||
image_std_1: float,
|
||||
image_std_2: float
|
||||
) -> "torch.Tensor":
|
||||
image_mean, image_std = fuse_mean_std_and_rescale_factor(
|
||||
image_mean_0, image_mean_1, image_mean_2,
|
||||
image_std_0, image_std_1, image_std_2,
|
||||
do_normalize,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
device=images.device,
|
||||
)
|
||||
if do_normalize:
|
||||
images = F.normalize(images.to(dtype=torch.float32), image_mean, image_std)
|
||||
return images
|
||||
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
stacked_images = resize(
|
||||
image=image,
|
||||
size_h=resized_height,
|
||||
size_w=resized_width,
|
||||
interpolation=interpolation, # BICUBIC插值
|
||||
)
|
||||
patches = rescale_and_normalize(
|
||||
stacked_images,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
image_mean_0, image_mean_1, image_mean_2,
|
||||
image_std_0, image_std_1, image_std_2
|
||||
)
|
||||
if patches.ndim == 4:
|
||||
patches = patches.unsqueeze(1)
|
||||
if patches.shape[1] % temporal_patch_size != 0:
|
||||
repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)
|
||||
patches = torch.cat([patches, repeats], dim=1)
|
||||
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
patches = patches.view(
|
||||
batch_size, # 1 - 批次大小
|
||||
grid_t, # 1 - 时间网格数
|
||||
temporal_patch_size, # 2 - 时间块大小
|
||||
channel, # 3 - 通道数(RGB)
|
||||
grid_h // merge_size, # 12 // 2 = 6 - 高度方向合并后的网格数
|
||||
merge_size, # 2 - 高度方向合并大小
|
||||
patch_size, # 14 - 高度方向块大小
|
||||
grid_w // merge_size, # 38 // 2 = 19 - 宽度方向合并后的网格数
|
||||
merge_size, # 2 - 宽度方向合并大小
|
||||
patch_size, # 14 - 宽度方向块大小
|
||||
)
|
||||
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
|
||||
flatten_patches = patches.reshape(
|
||||
# batch_size, # 1
|
||||
grid_t * grid_h * grid_w, # 1 * 12 * 38 = 456 (总网格数)
|
||||
channel * temporal_patch_size * patch_size * patch_size, # 3 * 2 * 14 * 14 = 1176 (每个网格的特征维度)
|
||||
)
|
||||
|
||||
return flatten_patches
|
||||
|
||||
|
||||
class Qwen2VLImageProcessorFastWithVacc(BaseImageProcessorFast):
|
||||
do_resize = True
|
||||
resample = PILImageResampling.BICUBIC
|
||||
size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
do_convert_rgb = True
|
||||
patch_size = 14
|
||||
temporal_patch_size = 2
|
||||
merge_size = 2
|
||||
min_pixels = None
|
||||
max_pixels = None
|
||||
valid_kwargs = Qwen2VLFastImageProcessorKwargs
|
||||
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Qwen2VLFastImageProcessorKwargs]):
|
||||
size = kwargs.pop("size", None)
|
||||
min_pixels = kwargs.pop("min_pixels", None)
|
||||
max_pixels = kwargs.pop("max_pixels", None)
|
||||
# backward compatibility: override size with min_pixels and max_pixels if they are provided
|
||||
size = self.size if size is None else size
|
||||
if min_pixels is not None:
|
||||
size["shortest_edge"] = min_pixels
|
||||
size.pop("min_pixels", None)
|
||||
if max_pixels is not None:
|
||||
size["longest_edge"] = max_pixels
|
||||
size.pop("max_pixels", None)
|
||||
if "shortest_edge" not in size or "longest_edge" not in size:
|
||||
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
||||
|
||||
super().__init__(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
|
||||
|
||||
def _further_process_kwargs(
|
||||
self,
|
||||
size: Optional[SizeDict] = None,
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Update kwargs that need further processing before being validated
|
||||
Can be overridden by subclasses to customize the processing of kwargs.
|
||||
"""
|
||||
if min_pixels is not None and max_pixels is not None:
|
||||
size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
|
||||
elif size is not None:
|
||||
if "shortest_edge" not in size or "longest_edge" not in size:
|
||||
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
||||
min_pixels = size["shortest_edge"]
|
||||
max_pixels = size["longest_edge"]
|
||||
else:
|
||||
size = {**self.size}
|
||||
|
||||
return super()._further_process_kwargs(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: Optional[VideoInput] = None,
|
||||
**kwargs: Unpack[Qwen2VLFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
return super().preprocess(images, videos, **kwargs)
|
||||
|
||||
def _preprocess_image_like_inputs(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: VideoInput,
|
||||
do_convert_rgb: bool,
|
||||
input_data_format: ChannelDimension,
|
||||
device: Optional[Union[str, "torch.device"]] = None,
|
||||
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess image-like inputs.
|
||||
To be overridden by subclasses when image-like inputs other than images should be processed.
|
||||
It can be used for segmentation maps, depth maps, etc.
|
||||
"""
|
||||
# Prepare input images
|
||||
batch_feature = BatchFeature()
|
||||
if images is not None:
|
||||
images = self._prepare_image_like_inputs(
|
||||
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
||||
)
|
||||
batch_feature = self._preprocess(images, **kwargs)
|
||||
if videos is not None:
|
||||
logger.warning(
|
||||
"`Qwen2VLImageProcessorFast` works only with image inputs and doesn't process videos anymore. "
|
||||
"This is a deprecated behavior and will be removed in v5.0. "
|
||||
"Your videos should be forwarded to `Qwen2VLVideoProcessor`. "
|
||||
)
|
||||
# Can't change _prepare_images_structure to work with videos because it also needs to work with images.
|
||||
videos = make_batched_videos(videos)
|
||||
videos = [
|
||||
torch.stack(self._prepare_image_like_inputs(video, do_convert_rgb, input_data_format, device))
|
||||
for video in videos
|
||||
]
|
||||
video_outputs = self._preprocess(videos, **kwargs)
|
||||
batch_feature.update(
|
||||
{"pixel_values_videos": video_outputs.pixel_values, "video_grid_thw": video_outputs.image_grid_thw}
|
||||
)
|
||||
return batch_feature
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
merge_size: int,
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
):
|
||||
min_pixels=size["shortest_edge"],
|
||||
max_pixels=size["longest_edge"],
|
||||
processed_images = []
|
||||
processed_grids = []
|
||||
for img in images:
|
||||
height, width = img.shape[-2:]
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=min_pixels[0],
|
||||
max_pixels=max_pixels[0],
|
||||
)
|
||||
# reseize_normalize_repeat = fuse_qwen2_vl_preprocess_img_cpu(
|
||||
# img,
|
||||
# do_resize,
|
||||
# min_pixels,
|
||||
# max_pixels,
|
||||
# do_rescale,
|
||||
# rescale_factor,
|
||||
# do_normalize,
|
||||
# resized_height,
|
||||
# resized_width,
|
||||
# interpolation,
|
||||
# patch_size,
|
||||
# temporal_patch_size,
|
||||
# merge_size,
|
||||
# image_mean[0], image_mean[1], image_mean[2],
|
||||
# image_std[0], image_std[1], image_std[2],
|
||||
# 1,1,3
|
||||
# )
|
||||
import torch_vacc
|
||||
if img.device.type != "vacc":
|
||||
img = img.to("vacc")
|
||||
reseize_normalize_repeat = torch_vacc.vacc.custom_qwen3_ops.qwen2vl_img_preprocess(img,
|
||||
do_resize,
|
||||
min_pixels[0],
|
||||
max_pixels[0],
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
resized_height,
|
||||
resized_width,
|
||||
0x1003, # interpolation,
|
||||
patch_size,
|
||||
temporal_patch_size,
|
||||
merge_size,
|
||||
image_mean[0], image_mean[1], image_mean[2],
|
||||
image_std[0], image_std[1], image_std[2] )
|
||||
processed_images.append(reseize_normalize_repeat)
|
||||
grid_t = 1
|
||||
grid_h = resized_height // patch_size
|
||||
grid_w = resized_width // patch_size
|
||||
grid_thw_ = torch.tensor([[grid_t, grid_h, grid_w]])
|
||||
processed_grids.append(grid_thw_)
|
||||
|
||||
pixel_values = torch.cat(processed_images, dim=0)
|
||||
image_grid_thw = torch.cat(processed_grids, dim=0)
|
||||
|
||||
return BatchFeature(
|
||||
data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors
|
||||
)
|
||||
|
||||
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
||||
"""
|
||||
A utility that returns number of image patches for a given image size.
|
||||
|
||||
Note: Do not remove this method! It is used by vLLM to infer the number of patches and placeholders
|
||||
without an image input.
|
||||
|
||||
Args:
|
||||
height (`int`):
|
||||
Height of the input image.
|
||||
width (`int`):
|
||||
Width of the input image.
|
||||
images_kwargs (`dict`, *optional*)
|
||||
Any kwargs to override defaults of the image processor.
|
||||
Returns:
|
||||
`int`: Number of image patches per image.
|
||||
"""
|
||||
min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"]
|
||||
max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"]
|
||||
patch_size = images_kwargs.get("patch_size", self.patch_size)
|
||||
merge_size = images_kwargs.get("merge_size", self.merge_size)
|
||||
|
||||
factor = patch_size * merge_size
|
||||
resized_height, resized_width = smart_resize(
|
||||
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
|
||||
)
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
return grid_h * grid_w
|
||||
|
||||
|
||||
__all__ = ["Qwen2VLImageProcessorFastWithVacc"]
|
||||
@@ -0,0 +1,91 @@
|
||||
from transformers.models.qwen3_vl import Qwen3VLProcessor
|
||||
# from transformers.models.auto.image_processing_auto import AutoImageProcessor
|
||||
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
||||
|
||||
class Qwen3VLProcessorWithVacc(Qwen3VLProcessor):
|
||||
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
|
||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
||||
|
||||
@classmethod
|
||||
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
"""
|
||||
Identify and instantiate the subcomponents of Processor classes, like image processors and
|
||||
tokenizers. This method uses the Processor attributes like `tokenizer_class` to figure out what class those
|
||||
subcomponents should be. Note that any subcomponents must either be library classes that are accessible in
|
||||
the `transformers` root, or they must be custom code that has been registered with the relevant autoclass,
|
||||
via methods like `AutoTokenizer.register()`. If neither of these conditions are fulfilled, this method
|
||||
will be unable to find the relevant subcomponent class and will raise an error.
|
||||
"""
|
||||
args = []
|
||||
for attribute_name in cls.attributes:
|
||||
class_name = getattr(cls, f"{attribute_name}_class")
|
||||
if isinstance(class_name, tuple):
|
||||
classes = tuple(cls.get_possibly_dynamic_module(n) if n is not None else None for n in class_name)
|
||||
if attribute_name == "image_processor":
|
||||
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
|
||||
use_fast = kwargs.get("use_fast")
|
||||
if use_fast is None:
|
||||
print(
|
||||
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
|
||||
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
|
||||
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
|
||||
)
|
||||
else:
|
||||
use_fast = kwargs.get("use_fast", True)
|
||||
if use_fast and classes[1] is not None:
|
||||
attribute_class = classes[1]
|
||||
else:
|
||||
attribute_class = classes[0]
|
||||
else:
|
||||
attribute_class = cls.get_possibly_dynamic_module(class_name)
|
||||
|
||||
if attribute_class.__name__ == "AutoImageProcessor":
|
||||
from .auto_image_preprocessor import AutoImageProcessorWithVacc
|
||||
attribute_class = AutoImageProcessorWithVacc
|
||||
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
||||
|
||||
return args
|
||||
|
||||
class Qwen2VLProcessorWithVacc(Qwen2VLProcessor):
|
||||
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
|
||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
||||
|
||||
@classmethod
|
||||
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
"""
|
||||
Identify and instantiate the subcomponents of Processor classes, like image processors and
|
||||
tokenizers. This method uses the Processor attributes like `tokenizer_class` to figure out what class those
|
||||
subcomponents should be. Note that any subcomponents must either be library classes that are accessible in
|
||||
the `transformers` root, or they must be custom code that has been registered with the relevant autoclass,
|
||||
via methods like `AutoTokenizer.register()`. If neither of these conditions are fulfilled, this method
|
||||
will be unable to find the relevant subcomponent class and will raise an error.
|
||||
"""
|
||||
args = []
|
||||
for attribute_name in cls.attributes:
|
||||
class_name = getattr(cls, f"{attribute_name}_class")
|
||||
if isinstance(class_name, tuple):
|
||||
classes = tuple(cls.get_possibly_dynamic_module(n) if n is not None else None for n in class_name)
|
||||
if attribute_name == "image_processor":
|
||||
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
|
||||
use_fast = kwargs.get("use_fast")
|
||||
if use_fast is None:
|
||||
print(
|
||||
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
|
||||
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
|
||||
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
|
||||
)
|
||||
else:
|
||||
use_fast = kwargs.get("use_fast", True)
|
||||
if use_fast and classes[1] is not None:
|
||||
attribute_class = classes[1]
|
||||
else:
|
||||
attribute_class = classes[0]
|
||||
else:
|
||||
attribute_class = cls.get_possibly_dynamic_module(class_name)
|
||||
|
||||
if attribute_class.__name__ == "AutoImageProcessor":
|
||||
from .auto_image_preprocessor import AutoImageProcessorWithVacc
|
||||
attribute_class = AutoImageProcessorWithVacc
|
||||
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
||||
|
||||
return args
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
235
vllm_vacc/vllm/model_executor/models/memory/allocator.py
Normal file
235
vllm_vacc/vllm/model_executor/models/memory/allocator.py
Normal file
@@ -0,0 +1,235 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List
|
||||
|
||||
class VaccHugeMemoryAllocator(nn.Module):
|
||||
# self._active_bytes means the real tensor buffers used bytes.
|
||||
# you can use this value to slice the src buffer
|
||||
# you can not free the self._src_buffer_array, because the src buffer is the max buffer
|
||||
# self._block_bytes means the part max buffer size
|
||||
def __init__(self, blocks, dtype = torch.bfloat16, use_contiguous = False):
|
||||
self._total_blocks = blocks
|
||||
self._dtype = dtype
|
||||
self._enable = False
|
||||
self._src_buffer_array = None
|
||||
self._block_bytes = 0
|
||||
self._max_tokens = 0
|
||||
self._hiddens = 0
|
||||
self._active_bytes = 0
|
||||
self._use_contiguous_buffer = use_contiguous
|
||||
# max tokens for dynamic buffer size
|
||||
# dynamic buffer size is bigger than normal buffer usually
|
||||
self._dynamic_max_tokens = 0
|
||||
self._dynamic_block_bytes = 0
|
||||
|
||||
# malloc the max buffer, and not free
|
||||
def init_buffers(self, max_tokens, hiddens):
|
||||
self._max_tokens = max_tokens
|
||||
self._hiddens = hiddens
|
||||
|
||||
try:
|
||||
import torch_vacc
|
||||
self._block_bytes = self._max_tokens * self._hiddens \
|
||||
* self.get_dtype_bytes(self._dtype)
|
||||
self._all_bytes = self._block_bytes * self._total_blocks
|
||||
|
||||
if self._use_contiguous_buffer:
|
||||
# 一次性申请[N*3,]大小的BytesBuffer
|
||||
self._src_buffer = torch.zeros(self._all_bytes,
|
||||
dtype = torch.uint8,
|
||||
device = "vacc")
|
||||
tmp_buffer_array = self._src_buffer.view(self._total_blocks, -1)
|
||||
self._src_buffer_array = [tmp_buffer_array[i]
|
||||
for i in range(self._total_blocks)]
|
||||
else:
|
||||
# 一次性申请3块[N,]大小的BytesBuffer
|
||||
self._src_buffer_array = [torch.zeros(self.block_bytes,
|
||||
dtype = torch.uint8,
|
||||
device = "vacc")
|
||||
for i in range(self._total_blocks)]
|
||||
self._enable = True
|
||||
except Exception as e:
|
||||
print(f"vacc huge buffer alloc fail: {e}")
|
||||
|
||||
# 为有dynamic buffer需求的网络设计, dynamic buffer 可能会要比普通的max_tokens buffers大一些
|
||||
def init_buffers_with_dynamic(self, max_tokens, dynamic_tokens, hiddens, dynamic_buffers_mask: List):
|
||||
self._max_tokens = max_tokens
|
||||
self._dynamic_max_tokens = dynamic_tokens
|
||||
self._hiddens = hiddens
|
||||
|
||||
dynamic_buffers_count = sum(dynamic_buffers_mask)
|
||||
normal_buffers_count = self._total_blocks - dynamic_buffers_count
|
||||
|
||||
self._block_bytes = self._max_tokens * self._hiddens \
|
||||
* self.get_dtype_bytes(self._dtype)
|
||||
self._dynamic_block_bytes = self._dynamic_max_tokens * self._hiddens \
|
||||
* self.get_dtype_bytes(self._dtype)
|
||||
|
||||
self._all_bytes = self._block_bytes * normal_buffers_count + \
|
||||
self._dynamic_block_bytes * dynamic_buffers_count
|
||||
|
||||
# print("创建重复利用buffer: dynamic的数量->", dynamic_buffers_count,
|
||||
# " 正常的数量->", normal_buffers_count,
|
||||
# " dynamic buffer大小->", self._dynamic_block_bytes,
|
||||
# " 正常大小->", self._block_bytes)
|
||||
|
||||
try:
|
||||
assert self._use_contiguous_buffer is False, "malloc dynamic recycle memory buffers only support separation buffer now"
|
||||
|
||||
self._src_buffer_array = []
|
||||
for i in range(self._total_blocks):
|
||||
if dynamic_buffers_mask[i]:
|
||||
self._src_buffer_array.append(torch.zeros(self._dynamic_block_bytes,
|
||||
dtype = torch.uint8,
|
||||
device = "vacc"))
|
||||
else:
|
||||
# 一次性申请3块[N,]大小的BytesBuffer
|
||||
self._src_buffer_array.append(torch.zeros(self._block_bytes,
|
||||
dtype = torch.uint8,
|
||||
device = "vacc"))
|
||||
self._enable = True
|
||||
except Exception as e:
|
||||
print("vacc huge buffer alloc fail.", e)
|
||||
|
||||
# slice buffers from the src buffer (48K * blocks)
|
||||
# use for target tensor
|
||||
# you should analyse such tensor position by model, such as deepseek have 4 buffers
|
||||
# you need alloc the buffer by real input tokens when new request is in
|
||||
# notice: you should warn the dtype, because the buffer created by uint8
|
||||
def alloc_memory_buffers(self, tokens, dtype=torch.bfloat16):
|
||||
if tokens > self._max_tokens:
|
||||
print("alloc memory buffer fail, tokens is large than max_tokens.", self._max_tokens)
|
||||
return None
|
||||
|
||||
self._active_bytes = tokens * self._hiddens * self.get_dtype_bytes(dtype)
|
||||
return [sub_array[:self._active_bytes]
|
||||
for sub_array in self._src_buffer_array]
|
||||
|
||||
@property
|
||||
def memory_buffers(self):
|
||||
return self._src_buffer_array
|
||||
|
||||
# allock 1_2 buffers
|
||||
# @params tokens 待缓存的prefill tokens buffer大小
|
||||
# @params part 总共划分的区域
|
||||
# @params return_buffer_list 需要返回的区域列表,如果为空的话,返回所有
|
||||
# @params dtype 数据类型
|
||||
# 创建1/N的buffers
|
||||
# 返回[第N部分]
|
||||
def alloc_1_div_N_buffers(self, part = 2,
|
||||
return_buffer_list = [], ):
|
||||
|
||||
if not hasattr(self, "_src_buffer"):
|
||||
print("1 div N alloctor need a contiguous buffer")
|
||||
return None
|
||||
|
||||
assert isinstance(return_buffer_list, list), "return_buffer_list need List object"
|
||||
|
||||
# 数据以int8的方式,划分为part
|
||||
tmp_buffer_array = self._src_buffer.view(part, -1)
|
||||
# 如果未指定return_buffer_list, 返回所有的part buffer
|
||||
if len(return_buffer_list) == 0:
|
||||
return [tmp_buffer_array[i] for i in range(part)]
|
||||
|
||||
return [tmp_buffer_array[i] for i in return_buffer_list]
|
||||
|
||||
|
||||
def get_dtype_bytes(self, dtype):
|
||||
if isinstance(dtype, torch.dtype):
|
||||
if dtype in [torch.float16, torch.bfloat16, torch.half]:
|
||||
return 2
|
||||
elif dtype in [torch.float32, torch.float, torch.int32]:
|
||||
return 4
|
||||
elif dtype in [torch.float64, torch.double, torch.int64]:
|
||||
return 8
|
||||
elif dtype in [torch.int8, torch.uint8, torch.bool]:
|
||||
return 1
|
||||
else:
|
||||
return 1
|
||||
elif dtype == int:
|
||||
return 8
|
||||
elif dtype == float:
|
||||
return 8
|
||||
elif dtype == bool:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@property
|
||||
def enable(self):
|
||||
return self._enable
|
||||
|
||||
@property
|
||||
def blocks(self):
|
||||
return self._total_blocks
|
||||
|
||||
@property
|
||||
def max_tokens(self):
|
||||
return self._max_tokens
|
||||
|
||||
@property
|
||||
def hiddens(self):
|
||||
return self._hiddens
|
||||
|
||||
@property
|
||||
def active_bytes(self):
|
||||
return self._active_bytes
|
||||
|
||||
@property
|
||||
def block_bytes(self):
|
||||
return self._block_bytes
|
||||
|
||||
@property
|
||||
def dynamic_block_bytes(self):
|
||||
return self._dynamic_block_bytes
|
||||
|
||||
|
||||
class LLMMemoryRecycler:
|
||||
def __init__(self):
|
||||
self.count = 3
|
||||
self.embedding_output = None
|
||||
self.moe_shared_mlp_output = None
|
||||
self.mla_oproj_output = None
|
||||
#self.moe_expert_output = None
|
||||
|
||||
def clear(self):
|
||||
self.embedding_output = None
|
||||
self.moe_shared_mlp_output = None
|
||||
self.mla_oproj_output = None
|
||||
#self.moe_expert_output = None
|
||||
|
||||
@property
|
||||
def EMBEDDING_OUT_BUFFER(self):
|
||||
return self.embedding_output
|
||||
@property
|
||||
def MOE_SHARED_MLP_OUT_BUFFER(self):
|
||||
return self.moe_shared_mlp_output
|
||||
@property
|
||||
def MLA_OPROJ_OUT_BUFFER(self):
|
||||
return self.mla_oproj_output
|
||||
|
||||
def alloc_memory_recycler_llm(tokens,
|
||||
alloctor:VaccHugeMemoryAllocator,
|
||||
recycler:LLMMemoryRecycler,
|
||||
dtype:torch.dtype = torch.bfloat16):
|
||||
|
||||
if not alloctor.enable:
|
||||
print("memory alloctor is not Init.")
|
||||
return False
|
||||
|
||||
recycler.clear()
|
||||
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
|
||||
|
||||
if out_buffers is None:
|
||||
print("llm memory recycler buffers alloc fail. now disable it")
|
||||
return False
|
||||
|
||||
if len(out_buffers) != recycler.count:
|
||||
print("memory recycler buffers not equal llm.")
|
||||
return False
|
||||
|
||||
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
|
||||
#recycler.moe_expert_output = out_buffers[3].view(dtype).view(tokens, alloctor.hiddens)
|
||||
return True
|
||||
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
from .allocator import VaccHugeMemoryAllocator,LLMMemoryRecycler
|
||||
'''
|
||||
DeepSeek support 48K input tokens, there are 3 buffers can recycle:
|
||||
1. parallel_embedding output buffer
|
||||
2. mla_oproject output buffer
|
||||
3. moe_shared_mlp output buffer
|
||||
# 4. moe_expert output buffer
|
||||
each buffer size is 48K * 7168 * 2 bytes
|
||||
'''
|
||||
class DeepseekV3MemoryRecycler(LLMMemoryRecycler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
#self.moe_expert_output = None
|
||||
|
||||
# @property
|
||||
# def MOE_EXPERT_OUT_BUFFER(self):
|
||||
# return self.moe_expert_output
|
||||
|
||||
|
||||
class DeepseekMTPMemoryRecycler(LLMMemoryRecycler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dynamic_output = None
|
||||
self.deepseek_mtp_layer_input = None
|
||||
|
||||
@property
|
||||
def DYNAMIC_OUTPUT_BUFFER(self):
|
||||
return self.dynamic_output
|
||||
|
||||
@property
|
||||
def DEEPSEEK_MTP_LAYER_INPUT(self):
|
||||
return self.deepseek_mtp_layer_input
|
||||
|
||||
|
||||
def alloc_memory_recycler_deepseek_v3(tokens,
|
||||
alloctor:VaccHugeMemoryAllocator,
|
||||
recycler:DeepseekV3MemoryRecycler,
|
||||
dtype:torch.dtype = torch.bfloat16):
|
||||
|
||||
if not alloctor.enable:
|
||||
print("memory alloctor is not Init.")
|
||||
return False
|
||||
|
||||
recycler.clear()
|
||||
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
|
||||
|
||||
if out_buffers is None:
|
||||
print("deepseek_v3 memory recycler buffers alloc fail. now disable it")
|
||||
return False
|
||||
|
||||
if len(out_buffers) != recycler.count:
|
||||
print("memory recycler buffers not equal deepseek_v3.")
|
||||
return False
|
||||
|
||||
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
|
||||
#recycler.moe_expert_output = out_buffers[3].view(dtype).view(tokens, alloctor.hiddens)
|
||||
return True
|
||||
|
||||
|
||||
def alloc_memory_recycler_deepseek_mtp(tokens,
|
||||
alloctor:VaccHugeMemoryAllocator,
|
||||
world_size:int,
|
||||
recycler:DeepseekMTPMemoryRecycler,
|
||||
dtype:torch.dtype = torch.bfloat16):
|
||||
|
||||
if not alloctor.enable:
|
||||
print("memory alloctor is not Init.")
|
||||
return False
|
||||
|
||||
recycler.clear()
|
||||
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
|
||||
|
||||
if out_buffers is None:
|
||||
print("deepseek_mtp memory recycler buffers alloc fail. now disable it")
|
||||
return False
|
||||
|
||||
if len(out_buffers) != recycler.count:
|
||||
print("memory recycler buffers not equal deepseek_mtp.")
|
||||
return False
|
||||
#MTP的内存布局为
|
||||
# 1. deepseek 主模型和 草稿模型中的decoder_layer
|
||||
# a. embedding_output
|
||||
# b. mla_oproj_output
|
||||
# c. moe_shared_mlp_output
|
||||
# 因为: moe会作为previous_hidden_states,被重新组织一次,组织好的buffer会置于buffer0[a.]中
|
||||
# 如果embedding_output放在第一位也可以,相关地址需要整体往后偏移,
|
||||
# 为了理解方便把embedding置于最后,不会参与buffer的重划分复用
|
||||
# 2. deepseek_mtp 草稿模型(未启用该策略, )
|
||||
# a. dynamic_output 占用1/2
|
||||
# b. mtp_input 占用1/6,并且位于dynamic_output buffer后面
|
||||
# dynamic_buffer = alloctor.alloc_1_div_N_buffers(2, [0,])
|
||||
# mtp_input_buffer = alloctor.alloc_1_div_N_buffers(6, [3,])
|
||||
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
|
||||
|
||||
#1/2用broadcast的自由分配,mtp涉及到previous_hidden_states的广播
|
||||
memory_buffers = alloctor.memory_buffers
|
||||
recycler.dynamic_output = memory_buffers[1] #公共用mla_oproject
|
||||
#1/6用于mtp decoder layer的输入, 该算子特殊处理,仅用了1/tp的输出buffer即可
|
||||
# recycler.deepseek_mtp_layer_input = mtp_input_buffer[0]
|
||||
|
||||
mtp_layer_input_dims = alloctor.hiddens * 2 // world_size
|
||||
mtp_layer_input_numels = tokens * mtp_layer_input_dims
|
||||
|
||||
from vllm import envs
|
||||
if envs.VLLM_USE_V1:
|
||||
#v1 无需重新缓存previous_hidden_states,因此moe还在被占用状态,因此需要先用attention的o-buffer去暂存mtp 预处理的空间
|
||||
recycler.deepseek_mtp_layer_input = memory_buffers[1].view(dtype)[:mtp_layer_input_numels].view(tokens, mtp_layer_input_dims)
|
||||
else:
|
||||
# 公共用moe output
|
||||
recycler.deepseek_mtp_layer_input = memory_buffers[2].view(dtype)[:mtp_layer_input_numels].view(tokens, mtp_layer_input_dims)
|
||||
return True
|
||||
132
vllm_vacc/vllm/model_executor/models/memory/memory_recycling.py
Normal file
132
vllm_vacc/vllm/model_executor/models/memory/memory_recycling.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import os
|
||||
import torch
|
||||
from .allocator import VaccHugeMemoryAllocator
|
||||
from .deepseek_v3_memory_recycler import DeepseekMTPMemoryRecycler
|
||||
|
||||
VLLM_MODEL_MODE = os.environ.get("VLLM_MODEL_MODE", "deepseek")
|
||||
|
||||
global huge_memory_alloctor
|
||||
global memory_recycler
|
||||
huge_memory_alloctor = None
|
||||
memory_recycler = None
|
||||
|
||||
# you should call this function when new request is in
|
||||
def alloc_memory_recycler(tokens, dtype=torch.bfloat16, **argv):
|
||||
global huge_memory_alloctor
|
||||
global memory_recycler
|
||||
|
||||
vllm_model = argv.get('vllm_model')
|
||||
if vllm_model is None:
|
||||
print("model infos is empty, now using VLLM_MODEL_MODE")
|
||||
vllm_model = VLLM_MODEL_MODE
|
||||
|
||||
# TODO: use default memory-recycle schedule
|
||||
# if vllm_model in ['xxx']:
|
||||
# vllm_model = "llm_default"
|
||||
|
||||
memory_recycler = None
|
||||
if vllm_model == "deepseek":
|
||||
from .deepseek_v3_memory_recycler import DeepseekV3MemoryRecycler, alloc_memory_recycler_deepseek_v3
|
||||
memory_recycler = DeepseekV3MemoryRecycler()
|
||||
state = alloc_memory_recycler_deepseek_v3(tokens, huge_memory_alloctor, memory_recycler, dtype)
|
||||
if not state:
|
||||
del memory_recycler
|
||||
memory_recycler = None
|
||||
return state
|
||||
|
||||
if vllm_model == "deepseek_mtp":
|
||||
from .deepseek_v3_memory_recycler import DeepseekMTPMemoryRecycler, alloc_memory_recycler_deepseek_mtp
|
||||
memory_recycler = DeepseekMTPMemoryRecycler()
|
||||
if argv.get('world_size') is None:
|
||||
print("mtp should have TP world size, memory recycler allock fail")
|
||||
return False
|
||||
|
||||
state = alloc_memory_recycler_deepseek_mtp(tokens, huge_memory_alloctor, argv['world_size'], memory_recycler, dtype)
|
||||
if not state:
|
||||
del memory_recycler
|
||||
memory_recycler = None
|
||||
return state
|
||||
|
||||
if vllm_model == "qwen3_moe":
|
||||
from .qwen3_moe_memory_recycler import QWen3MoeMemoryRecycler, alloc_memory_recycler_qwen3_moe
|
||||
memory_recycler = QWen3MoeMemoryRecycler()
|
||||
state = alloc_memory_recycler_qwen3_moe(tokens, huge_memory_alloctor, memory_recycler, dtype)
|
||||
if not state:
|
||||
del memory_recycler
|
||||
memory_recycler = None
|
||||
return state
|
||||
|
||||
if vllm_model == "llm_default":
|
||||
from .allocator import LLMMemoryRecycler, alloc_memory_recycler_llm
|
||||
memory_recycler = LLMMemoryRecycler()
|
||||
state = alloc_memory_recycler_llm(tokens, huge_memory_alloctor, memory_recycler, dtype)
|
||||
if not state:
|
||||
del memory_recycler
|
||||
memory_recycler = None
|
||||
return False
|
||||
|
||||
# LLM pipeline parallel 方案下, 对于非stage0的PART,
|
||||
# 在接收来自BEFORE PART的hiddens, residual的时候
|
||||
# 符合内存复用规则
|
||||
# 该过程与llm forward过程相独立, 需要单独维护
|
||||
# hiddens 对应 llm forward的时候,moe mlp buffer
|
||||
# residual 对应 llm forward的时候,embedding buffer
|
||||
def alloc_pipeline_parallel_recycler_buffer(size:torch.Size, dtype:torch.dtype, key:str):
|
||||
global huge_memory_alloctor
|
||||
if huge_memory_alloctor is None:
|
||||
return None
|
||||
|
||||
intermize_tensor_dict = {
|
||||
"hidden_states": 2,
|
||||
"attention": 1,
|
||||
"residual": 0
|
||||
}
|
||||
|
||||
if not key in intermize_tensor_dict:
|
||||
return None
|
||||
|
||||
src_tensors = huge_memory_alloctor.memory_buffers[intermize_tensor_dict[key]]
|
||||
|
||||
all_bytes = size.numel() * huge_memory_alloctor.get_dtype_bytes(dtype)
|
||||
return src_tensors[:all_bytes].view(dtype).view(size)
|
||||
|
||||
# you should call this function when new server and every workers is start
|
||||
def init_huge_memory_allocator(max_tokens, hidden_size, vllm_model = None):
|
||||
global huge_memory_alloctor
|
||||
if vllm_model is None:
|
||||
print("model infos is empty, now using VLLM_MODEL_MODE")
|
||||
vllm_model = VLLM_MODEL_MODE
|
||||
|
||||
if huge_memory_alloctor is not None:
|
||||
del huge_memory_alloctor
|
||||
torch.vacc.empty_cache()
|
||||
huge_memory_alloctor = None
|
||||
|
||||
if vllm_model == "deepseek":
|
||||
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
|
||||
huge_memory_alloctor.init_buffers(max_tokens, hidden_size)
|
||||
return True
|
||||
|
||||
# deepseek_mtp set use_congituous = True
|
||||
# deepseek_mtp buffer recycler:
|
||||
# buffer[0]: normal_buffer -> embedding_output
|
||||
# buffer[1]: dynamic_buffer -> mla_oproj_output, dynamic_output
|
||||
# buffer[2]: normal_buffer -> moe_shared_mlp_output, deepseek_mtp_layer_input
|
||||
if vllm_model == "deepseek_mtp":
|
||||
# deepseek dynamic tokens last block use for mtp-weights
|
||||
# dynamic_buffer_max_tokens = max_tokens + 128, we will let mtp only support 48K now
|
||||
dynamic_buffer_max_tokens = max_tokens
|
||||
# dynamic bufffer use more 128tokens for broadcast
|
||||
# positions, input tokens
|
||||
deepseek_mtp_max_tokens = max_tokens
|
||||
|
||||
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
|
||||
# huge_memory_alloctor.init_buffers(max_tokens, 7168)
|
||||
huge_memory_alloctor.init_buffers_with_dynamic(deepseek_mtp_max_tokens, dynamic_buffer_max_tokens, hidden_size, [False, True, False])
|
||||
return True
|
||||
|
||||
if vllm_model == "qwen3_moe":
|
||||
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
|
||||
huge_memory_alloctor.init_buffers(max_tokens, hidden_size)
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from .allocator import VaccHugeMemoryAllocator, LLMMemoryRecycler
|
||||
'''
|
||||
QWen3-Moe support 56K input tokens, there are 3 buffers can recycle:
|
||||
1. parallel_embedding output buffer
|
||||
2. mla_oproject output buffer
|
||||
3. moe_shared_mlp output buffer
|
||||
'''
|
||||
class QWen3MoeMemoryRecycler(LLMMemoryRecycler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def alloc_memory_recycler_qwen3_moe(tokens,
|
||||
alloctor:VaccHugeMemoryAllocator,
|
||||
recycler:QWen3MoeMemoryRecycler,
|
||||
dtype:torch.dtype = torch.bfloat16):
|
||||
|
||||
if not alloctor.enable:
|
||||
print("memory alloctor is not Init.")
|
||||
return False
|
||||
|
||||
recycler.clear()
|
||||
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
|
||||
|
||||
if out_buffers is None:
|
||||
print("qwen3_moe memory recycler buffers alloc fail. now disable it")
|
||||
return False
|
||||
|
||||
if len(out_buffers) != recycler.count:
|
||||
print("memory recycler buffers not equal qwen3_moe.")
|
||||
return False
|
||||
|
||||
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
|
||||
return True
|
||||
|
||||
1457
vllm_vacc/vllm/model_executor/models/qwen2.py
Normal file
1457
vllm_vacc/vllm/model_executor/models/qwen2.py
Normal file
File diff suppressed because it is too large
Load Diff
33
vllm_vacc/vllm/model_executor/models/qwen2_5_vl.py
Normal file
33
vllm_vacc/vllm/model_executor/models/qwen2_5_vl.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
|
||||
from vacc_tools.trace_logger import get_trace_api
|
||||
trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
|
||||
get_trace_api("deepseek")
|
||||
)
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Qwen2_5_VisionAttention(nn.Module):
|
||||
|
||||
# @trace_time('Qwen2_5_VisionAttention_vacc_split_qkv')
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
q1, k1, v1 = qkv.chunk(3, dim=-1)
|
||||
q1, k1, v1 = (x.view(*new_shape) for x in (q1, k1, v1))
|
||||
return q1, k1, v1
|
||||
|
||||
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||
|
||||
# convert conv3d to matmul
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.matmul(x, self.proj.weight.view(self.hidden_size, -1).T)
|
||||
|
||||
285
vllm_vacc/vllm/model_executor/models/qwen2_vl.py
Normal file
285
vllm_vacc/vllm/model_executor/models/qwen2_vl.py
Normal file
@@ -0,0 +1,285 @@
|
||||
|
||||
"""Inference-only Qwen2VL model compatible with HuggingFace weights."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce, parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.models.qwen2_vl import Qwen2VisionAttention as Qwen2VisionAttentionOrg
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.logger import init_logger
|
||||
from .hf_processor.qwenvl_processor import Qwen2VLProcessorWithVacc
|
||||
from .hf_processor.qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
|
||||
from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm_vacc.vllm.model_executor.models.vars import USE_FUSED_QWEN_ATTENTION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class Qwen2VisionPatchEmbed(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self.proj, 'bias') and self.proj.bias is not None:
|
||||
return torch.nn.functional.linear(x, self.proj.weight.view(self.hidden_size, -1), self.proj.bias)
|
||||
return torch.matmul(x, self.proj.weight.view(self.embed_dim, -1).T)
|
||||
|
||||
|
||||
class Qwen2VLProcessingInfo():
|
||||
def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessorWithVacc:
|
||||
return self.ctx.get_hf_processor(
|
||||
Qwen2VLProcessorWithVacc,
|
||||
use_fast=kwargs.pop("use_fast", True),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFastWithVacc:
|
||||
return self.get_hf_processor(**kwargs).image_processor
|
||||
|
||||
import torch.nn.functional as F
|
||||
class Qwen2VisionTransformer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
# patchify
|
||||
x = x.to(device=self.device, dtype=self.dtype)
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# compute position embedding
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_qwen3_ops import rot_pos_emb_qwenvl
|
||||
sin_cache, cos_cache = rot_pos_emb_qwenvl(grid_thw, self.embed_dim, self.num_heads, self.spatial_merge_size, self.dtype, self.device)
|
||||
except Exception as e:
|
||||
logger.error(f"rot_pos_emb fused ops run fail, e:{e}")
|
||||
rotary_pos_emb = None
|
||||
else:
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
sin_cache, cos_cache = None, None
|
||||
|
||||
# tmp_rotary_pos_emb = self.transformer_rot_pos_emb(grid_thw)
|
||||
# qwen3_rotary_pos_emb = self.qwen3_rot_pos_emb(grid_thw)
|
||||
|
||||
# compute cu_seqlens
|
||||
grid_thw_ = torch.tensor(grid_thw)
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2],
|
||||
grid_thw_[:, 0]).cumsum(
|
||||
dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
cu_seqlens = cu_seqlens.tolist()
|
||||
max_seqlen, seqlens = None, None
|
||||
else:
|
||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(
|
||||
x,
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Qwen2VisionAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional["QuantizationConfig"] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super(Qwen2VisionAttentionOrg, self).__init__()
|
||||
# Per attention head and per partition values.
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
parallel_state.get_tensor_model_parallel_world_size())
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, self.tp_size)
|
||||
|
||||
# self.qkv = ColumnParallelLinear(input_size=embed_dim,
|
||||
# output_size=3 * projection_size,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.qkv",
|
||||
# disable_tp=use_data_parallel)
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
self.proj = RowParallelLinear(input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype())
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2-VL does not support {self.attn_backend} backend now.")
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
q1, k1, v1 = qkv.chunk(3, dim=-1)
|
||||
q1, k1, v1 = (x.view(*new_shape) for x in (q1, k1, v1))
|
||||
return q1, k1, v1
|
||||
|
||||
class Qwen2VisionBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
sin_cache: torch.Tensor,
|
||||
cos_cache: torch.Tensor,
|
||||
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
total_bytes = x.numel() * x.element_size() * get_tp_group().world_size
|
||||
reduce_result = get_tp_group().world_size > 1 and total_bytes < 4194304
|
||||
|
||||
# hidden_states = self.norm1(x)
|
||||
attn_outs = torch.vacc.fuse_atten_vit(
|
||||
hidden_states=x.view(-1, x.shape[-1]),
|
||||
hidden_states_norm_weight = self.norm1.weight,
|
||||
hidden_states_norm_bias = self.norm1.bias,
|
||||
# hidden_states_norm_weight = torch.Tensor(),
|
||||
# hidden_states_norm_bias = torch.Tensor(),
|
||||
qkv_proj_weight=self.attn.qkv.weight,
|
||||
qkv_proj_bias=self.attn.qkv.bias,
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
o_proj_weight=self.attn.proj.weight,
|
||||
o_proj_bias=self.attn.proj.bias if self.attn.proj.tp_rank == 0 else torch.Tensor(),
|
||||
seq_lens=cu_seqlens,
|
||||
sm_scale=-1,
|
||||
num_attention_heads=self.attn.num_attention_heads_per_partition * get_tp_group().world_size,
|
||||
flash_attention=True,
|
||||
reduce_result=reduce_result,
|
||||
world_size=get_tp_group().world_size,
|
||||
rank=get_tp_group().rank_in_group,
|
||||
group_id=get_tp_group().group_id,
|
||||
dev_info=get_tp_group().rank_device_infos
|
||||
)
|
||||
attn_out = attn_outs[0] if reduce_result else tensor_model_parallel_all_reduce(attn_outs[0])
|
||||
attn_out = attn_out.view(x.shape)
|
||||
|
||||
x = x + attn_out
|
||||
else:
|
||||
x = x + self.attn(
|
||||
self.norm1(x),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
class Qwen2VisionMLP():
|
||||
def forward(self, x: torch.Tensor):
|
||||
try:
|
||||
from torch_vacc.vacc import fuse_mlp_vision
|
||||
hiddens_shape = x.shape
|
||||
tp_rank_id = get_tp_group().rank_in_group
|
||||
fc2_bias = None if tp_rank_id > 0 else self.fc2.bias
|
||||
hidden_states = fuse_mlp_vision(x.view(-1, hiddens_shape[-1]),
|
||||
self.fc1.weight, # nk
|
||||
self.fc2.weight, # nk
|
||||
self.fc1.bias,
|
||||
fc2_bias,
|
||||
2) # 0 is gelu, 1 is relu, 2 is quick_gelu
|
||||
vacc_res = tensor_model_parallel_all_reduce(hidden_states).view(hiddens_shape)
|
||||
return vacc_res
|
||||
except Exception as e:
|
||||
logger.error(f"mlp fused ops run fail, e:{e}")
|
||||
|
||||
x_parallel, _ = self.fc1(x)
|
||||
x_parallel = self.act(x_parallel)
|
||||
x, _ = self.fc2(x_parallel)
|
||||
return x
|
||||
|
||||
class Qwen2VisionPatchMerger():
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.ln_q(x)
|
||||
x = x.view(-1, self.hidden_size)
|
||||
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc import patch_merger_vision
|
||||
tp_rank_id = get_tp_group().rank_in_group
|
||||
fc2_bias = None if tp_rank_id > 0 else mlp_fc2.bias
|
||||
|
||||
hidden_states = patch_merger_vision(x,
|
||||
mlp_fc1.weight,
|
||||
mlp_fc2.weight,
|
||||
mlp_fc1.bias,
|
||||
fc2_bias,
|
||||
0) #0 is gelu, 1 is silu
|
||||
vacc_res = tensor_model_parallel_all_reduce(hidden_states)
|
||||
return vacc_res
|
||||
except Exception as e:
|
||||
logger.error(f"merge patch fused vision mlp run fail, cased by:{e}")
|
||||
|
||||
x_parallel, _ = mlp_fc1(x)
|
||||
x_parallel = mlp_act(x_parallel)
|
||||
out, _ = mlp_fc2(x_parallel)
|
||||
return out
|
||||
194
vllm_vacc/vllm/model_executor/models/qwen3.py
Normal file
194
vllm_vacc/vllm/model_executor/models/qwen3.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Inference-only Qwen3 model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union, Any, Dict
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.logger import init_logger
|
||||
from .vars import *
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod as UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# uniform the params names from different quantize method
|
||||
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
|
||||
if isinstance(quant_method, UnquantizedLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.weight
|
||||
fused_params[name + '_weight_scale'] = torch.Tensor()
|
||||
fused_params[name + '_bias'] = None
|
||||
fused_params[name + '_qzeros'] = None
|
||||
elif isinstance(quant_method, Fp8LinearMethod):
|
||||
fused_params[name + '_weight'] = layer.weight
|
||||
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
elif isinstance(quant_method, GPTQLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.qweight
|
||||
fused_params[name + '_weight_scale'] = layer.scales
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
elif isinstance(quant_method, AWQLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.qweight
|
||||
fused_params[name + '_weight_scale'] = layer.scales
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant_method: {quant_method}")
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None # new added params
|
||||
) -> torch.Tensor:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata_all = forward_context.attn_metadata
|
||||
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
# reshape kvcache
|
||||
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
|
||||
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
|
||||
|
||||
if isinstance(attn_metadata_all, dict):
|
||||
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
||||
is_decode = attn_metadata.prefill_metadata is None
|
||||
else:
|
||||
is_decode = attn_metadata_all.prefill_metadata is None
|
||||
attn_metadata = attn_metadata_all
|
||||
|
||||
reduce_result = is_decode
|
||||
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
|
||||
# # only support 4M now
|
||||
# if total_bytes < 4194304:
|
||||
# reduce_result = True
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
if is_decode:
|
||||
positions = [i - 1 for i in attn_metadata.seq_lens]
|
||||
cos_cache = [self.rotary_emb.cos_cache[i:i+1, ...] for i in positions]
|
||||
sin_cache = [self.rotary_emb.sin_cache[i:i+1, ...] for i in positions]
|
||||
else:
|
||||
cos_cache = [self.rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||||
sin_cache = [self.rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||||
if residual is None:
|
||||
res_out = hidden_states
|
||||
#from torch_vacc.vacc import fuse_atten_qwen3
|
||||
attn_outs = None
|
||||
if not is_decode:
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
if memory_recycler is not None:
|
||||
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
|
||||
|
||||
total_num_kv_heads = self.total_num_kv_heads
|
||||
if self.total_num_kv_heads < get_tp_group().world_size:
|
||||
assert get_tp_group().world_size % self.total_num_kv_heads == 0
|
||||
total_num_kv_heads = get_tp_group().world_size
|
||||
|
||||
attn_outs = torch.vacc.fuse_atten_qwen3(
|
||||
# attn_outs = vacc_fused_attn_qwen3_naive(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
|
||||
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
|
||||
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
|
||||
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
|
||||
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
|
||||
q_layernorm_weight=self.fused_params['q_norm_weight'],
|
||||
k_layernorm_weight=self.fused_params['k_norm_weight'],
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=attn_metadata.block_tables, # tensor
|
||||
block_group_size=env_blk_grp_size,
|
||||
o_proj_weight=self.fused_params['o_proj_weight'],
|
||||
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
|
||||
o_proj_bias=self.fused_params['o_proj_bias'],
|
||||
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
sm_scale=self.scaling,
|
||||
num_attention_heads=self.total_num_heads,
|
||||
num_key_value_heads=total_num_kv_heads,
|
||||
flash_attention=is_decode, # decode use flash_atten by default
|
||||
is_decode=is_decode,
|
||||
reduce_result=reduce_result,
|
||||
world_size=get_tp_group().world_size,
|
||||
rank=get_tp_group().rank_in_group,
|
||||
group_id=get_tp_group().group_id,
|
||||
dev_info=get_tp_group().rank_device_infos,
|
||||
output_opt=attn_outs,
|
||||
res_opt=residual)
|
||||
|
||||
if residual is None:
|
||||
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
|
||||
else:
|
||||
res_out = attn_outs[1]
|
||||
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
|
||||
|
||||
return attn_out, res_out
|
||||
else:
|
||||
# orig code
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
if not hasattr(self.self_attn, "fused_params"):
|
||||
self.self_attn.fused_params = {}
|
||||
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
|
||||
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
|
||||
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
|
||||
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
|
||||
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
|
||||
|
||||
hidden_states, residual = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual)
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
790
vllm_vacc/vllm/model_executor/models/qwen3_moe.py
Normal file
790
vllm_vacc/vllm/model_executor/models/qwen3_moe.py
Normal file
@@ -0,0 +1,790 @@
|
||||
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union, List
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from torch_vacc.vacc.custom_ops_cpu import (
|
||||
w8a8_block_fp8_linear as w8a8_block_fp8_linear_cpu,
|
||||
)
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock, Qwen3MoeMLP
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
|
||||
|
||||
from ..ops.mrope_op import get_sin_cos_mrope
|
||||
from ..ops.qwen3_fused_moe import vacc_fused_prefill_moe_fp8, vacc_fused_decode_moe_fp8, recompute_moe_layer_blocksize
|
||||
from .vars import *
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# uniform the params names from different quantize method
|
||||
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
|
||||
if isinstance(quant_method, UnquantizedLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.weight
|
||||
fused_params[name + '_weight_scale'] = None
|
||||
fused_params[name + '_bias'] = None
|
||||
fused_params[name + '_qzeros'] = None
|
||||
if isinstance(quant_method, Fp8LinearMethod):
|
||||
fused_params[name + '_weight'] = layer.weight
|
||||
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
elif isinstance(quant_method, GPTQLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.qweight
|
||||
fused_params[name + '_weight_scale'] = layer.scales
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
elif isinstance(quant_method, AWQLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.qweight
|
||||
fused_params[name + '_weight_scale'] = layer.scales
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant_method: {quant_method}")
|
||||
|
||||
|
||||
|
||||
def apply_w8a8_block_fp8_linear_v2(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_scale = None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
block_size = [
|
||||
weight.shape[-2] // weight_scale.shape[-2],
|
||||
weight.shape[-1] // weight_scale.shape[-1],
|
||||
]
|
||||
|
||||
if input.device.type == "vacc":
|
||||
output = torch.vacc.w8a8_block_fp8_linear(
|
||||
input_2d, weight, input_scale, weight_scale, block_size
|
||||
)
|
||||
else:
|
||||
output = w8a8_block_fp8_linear_cpu(
|
||||
input_2d, weight, input_scale, weight_scale, block_size
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
def vacc_fused_attn_qwen3_naive(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
hidden_states_norm_weight: torch.Tensor,
|
||||
qkv_proj_weight: torch.Tensor,
|
||||
qkv_proj_weight_scale: torch.Tensor,
|
||||
qkv_proj_bias: Optional[torch.Tensor],
|
||||
qkv_proj_qzeros: Optional[torch.Tensor],
|
||||
q_layernorm_weight: torch.Tensor,
|
||||
k_layernorm_weight: torch.Tensor,
|
||||
sin_cache: List[torch.Tensor],
|
||||
cos_cache: List[torch.Tensor],
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
block_group_size: int,
|
||||
o_proj_weight: torch.Tensor,
|
||||
o_proj_weight_scale: torch.Tensor,
|
||||
o_proj_bias: Optional[torch.Tensor],
|
||||
o_proj_qzeros: Optional[torch.Tensor],
|
||||
seq_lens: List[int],
|
||||
sm_scale: float,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
flash_attention: bool,
|
||||
is_decode: bool,
|
||||
reduce_result: bool,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
group_id: int,
|
||||
dev_info: List[int] | Tuple[int],
|
||||
block_size: int = 16
|
||||
):
|
||||
if residual is not None:
|
||||
hidden_states = hidden_states + residual
|
||||
residual_out = hidden_states
|
||||
|
||||
hidden_states = torch.vacc.rms_norm(
|
||||
hidden_states.unsqueeze(0), hidden_states_norm_weight, 1e-6).squeeze(0)
|
||||
|
||||
# NOTE: for qwen3 and qwen2.5, head_dim is always 128
|
||||
head_dim = 128
|
||||
|
||||
# qkv gen
|
||||
qkv = apply_w8a8_block_fp8_linear_v2(
|
||||
input=hidden_states,
|
||||
weight=qkv_proj_weight,
|
||||
weight_scale=qkv_proj_weight_scale)
|
||||
|
||||
num_q_heads = num_attention_heads // world_size
|
||||
num_kv_heads = num_key_value_heads // world_size
|
||||
|
||||
q_size = head_dim * num_q_heads
|
||||
kv_size = head_dim * num_kv_heads
|
||||
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
|
||||
# q_by_head = self.q_norm.forward_native(q_by_head)
|
||||
q_norm = torch.vacc.rms_norm(q_by_head, q_layernorm_weight, 1e-6)
|
||||
# q = q_by_head.view(q.shap
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
|
||||
# k_by_head = k_norm.forward_native(k_by_head)
|
||||
k_norm = torch.vacc.rms_norm(k_by_head, k_layernorm_weight, 1e-6)
|
||||
# k = k_by_head.view(k.shap
|
||||
|
||||
v = v.view(-1, num_kv_heads, head_dim)
|
||||
|
||||
# q, k = self.rotary_emb(positions, q, k)
|
||||
start = 0
|
||||
attn_outs = []
|
||||
|
||||
if is_decode:
|
||||
# convert block_tables to 8K group index
|
||||
block_per_group = block_group_size // block_size
|
||||
block_tables = (block_tables // block_per_group).to(torch.int32)
|
||||
# logger.warning(f"decode block table: {block_tables}")
|
||||
|
||||
num_blocks = kv_cache.shape[1]
|
||||
key_cache_split = kv_cache[0].view(num_blocks, -1, num_kv_heads, head_dim)
|
||||
value_cache_split = kv_cache[1].view(num_blocks, -1, num_kv_heads, head_dim)
|
||||
|
||||
# bs loop
|
||||
for i in range(len(seq_lens)):
|
||||
if not is_decode:
|
||||
# prefill
|
||||
end = start + seq_lens[i]
|
||||
else:
|
||||
# decode
|
||||
end = start + 1
|
||||
|
||||
cos = cos_cache[i].unsqueeze(-2)
|
||||
sin = sin_cache[i].unsqueeze(-2)
|
||||
|
||||
q, k = torch.vacc.RotaryPosEmbedding(
|
||||
q_norm[start : end, ...], k_norm[start : end, ...], cos, sin, 0, "neox")
|
||||
|
||||
# cache concat
|
||||
torch.vacc.reshape_and_cache_attention(k, key_cache_split, slot_mapping[start : end, ...])
|
||||
torch.vacc.reshape_and_cache_attention(v[start : end, ...], value_cache_split, slot_mapping[start : end, ...])
|
||||
|
||||
# attn_output = self.attn(q, k, v)
|
||||
if not is_decode:
|
||||
# prefill
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v[start : end, ...],
|
||||
attn_mask = None,
|
||||
dropout_p = 0.0,
|
||||
is_causal = True, #causal_attn and not self.need_mask,
|
||||
is_train = False,
|
||||
recompute = False,
|
||||
flash_attention = False,
|
||||
sm_scale=sm_scale)
|
||||
else:
|
||||
# decode
|
||||
key_cache = key_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
|
||||
value_cache = value_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
|
||||
|
||||
k_slices = key_cache[block_tables[i], ...]
|
||||
k_cached = torch.cat(
|
||||
[k_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
|
||||
dim=0,
|
||||
)
|
||||
k_cached = k_cached.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
|
||||
|
||||
v_slices = value_cache[block_tables[i], ...]
|
||||
v_cached = torch.cat(
|
||||
[v_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
|
||||
dim=0,
|
||||
)
|
||||
v_cached = v_cached.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k_cached,
|
||||
value=v_cached,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,#flash_attention,
|
||||
sm_scale=sm_scale)
|
||||
|
||||
attn_outs.append(attn_out)
|
||||
# update start
|
||||
start = end
|
||||
attn_out = torch.cat(attn_outs, dim=0)
|
||||
|
||||
# output, _ = self.o_proj(attn_output)
|
||||
o_proj = apply_w8a8_block_fp8_linear_v2(
|
||||
input = attn_out.reshape(hidden_states.shape[0], -1),
|
||||
weight = o_proj_weight,
|
||||
weight_scale = o_proj_weight_scale,
|
||||
)
|
||||
|
||||
if reduce_result:
|
||||
o_proj = tensor_model_parallel_all_reduce(o_proj)
|
||||
|
||||
if residual is not None:
|
||||
return o_proj, residual_out
|
||||
return o_proj
|
||||
|
||||
def Qwen3MoeSparseMoeBlock__init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super(Qwen3MoeSparseMoeBlock, self).__init__()
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = self.ep_group.rank()
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts = config.num_experts
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.n_physical_experts = (self.n_logical_experts +
|
||||
self.n_redundant_experts)
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
|
||||
self.physical_expert_start = (self.ep_rank *
|
||||
self.n_local_physical_experts)
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
|
||||
self.experts = FusedMoE(num_experts=self.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=True,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
#patch here to transpose w2/w2_scale's data arrange , only for block quant
|
||||
if hasattr(self.experts.quant_method, 'quant_config') and hasattr(self.experts.quant_method.quant_config, 'weight_block_size'):
|
||||
self.experts.w2_weight.data = self.experts.w2_weight.data.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
self.experts.w2_weight_scale_inv.data = self.experts.w2_weight_scale_inv.data.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
if hasattr(self.experts, 'w2_weight_scale_inv_prefill'):
|
||||
self.experts.w2_weight_scale_inv_prefill.data = self.experts.w2_weight_scale_inv_prefill.data.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
|
||||
def get_cos_sin_cache(rotary_emb: Union["MRotaryEmbedding", "RotaryEmbedding"],
|
||||
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]],
|
||||
positions: Union[torch.Tensor, list],
|
||||
is_decode: bool):
|
||||
if isinstance(rotary_emb, MRotaryEmbedding):
|
||||
# get mrope sin/cos
|
||||
cos_cache, sin_cache = get_sin_cos_mrope(rotary_emb, positions)
|
||||
if len(attn_metadata.seq_lens) > 1:
|
||||
if is_decode:
|
||||
cos_cache = torch.chunk(cos_cache, len(attn_metadata.seq_lens))
|
||||
sin_cache = torch.chunk(sin_cache, len(attn_metadata.seq_lens))
|
||||
else:
|
||||
cos_cache = torch.split(cos_cache, attn_metadata.seq_lens)
|
||||
sin_cache = torch.split(sin_cache, attn_metadata.seq_lens)
|
||||
else:
|
||||
cos_cache = [cos_cache]
|
||||
sin_cache = [sin_cache]
|
||||
else:
|
||||
if is_decode:
|
||||
positions = [i - 1 for i in attn_metadata.seq_lens]
|
||||
cos_cache = [rotary_emb.cos_cache[i:i+1, ...] for i in positions]
|
||||
sin_cache = [rotary_emb.sin_cache[i:i+1, ...] for i in positions]
|
||||
else:
|
||||
cos_cache = [rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||||
sin_cache = [rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||||
return cos_cache, sin_cache
|
||||
|
||||
class Qwen3MoeDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
cos_cache: list[torch.Tensor],
|
||||
sin_cache: list[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
if not hasattr(self.self_attn, "fused_params"):
|
||||
self.self_attn.fused_params = {}
|
||||
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
|
||||
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
|
||||
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
|
||||
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
|
||||
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
|
||||
|
||||
hidden_states, residual = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
cos_cache=cos_cache,
|
||||
sin_cache=sin_cache)
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
cos_cache=cos_cache,
|
||||
sin_cache=sin_cache
|
||||
)
|
||||
|
||||
# # Fully Connected
|
||||
# hidden_states, residual = self.post_attention_layernorm(
|
||||
# hidden_states, residual)
|
||||
# hidden_states = self.mlp(hidden_states)
|
||||
# return hidden_states, residual
|
||||
|
||||
# TODO for noquant or not block_quant
|
||||
if not hasattr(self.mlp.experts.quant_method, 'quant_config') or \
|
||||
not hasattr(self.mlp.experts.quant_method.quant_config, 'weight_block_size'):
|
||||
if not isinstance(self.mlp.experts.quant_method, MoeWNA16Method):
|
||||
logger.warning('TODO for noquant or other quant')
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
if isinstance(attn_metadata, dict):
|
||||
# is_prefill = get_forward_context().attn_metadata['test'].prefill_metadata
|
||||
attn_metadata_0 = get_forward_context().attn_metadata.items().__iter__().__next__()[1]
|
||||
is_prefill = attn_metadata_0.prefill_metadata
|
||||
|
||||
else:
|
||||
is_prefill = get_forward_context().attn_metadata.prefill_metadata
|
||||
|
||||
quant_method = self.mlp.experts.quant_method if isinstance(self.mlp, Qwen3MoeSparseMoeBlock) \
|
||||
else self.mlp.down_proj.quant_method
|
||||
|
||||
if is_prefill is not None:
|
||||
if isinstance(quant_method, MoeWNA16Method):
|
||||
try:
|
||||
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_prefill_moe_gptq_int4
|
||||
return vacc_fused_prefill_moe_gptq_int4(hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm,
|
||||
self.mlp.gate,
|
||||
self.mlp.experts)
|
||||
except Exception as e:
|
||||
print(f'vacc_fused_prefill_moe_gptq_int4 fail: {e}')
|
||||
else:
|
||||
recompute_moe_layer_blocksize(self.mlp.experts)
|
||||
try:
|
||||
return vacc_fused_prefill_moe_fp8(hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm,
|
||||
self.mlp.gate,
|
||||
self.mlp.experts)
|
||||
except Exception as e:
|
||||
print(f'vacc_fused_prefill_moe_fp8 fail: {e}')
|
||||
else:
|
||||
if isinstance(quant_method, MoeWNA16Method):
|
||||
try:
|
||||
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_decode_moe_gptq_int4
|
||||
return vacc_fused_decode_moe_gptq_int4(hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm,
|
||||
self.mlp.gate,
|
||||
self.mlp.experts)
|
||||
except Exception as e:
|
||||
print(f'vacc_fused_decode_moe_gptq_int4 fail: {e}')
|
||||
else:
|
||||
try:
|
||||
return vacc_fused_decode_moe_fp8(hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm,
|
||||
self.mlp.gate,
|
||||
self.mlp.experts)
|
||||
except Exception as e:
|
||||
print(f'vacc_fused_decode_moe_fp8 fail: {e}')
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
class Qwen3MoeAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None, # new added params
|
||||
cos_cache: list[torch.Tensor] = None,
|
||||
sin_cache: list[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata_all = forward_context.attn_metadata
|
||||
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
# reshape kvcache
|
||||
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
|
||||
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
|
||||
|
||||
if isinstance(attn_metadata_all, dict):
|
||||
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
||||
is_decode = attn_metadata.prefill_metadata is None
|
||||
else:
|
||||
is_decode = attn_metadata_all.prefill_metadata is None
|
||||
attn_metadata = attn_metadata_all
|
||||
|
||||
|
||||
reduce_result = is_decode
|
||||
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
|
||||
# # only support 4M now
|
||||
# if total_bytes < 4194304:
|
||||
# reduce_result = True
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
if cos_cache is None or sin_cache is None:
|
||||
cos_cache, sin_cache = get_cos_sin_cache(self.rotary_emb, attn_metadata, positions, is_decode)
|
||||
|
||||
if residual is None:
|
||||
res_out = hidden_states
|
||||
#from torch_vacc.vacc import fuse_atten_qwen3
|
||||
attn_outs = None
|
||||
if not is_decode:
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
if memory_recycler is not None:
|
||||
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
|
||||
|
||||
total_num_kv_heads = self.total_num_kv_heads
|
||||
if self.total_num_kv_heads < get_tp_group().world_size:
|
||||
assert get_tp_group().world_size % self.total_num_kv_heads == 0
|
||||
total_num_kv_heads = get_tp_group().world_size
|
||||
attn_outs = torch.vacc.fuse_atten_qwen3(
|
||||
# attn_outs = vacc_fused_attn_qwen3_naive(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
|
||||
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
|
||||
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
|
||||
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
|
||||
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
|
||||
q_layernorm_weight=self.fused_params['q_norm_weight'],
|
||||
k_layernorm_weight=self.fused_params['k_norm_weight'],
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_group_size=env_blk_grp_size,
|
||||
o_proj_weight=self.fused_params['o_proj_weight'],
|
||||
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
|
||||
o_proj_bias=self.fused_params['o_proj_bias'],
|
||||
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
sm_scale=self.scaling,
|
||||
num_attention_heads=self.total_num_heads,
|
||||
num_key_value_heads=total_num_kv_heads,
|
||||
flash_attention=is_decode, # decode use flash_atten by default
|
||||
is_decode=is_decode,
|
||||
reduce_result=reduce_result,
|
||||
world_size=get_tp_group().world_size,
|
||||
rank=get_tp_group().rank_in_group,
|
||||
group_id=get_tp_group().group_id,
|
||||
dev_info=get_tp_group().rank_device_infos,
|
||||
output_opt=attn_outs,
|
||||
res_opt=residual)
|
||||
# debug_qwen3_moe_attention_prefill(hidden_states=hidden_states,
|
||||
# residual=residual,
|
||||
# attn_outs=attn_outs,
|
||||
# fused_params=self.fused_params,
|
||||
# attn_metadata=attn_metadata,
|
||||
# is_decode=is_decode,
|
||||
# sin_cache=sin_cache,
|
||||
# cos_cache=cos_cache,
|
||||
# kv_cache=kv_cache,
|
||||
# env_blk_grp_size=env_blk_grp_size,
|
||||
# scaling=self.scaling,
|
||||
# total_num_heads=self.total_num_heads,
|
||||
# total_num_kv_heads=self.total_num_kv_heads,
|
||||
# world_size=get_tp_group().world_size,
|
||||
# rank=get_tp_group().rank_in_group,
|
||||
# group_id=get_tp_group().group_id,
|
||||
# dev_info=get_tp_group().rank_device_infos)
|
||||
|
||||
if residual is None:
|
||||
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
|
||||
else:
|
||||
res_out = attn_outs[1]
|
||||
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
|
||||
|
||||
return attn_out, res_out
|
||||
else:
|
||||
# orig code
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_by_head = self.q_norm.forward_native(q_by_head)
|
||||
|
||||
q = q_by_head.view(q.shape)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_by_head = self.k_norm.forward_native(k_by_head)
|
||||
|
||||
k = k_by_head.view(k.shape)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
class Qwen3MoeModel(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
deepstack_input_embeds: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata_all = forward_context.attn_metadata
|
||||
if not hasattr(self, "weight_capture"):
|
||||
from vllm_vacc.vllm.model_executor.models.weight_capture.qwen3_moe_weight_capture import Qwen3Moe_WeightCapture
|
||||
self.weight_capture = Qwen3Moe_WeightCapture(self.layers, self.start_layer, self.end_layer)
|
||||
self.layer_nums = self.end_layer - self.start_layer
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
# fused layer decoder only support fp8 quant model now
|
||||
use_default_layer = self.weight_capture.support_fused_weights and USE_DECODER_LAYER_FUSE_MODE
|
||||
# print('Qwen3MoeModel attn_metadata', attn_metadata)
|
||||
if isinstance(attn_metadata_all, dict):
|
||||
# is_decode = attn_metadata_all['test'].prefill_metadata is None
|
||||
# attn_metadata = attn_metadata_all['test']
|
||||
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
||||
is_decode = attn_metadata.prefill_metadata is None
|
||||
|
||||
else:
|
||||
is_decode = attn_metadata_all.prefill_metadata is None
|
||||
attn_metadata = attn_metadata_all
|
||||
|
||||
if(use_default_layer and is_decode):
|
||||
from torch_vacc.vacc.custom_ops import qwen3_fuse_attention_moe_decode
|
||||
|
||||
layer0 = self.layers[self.start_layer]
|
||||
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode=True)
|
||||
|
||||
for i in range(0, self.layer_nums):
|
||||
layer = self.layers[i + self.start_layer]
|
||||
kv_cache = layer.self_attn.attn.kv_cache[forward_context.virtual_engine]
|
||||
num_kv_heads = max(1, layer.self_attn.total_num_kv_heads // get_tp_group().world_size)
|
||||
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, layer.self_attn.head_dim)
|
||||
total_num_kv_heads = layer.self_attn.total_num_kv_heads
|
||||
if layer.self_attn.total_num_kv_heads < get_tp_group().world_size:
|
||||
assert get_tp_group().world_size % layer.self_attn.total_num_kv_heads == 0
|
||||
total_num_kv_heads = get_tp_group().world_size
|
||||
|
||||
hidden_states, residual = qwen3_fuse_attention_moe_decode(hidden_states, residual,
|
||||
hidden_states_norm_weight=self.weight_capture.layer_mapper.attn_args._0_input_layernorm_weight[i],
|
||||
qkv_proj_weight=self.weight_capture.layer_mapper.attn_args._1_qkv_proj_weight[i],
|
||||
qkv_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._2_qkv_proj_weight_scale[i],
|
||||
qkv_proj_bias=self.weight_capture.layer_mapper.attn_args._3_qkv_proj_bias[i],
|
||||
qkv_proj_qzeros=self.weight_capture.layer_mapper.attn_args._4_qkv_proj_qzeros[i],
|
||||
q_layernorm_weight=self.weight_capture.layer_mapper.attn_args._5_q_norm_weight[i],
|
||||
k_layernorm_weight=self.weight_capture.layer_mapper.attn_args._6_k_norm_weight[i],
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_group_size=env_blk_grp_size,
|
||||
o_proj_weight=self.weight_capture.layer_mapper.attn_args._13_o_proj_weight[i],
|
||||
o_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._14_o_proj_weight_scale[i],
|
||||
o_proj_bias=self.weight_capture.layer_mapper.attn_args._15_o_proj_bias[i],
|
||||
o_proj_qzeros=self.weight_capture.layer_mapper.attn_args._16_o_proj_qzeros[i],
|
||||
seq_lens_num=attn_metadata.seq_lens,
|
||||
sm_scale=layer.self_attn.scaling,
|
||||
num_attention_heads=layer.self_attn.total_num_heads,
|
||||
num_key_value_heads=total_num_kv_heads,
|
||||
flash_attentiton=True,
|
||||
is_decode=True,
|
||||
reduce_result=True,
|
||||
# moe
|
||||
rms_weight=self.weight_capture.layer_mapper.moe_args._0_rms_norm_weight[i],
|
||||
moe_weight_13=self.weight_capture.layer_mapper.moe_args._1_w13_weight[i],
|
||||
moe_weight_2=self.weight_capture.layer_mapper.moe_args._2_w2_weight[i],
|
||||
moe_weight_13_dequat=self.weight_capture.layer_mapper.moe_args._3_w13_weight_scale_inv[i],
|
||||
moe_weight_2_dequant=self.weight_capture.layer_mapper.moe_args._4_w2_weight_scale_inv[i],
|
||||
gate_weight=self.weight_capture.layer_mapper.moe_args._5_gate_weight[i],
|
||||
block_size_13=self.weight_capture.layer_mapper.moe_args._6_w13_block_size,
|
||||
block_size_2=self.weight_capture.layer_mapper.moe_args._7_w2_block_size,
|
||||
# dist
|
||||
world_size=self.weight_capture.layer_mapper.dist_args._0_world_size,
|
||||
rank=self.weight_capture.layer_mapper.dist_args._1_rank,
|
||||
group_id=self.weight_capture.layer_mapper.dist_args._2_group_id,
|
||||
dev_info=self.weight_capture.layer_mapper.dist_args._3_dev_info)
|
||||
else:
|
||||
layer0 = self.layers[self.start_layer]
|
||||
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode)
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual, cos_cache, sin_cache )
|
||||
if deepstack_input_embeds is not None and i in range(0, len(deepstack_input_embeds)):
|
||||
if isinstance(deepstack_input_embeds, IntermediateTensors):
|
||||
hidden_states = hidden_states + deepstack_input_embeds[f"deepstack_input_embeds_{i}"]
|
||||
elif isinstance(deepstack_input_embeds, torch.Tensor):
|
||||
hidden_states = hidden_states + deepstack_input_embeds[i]
|
||||
else:
|
||||
raise ValueError(f'unsupported type: {type(deepstack_input_embeds)}')
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
if residual is not None:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
class Qwen3MoeForCausalLM(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
deepstack_input_embeds = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
|
||||
if attn_metadata.prefill_metadata is not None:
|
||||
from .memory.memory_recycling import alloc_memory_recycler
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
if hasattr(attn_metadata, 'num_prefill_tokens'):
|
||||
tokens = attn_metadata.num_prefill_tokens
|
||||
else:
|
||||
tokens = attn_metadata.prefill_metadata.num_prefill_tokens
|
||||
|
||||
vllm_model_mode = "qwen3_moe"
|
||||
config_infos = vllm_vacc_config_manager().get_model_infos()
|
||||
if config_infos != "default":
|
||||
vllm_model_mode = config_infos
|
||||
|
||||
if get_tp_group().rank_in_group == 0:
|
||||
memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}'
|
||||
logger.info(memory_infos)
|
||||
|
||||
if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size, dtype=self.lm_head.weight.dtype):
|
||||
logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, deepstack_input_embeds)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
|
||||
from .memory.memory_recycling import init_huge_memory_allocator
|
||||
from .vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
|
||||
# default is deepseek, config can set to ['deepseek_mtp',]
|
||||
model_name = "qwen3_moe"
|
||||
config_infos = vllm_vacc_config_manager().get_model_infos()
|
||||
if config_infos != "default":
|
||||
model_name = config_infos
|
||||
|
||||
if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name):
|
||||
logger.warning("init huge memory allocator fail. prefill memory recycling will disable")
|
||||
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
362
vllm_vacc/vllm/model_executor/models/qwen3_vl.py
Normal file
362
vllm_vacc/vllm/model_executor/models/qwen3_vl.py
Normal file
@@ -0,0 +1,362 @@
|
||||
|
||||
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from vllm.logger import init_logger
|
||||
from .hf_processor.qwenvl_processor import Qwen3VLProcessorWithVacc
|
||||
from .hf_processor.qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
|
||||
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
|
||||
|
||||
from .vars import USE_FUSED_QWEN_ATTENTION
|
||||
|
||||
# from vacc_tools.trace_logger import get_trace_api
|
||||
# trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
|
||||
# get_trace_api("Qwen3vl")
|
||||
# )
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class Qwen3_VisionPatchEmbed(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self.proj, 'bias') and self.proj.bias is not None:
|
||||
return torch.nn.functional.linear(x, self.proj.weight.view(self.hidden_size, -1), self.proj.bias)
|
||||
return torch.matmul(x, self.proj.weight.view(self.hidden_size, -1).T)
|
||||
|
||||
class Qwen3_VisionBlock(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor | list[torch.Tensor],
|
||||
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
assert isinstance(rotary_pos_emb, list), "qwen3vl vit-attention need rotary_pos_emb is list[torch.Tensor]"
|
||||
|
||||
total_bytes = x.numel() * x.element_size() * get_tp_group().world_size
|
||||
reduce_result = get_tp_group().world_size > 1 and total_bytes < 4194304
|
||||
|
||||
# hidden_states = self.norm1(x)
|
||||
attn_outs = torch.vacc.fuse_atten_vit(
|
||||
hidden_states=x.view(-1, x.shape[-1]),
|
||||
hidden_states_norm_weight = self.norm1.weight,
|
||||
hidden_states_norm_bias = self.norm1.bias,
|
||||
# hidden_states_norm_weight = torch.Tensor(),
|
||||
# hidden_states_norm_bias = torch.Tensor(),
|
||||
qkv_proj_weight=self.attn.qkv.weight,
|
||||
qkv_proj_bias=self.attn.qkv.bias,
|
||||
sin_cache=rotary_pos_emb[0],
|
||||
cos_cache=rotary_pos_emb[1],
|
||||
o_proj_weight=self.attn.proj.weight,
|
||||
o_proj_bias=self.attn.proj.bias if self.attn.proj.tp_rank == 0 else torch.Tensor(),
|
||||
seq_lens=cu_seqlens,
|
||||
sm_scale=-1,
|
||||
num_attention_heads=self.attn.num_attention_heads_per_partition * get_tp_group().world_size,
|
||||
flash_attention=True,
|
||||
reduce_result=reduce_result,
|
||||
world_size=get_tp_group().world_size,
|
||||
rank=get_tp_group().rank_in_group,
|
||||
group_id=get_tp_group().group_id,
|
||||
dev_info=get_tp_group().rank_device_infos
|
||||
)
|
||||
attn_out = attn_outs[0] if reduce_result else tensor_model_parallel_all_reduce(attn_outs[0])
|
||||
attn_out = attn_out.view(x.shape)
|
||||
|
||||
x = x + attn_out
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
class Qwen3_VisionTransformer(nn.Module):
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_qwen3_ops import rot_pos_emb_qwenvl
|
||||
return rot_pos_emb_qwenvl(grid_thw, self.hidden_size, self.num_heads, self.spatial_merge_size, self.dtype, self.device)
|
||||
except Exception as e:
|
||||
logger.error(f"rot_pos_emb fused ops run fail, e:{e}")
|
||||
|
||||
pos_ids = []
|
||||
# Support both Tensor and list inputs for DP path
|
||||
if isinstance(grid_thw, list):
|
||||
grid_list = grid_thw
|
||||
max_grid_size = max(max(h, w) for _, h, w in grid_list)
|
||||
else:
|
||||
grid_list = grid_thw.tolist()
|
||||
max_grid_size = int(grid_thw[:, 1:].max().item())
|
||||
for t, h, w in grid_list:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
||||
hpos_ids = hpos_ids.flatten()
|
||||
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||
wpos_ids = wpos_ids.flatten()
|
||||
pos_ids.append(
|
||||
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def fast_pos_embed_interpolate(self,
|
||||
grid_thw: list[list[int]]) -> torch.Tensor:
|
||||
num_grid_per_side = self.num_grid_per_side
|
||||
m_size = self.spatial_merge_size
|
||||
hidden_dim = self.pos_embed.embedding_dim
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc.custom_qwen3_ops import fast_pos_embed_interpolate_qwenvl
|
||||
return fast_pos_embed_interpolate_qwenvl(self.pos_embed.weight, grid_thw, num_grid_per_side, m_size, hidden_dim)
|
||||
except Exception as e:
|
||||
logger.error(f"fast_pos_embed_interpolate fused ops run fail, e:{e}")
|
||||
|
||||
outputs = []
|
||||
for t, h, w in grid_thw:
|
||||
h_idxs = torch.linspace(0,
|
||||
num_grid_per_side - 1,
|
||||
h,
|
||||
dtype=torch.float32,
|
||||
device=self.device)
|
||||
w_idxs = torch.linspace(0,
|
||||
num_grid_per_side - 1,
|
||||
w,
|
||||
dtype=torch.float32,
|
||||
device=self.device)
|
||||
|
||||
h_floor = h_idxs.to(torch.long)
|
||||
w_floor = w_idxs.to(torch.long)
|
||||
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
|
||||
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
|
||||
|
||||
dh = h_idxs - h_floor
|
||||
dw = w_idxs - w_floor
|
||||
|
||||
# Create meshgrid view for all h, w vars
|
||||
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij')
|
||||
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor,
|
||||
w_floor,
|
||||
indexing='ij')
|
||||
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil,
|
||||
w_ceil,
|
||||
indexing='ij')
|
||||
h_floor_grid_idx = h_floor_grid * num_grid_per_side
|
||||
h_ceil_grid_idx = h_ceil_grid * num_grid_per_side
|
||||
|
||||
# original computation of weights
|
||||
# w00 = (1 - dh_grid) * (1 - dw_grid)
|
||||
# w01 = (1 - dh_grid) * dw_grid
|
||||
# w10 = dh_grid * (1 - dw_grid)
|
||||
# w11 = dh_grid * dw_grid
|
||||
# we reuse w11 here to avoid duplicate
|
||||
# dh_grid * dw_grid computation
|
||||
w11 = dh_grid * dw_grid
|
||||
w10 = dh_grid - w11
|
||||
w01 = dw_grid - w11
|
||||
w00 = 1 - dh_grid - dw_grid + w11
|
||||
|
||||
idx00 = h_floor_grid_idx + w_floor_grid
|
||||
idx01 = h_floor_grid_idx + w_ceil_grid
|
||||
idx10 = h_ceil_grid_idx + w_floor_grid
|
||||
idx11 = h_ceil_grid_idx + w_ceil_grid
|
||||
|
||||
indices = torch.stack([idx00, idx01, idx10, idx11],
|
||||
dim=0).reshape(4, -1)
|
||||
weights = torch.stack([w00, w01, w10, w11],
|
||||
dim=0).reshape(4, -1, 1)
|
||||
weights = weights.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
embeds = self.pos_embed(indices)
|
||||
weighted_embeds = embeds * weights
|
||||
p0, p1, p2, p3 = weighted_embeds.unbind(dim=0)
|
||||
combined = p0 + p1 + p2 + p3
|
||||
|
||||
combined = combined.view(h * w, hidden_dim)
|
||||
repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous()
|
||||
repeated = repeated.view(t, h // m_size, m_size, w // m_size,
|
||||
m_size, hidden_dim)
|
||||
repeated = repeated.permute(0, 1, 3, 2, 4,
|
||||
5).reshape(-1, hidden_dim)
|
||||
outputs.append(repeated)
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
grid_thw_tensor = torch.tensor(grid_thw,
|
||||
dtype=torch.int32)
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
|
||||
grid_thw_tensor[:, 0]).cumsum(
|
||||
dim=0,
|
||||
dtype=grid_thw_tensor.dtype
|
||||
if torch.jit.is_tracing() else torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
|
||||
if isinstance(rotary_pos_emb, torch.Tensor):
|
||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
max_seqlen, seqlens = None, None
|
||||
cu_seqlens = cu_seqlens.tolist()
|
||||
else:
|
||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
hidden_states = blk(hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens)
|
||||
if layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_merger_idx = self.deepstack_visual_indexes.index(
|
||||
layer_num)
|
||||
deepstack_feature = self.deepstack_merger_list[
|
||||
deepstack_merger_idx](hidden_states)
|
||||
deepstack_feature_lists.append(deepstack_feature)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
hidden_states = torch.cat(
|
||||
[hidden_states] + deepstack_feature_lists,
|
||||
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
|
||||
return hidden_states
|
||||
|
||||
class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA, SupportsPP):
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
deepstack_input_embeds = None
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
if self.use_deepstack:
|
||||
deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501
|
||||
input_ids, inputs_embeds, multimodal_embeddings)
|
||||
self._set_deepstack_input_embeds(deepstack_input_embeds)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
[self.config.image_token_id, self.config.video_token_id])
|
||||
|
||||
# commit here to remove deepstack_input_embeds copy
|
||||
# if self.use_deepstack:
|
||||
# if deepstack_input_embeds is None:
|
||||
# deepstack_input_embeds = torch.zeros_like(
|
||||
# inputs_embeds).unsqueeze(0).repeat(
|
||||
# self.deepstack_num_level, 1, 1).contiguous()
|
||||
# self._set_deepstack_input_embeds(deepstack_input_embeds)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
|
||||
return #patch here to optimize deepstack_input_embeds
|
||||
# clear deepstack_input_embeds in buffer
|
||||
if num_tokens > 0:
|
||||
for idx in range(self.deepstack_num_level):
|
||||
self.deepstack_input_embeds[idx][:num_tokens].zero_()
|
||||
class Qwen3VLProcessingInfo():
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessorWithVacc:
|
||||
processor = self.ctx.get_hf_processor(
|
||||
Qwen3VLProcessorWithVacc,
|
||||
use_fast=kwargs.pop("use_fast", True),
|
||||
**kwargs,
|
||||
)
|
||||
return processor
|
||||
|
||||
|
||||
def get_image_processor(self,
|
||||
**kwargs: object) -> Qwen2VLImageProcessorFastWithVacc:
|
||||
return self.get_hf_processor(**kwargs).image_processor
|
||||
|
||||
# def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
|
||||
# return self.get_hf_processor(**kwargs).video_processor
|
||||
|
||||
|
||||
class Qwen3_VisionPatchMerger():
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_postshuffle_norm:
|
||||
x = self.norm(x.view(-1, self.hidden_size))
|
||||
else:
|
||||
x = self.norm(x).view(-1, self.hidden_size)
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc import patch_merger_vision
|
||||
tp_rank_id = get_tp_group().rank_in_group
|
||||
fc2_bias = None if tp_rank_id > 0 else self.linear_fc2.bias
|
||||
|
||||
hidden_states = patch_merger_vision(x,
|
||||
self.linear_fc1.weight, self.linear_fc2.weight,
|
||||
self.linear_fc1.bias, fc2_bias,
|
||||
0) #0 is gelu, 1 is silu
|
||||
return tensor_model_parallel_all_reduce(hidden_states)
|
||||
except Exception as e:
|
||||
logger.error(f"merge patch fused vision mlp run fail, cased by:{e}")
|
||||
|
||||
x_parallel, _ = self.linear_fc1(x)
|
||||
x_parallel = self.act_fn(x_parallel)
|
||||
out, _ = self.linear_fc2(x_parallel)
|
||||
return out
|
||||
|
||||
class Qwen3_VisionMLP():
|
||||
def forward(self, x: torch.Tensor):
|
||||
try:
|
||||
from torch_vacc.vacc import fuse_mlp_vision
|
||||
hiddens_shape = x.shape
|
||||
tp_rank_id = get_tp_group().rank_in_group
|
||||
fc2_bias = None if tp_rank_id > 0 else self.linear_fc2.bias
|
||||
|
||||
hidden_states = fuse_mlp_vision(x.view(-1, hiddens_shape[-1]),
|
||||
self.linear_fc1.weight, self.linear_fc2.weight,
|
||||
self.linear_fc1.bias, fc2_bias,
|
||||
0) #0 is gelu, 1 is silu
|
||||
return tensor_model_parallel_all_reduce(hidden_states).view(hiddens_shape)
|
||||
except Exception as e:
|
||||
logger.error(f"qwen3vl fused vision mlp run fail, cased by:{e}")
|
||||
return self.linear_fc2(self.act_fn(self.linear_fc1(x)))
|
||||
|
||||
27
vllm_vacc/vllm/model_executor/models/roberta.py
Normal file
27
vllm_vacc/vllm/model_executor/models/roberta.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.models.bert import _decode_token_type_ids
|
||||
|
||||
|
||||
class RobertaEmbedding(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
token_type_ids = _decode_token_type_ids(input_ids)
|
||||
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
# position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
# token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
# embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
# embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = torch.vacc.fuse_bge_embedding_stage1(inputs_embeds, position_ids, self.position_embeddings.weight, token_type_ids, self.token_type_embeddings.weight, self.LayerNorm.weight, self.LayerNorm.bias, self.LayerNorm.eps)
|
||||
return embeddings
|
||||
58
vllm_vacc/vllm/model_executor/models/vars.py
Normal file
58
vllm_vacc/vllm/model_executor/models/vars.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
|
||||
# support Q,KV Gen with TP
|
||||
USE_PARALLEL_Q_KV_GEN = True
|
||||
|
||||
# support Merge Q,KV Gen, Q,QR weights Merge
|
||||
USE_MERGE_Q_KV_GEN_AND_Q_QR = True
|
||||
|
||||
# Support FP8 Weights for WQ,QR
|
||||
W_Q_W_QR_WUV_WUK_USE_FP8 = True
|
||||
|
||||
# fused prefill
|
||||
USE_FUSED_PREFILL = True
|
||||
|
||||
# fused prefill stage1
|
||||
USE_FUSED_PREFILL_STAGE1 = True
|
||||
|
||||
# All Request Seq Lens
|
||||
DO_SEQ_LENS = 0
|
||||
def update_seqence_length(seq_num):
|
||||
global DO_SEQ_LENS
|
||||
DO_SEQ_LENS = seq_num
|
||||
|
||||
USE_DS3_SAMPLER = int(os.getenv("USE_DS3_SAMPLER", 1))
|
||||
USE_DS3_SAMPLER_OP = int(os.getenv("USE_DS3_SAMPLER_OP", 1))
|
||||
|
||||
# cut prefill seq len
|
||||
CUT_PREFILL_SEQ_LEN = int(os.getenv("CUT_PREFILL_SEQ_LEN", -1))
|
||||
|
||||
# llm max prefill seq len
|
||||
LLM_MAX_PREFILL_SEQ_LEN = int(os.getenv("LLM_MAX_PREFILL_SEQ_LEN", 56 * 1024))
|
||||
|
||||
# All Fused Decode, default is cpu loop
|
||||
USE_DECODER_LAYER_FUSE_MODE = int(os.getenv("USE_DECODER_LAYER_FUSE_MODE", 1))
|
||||
|
||||
# Fused all layers, use cmcu loop
|
||||
FUSE_ALL_DECODER_LAYERS = int(os.getenv("FUSE_ALL_DECODER_LAYERS", 1))
|
||||
|
||||
# where to use flash attention (default: 1)
|
||||
USE_FLASH_ATTENTION = int(os.getenv("USE_FLASH_ATTENTION", 1))
|
||||
|
||||
# transpose gptq weight KN => NK
|
||||
TRANSPOSE_GPTQ_WEIGHT = True
|
||||
|
||||
# qwen fused attention
|
||||
USE_FUSED_QWEN_ATTENTION = int(os.getenv("USE_FUSED_QWEN_ATTENTION", 1))
|
||||
|
||||
# support MTP eh_proj with TP
|
||||
USE_PARALLEL_MTP_EH_PROJ = int(os.getenv("USE_PARALLEL_MTP_EH_PROJ", 1))
|
||||
|
||||
# kv_cache group size
|
||||
BLOCK_GROUP_SIZE = int(os.getenv("BLOCK_GROUP_SIZE", 8192))
|
||||
|
||||
# bert fused attention
|
||||
USE_FUSED_BERT_ATTENTION = int(os.getenv("USE_FUSED_BERT_ATTENTION", 1))
|
||||
|
||||
# fused mlp vision
|
||||
USE_FUSED_MLP_VISION = int(os.getenv("USE_FUSED_MLP_VISION", 1))
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,491 @@
|
||||
import torch
|
||||
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer,
|
||||
DeepseekV2MLAAttention,
|
||||
DeepseekV2MLP,
|
||||
DeepseekV2MoE)
|
||||
|
||||
from ..vars import *
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
OUTPUT_ARGS_LOGS = False
|
||||
|
||||
class DistributedArgs():
|
||||
def __init__(self):
|
||||
self._0_world_size = 32
|
||||
self._1_rank = -1
|
||||
self._2_group_id = 0
|
||||
self._3_dev_info = []
|
||||
|
||||
def logs(self):
|
||||
print("dist self._0_world_size = " , self._0_world_size)
|
||||
print("dist self._1_rank = " , self._1_rank)
|
||||
print("dist self._2_group_id = " , self._2_group_id)
|
||||
print("dist self._3_dev_info = " , self._3_dev_info)
|
||||
|
||||
class AttenArgs():
|
||||
def __init__(self):
|
||||
self._a_hidden_states_norm_weight = []
|
||||
self._0_merge_q_kv_weights = [] # 融合Q,KV Weights
|
||||
self._1_merge_q_kv_scale_inv = [] # 融合Q,KV Scales
|
||||
self._2_q_a_layernorm_weight = []
|
||||
self._3_W_Q = []
|
||||
self._4_W_Q_scales = []
|
||||
self._5_W_UK = []
|
||||
self._6_W_UK_scales = []
|
||||
self._7_W_QR = []
|
||||
self._8_W_QR_scales = []
|
||||
self._9_kv_a_layernorm_weight = []
|
||||
self._10_sin_cache = None
|
||||
self._11_cos_cache = None
|
||||
self._12_slot_mapping = None
|
||||
self._13_kv_cache = None
|
||||
self._14_block_tables = None
|
||||
self._15_env_blk_grp_size = env_blk_grp_size
|
||||
self._16_W_UV = []
|
||||
self._17_W_UV_scales = []
|
||||
self._18_o_proj_weight =[]
|
||||
self._19_o_proj_weight_scale_inv = []
|
||||
# mla params
|
||||
self._20_seq_lens = []
|
||||
self._21_sm_scale = 0.0
|
||||
self._22_head_num = 128
|
||||
|
||||
def logs(self):
|
||||
print("mla _20_seq_lens block size is:", self._20_seq_lens)
|
||||
print("mla _21_sm_scale block size is:", self._21_sm_scale)
|
||||
print("mla _22_head_num block size is:", self._22_head_num)
|
||||
|
||||
|
||||
class MlpArgs():
|
||||
def __init__(self):
|
||||
#mlp params
|
||||
self._0_mlp_rms_weight = []
|
||||
self._1_mlp_w13 = []
|
||||
self._2_mlp_w2 = []
|
||||
self._3_mlp_w13_scale = []
|
||||
self._4_mlp_w2_scale = []
|
||||
self._5_mlp_w13_block_size = []
|
||||
self._6_mlp_w2_block_size = []
|
||||
|
||||
def logs(self):
|
||||
print("mlp _5_mlp_w13_block_size block size is:", self._5_mlp_w13_block_size)
|
||||
print("mlp _6_mlp_w2_block_size block size is:", self._6_mlp_w2_block_size)
|
||||
|
||||
class MoeArgs():
|
||||
def __init__(self):
|
||||
#moe params
|
||||
self._0_moe_rms_weight = []
|
||||
self._1_moe_share_mlp_w13 = []
|
||||
self._2_moe_share_mlp_w2 = []
|
||||
self._3_moe_share_mlp_w13_scale = []
|
||||
self._4_moe_share_mlp_w2_scale = []
|
||||
self._5_moe_w13 = []
|
||||
self._6_moe_w2 = []
|
||||
self._7_moe_w13_scale = []
|
||||
self._8_moe_w2_scale = []
|
||||
self._9_gate_weight = []
|
||||
self._10_moe_bias = []
|
||||
self._11_moe_mlp_w13_block_size = []
|
||||
self._12_moe_mlp_w2_block_size = []
|
||||
self._13_moe_w13_block_size = []
|
||||
self._14_moe_w2_block_size = []
|
||||
|
||||
def logs(self):
|
||||
print("moe _11_moe_mlp_w13_block_size block size is:", self._11_moe_mlp_w13_block_size)
|
||||
print("moe _12_moe_mlp_w2_block_size block size is:", self._12_moe_mlp_w2_block_size)
|
||||
print("moe _13_moe_w13_block_size block size is:", self._13_moe_w13_block_size)
|
||||
print("moe _14_moe_w2_block_size block size is:", self._14_moe_w2_block_size)
|
||||
|
||||
class WeightMapper():
|
||||
def __init__(self):
|
||||
self.attn_args = AttenArgs()
|
||||
self.mlp_args = MlpArgs() # 3 mla+mlp
|
||||
self.moe_args = MoeArgs() # 58 mla+moe
|
||||
self.dist_args = DistributedArgs()
|
||||
|
||||
# 1. weights载入
|
||||
# 2. dequant blocks 预计算
|
||||
# 3. 参数缓存&提取
|
||||
class DeepseekWeightCapture():
|
||||
def __init__(self, layer: torch.nn.ModuleList,
|
||||
start: int,
|
||||
end: int):
|
||||
|
||||
self.layer_mlp = WeightMapper()
|
||||
self.layer_moe = WeightMapper()
|
||||
|
||||
self.sin_cache_all = None
|
||||
self.cos_cache_all = None
|
||||
|
||||
self.mlp_nums = 3
|
||||
self.moe_nums = end - self.mlp_nums
|
||||
|
||||
self.start_idx = start
|
||||
self.end_idx = end
|
||||
for i in range(start, end):
|
||||
if i < self.mlp_nums:
|
||||
self.capture_deepseek_mla_attn_weights(layer[i], self.layer_mlp.attn_args)
|
||||
self.capture_deepseek_mlp_weights(layer[i])
|
||||
else:
|
||||
self.capture_deepseek_mla_attn_weights(layer[i], self.layer_moe.attn_args)
|
||||
self.capture_deepseek_moe_weights(layer[i])
|
||||
|
||||
if OUTPUT_ARGS_LOGS:
|
||||
self.layer_mlp.attn_args.logs()
|
||||
self.layer_mlp.mlp_args.logs()
|
||||
|
||||
self.layer_moe.attn_args.logs()
|
||||
self.layer_moe.moe_args.logs()
|
||||
|
||||
from vllm.distributed import get_tp_group
|
||||
tp_group = get_tp_group()
|
||||
self.layer_mlp.dist_args._0_world_size = tp_group.world_size
|
||||
self.layer_mlp.dist_args._1_rank = tp_group.rank_in_group
|
||||
self.layer_mlp.dist_args._2_group_id = tp_group.group_id
|
||||
self.layer_mlp.dist_args._3_dev_info = tp_group.rank_device_infos
|
||||
|
||||
self.layer_moe.dist_args._0_world_size = tp_group.world_size
|
||||
self.layer_moe.dist_args._1_rank = tp_group.rank_in_group
|
||||
self.layer_moe.dist_args._2_group_id = tp_group.group_id
|
||||
self.layer_moe.dist_args._3_dev_info = tp_group.rank_device_infos
|
||||
|
||||
|
||||
def capture_deepseek_mlp_weights(self, module: DeepseekV2DecoderLayer):
|
||||
assert isinstance(module.mlp, DeepseekV2MLP)
|
||||
|
||||
mlp = module.mlp
|
||||
rms_norm = module.post_attention_layernorm
|
||||
|
||||
w13_weight = mlp.gate_up_proj.weight
|
||||
w2_weight = mlp.down_proj.weight
|
||||
|
||||
w13_wscale = mlp.gate_up_proj.weight_scale_inv
|
||||
w13_iscale = mlp.gate_up_proj.input_scale
|
||||
|
||||
w2_wscale = mlp.down_proj.weight_scale_inv
|
||||
w2_iscale = mlp.down_proj.input_scale
|
||||
|
||||
|
||||
w13_block_size0, w13_block_size1 = mlp.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = mlp.gate_up_proj.quant_method.scale_n, mlp.gate_up_proj.quant_method.scale_k
|
||||
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
|
||||
w13_block_size0 = w13_block_size0 // scale_n
|
||||
w13_block_size1 = w13_block_size1 // scale_k
|
||||
|
||||
w2_block_size0, w2_block_size1 = mlp.down_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = mlp.down_proj.quant_method.scale_n, mlp.down_proj.quant_method.scale_k
|
||||
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
|
||||
w2_block_size0 = w2_block_size0 // scale_n
|
||||
w2_block_size1 = w2_block_size1 // scale_k
|
||||
|
||||
|
||||
self.layer_mlp.mlp_args._0_mlp_rms_weight.append(rms_norm.weight)
|
||||
self.layer_mlp.mlp_args._1_mlp_w13.append(w13_weight)
|
||||
self.layer_mlp.mlp_args._2_mlp_w2.append(w2_weight)
|
||||
self.layer_mlp.mlp_args._3_mlp_w13_scale.append(w13_wscale)
|
||||
self.layer_mlp.mlp_args._4_mlp_w2_scale.append(w2_wscale)
|
||||
self.layer_mlp.mlp_args._5_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
|
||||
self.layer_mlp.mlp_args._6_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
|
||||
|
||||
|
||||
def capture_deepseek_moe_weights(self, module: DeepseekV2DecoderLayer):
|
||||
assert isinstance(module.mlp, DeepseekV2MoE)
|
||||
|
||||
share_expert_layer = module.mlp.shared_experts
|
||||
experts_layer = module.mlp.experts
|
||||
rms_norm = module.post_attention_layernorm
|
||||
gate = module.mlp.gate
|
||||
|
||||
w13_weight = share_expert_layer.gate_up_proj.weight
|
||||
w2_weight = share_expert_layer.down_proj.weight
|
||||
|
||||
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
|
||||
# w13_iscale = share_expert_layer.gate_up_proj.input_scale
|
||||
|
||||
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
|
||||
# w2_iscale = share_expert_layer.down_proj.input_scale
|
||||
|
||||
|
||||
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
|
||||
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
|
||||
w13_block_size0 = w13_block_size0 // scale_n
|
||||
w13_block_size1 = w13_block_size1 // scale_k
|
||||
|
||||
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
|
||||
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
|
||||
w2_block_size0 = w2_block_size0 // scale_n
|
||||
w2_block_size1 = w2_block_size1 // scale_k
|
||||
|
||||
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
|
||||
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
|
||||
block_size0, block_size1 = (
|
||||
hidden_dims // hidden_blocks,
|
||||
inter_dims // inter_blocks,
|
||||
)
|
||||
|
||||
self.layer_moe.moe_args._0_moe_rms_weight.append(rms_norm.weight)
|
||||
self.layer_moe.moe_args._1_moe_share_mlp_w13.append(w13_weight)
|
||||
self.layer_moe.moe_args._2_moe_share_mlp_w2.append(w2_weight)
|
||||
self.layer_moe.moe_args._3_moe_share_mlp_w13_scale.append(w13_wscale)
|
||||
self.layer_moe.moe_args._4_moe_share_mlp_w2_scale.append(w2_wscale)
|
||||
self.layer_moe.moe_args._5_moe_w13.append(experts_layer.w13_weight)
|
||||
self.layer_moe.moe_args._6_moe_w2.append(experts_layer.w2_weight)
|
||||
self.layer_moe.moe_args._7_moe_w13_scale.append(experts_layer.w13_weight_scale_inv)
|
||||
self.layer_moe.moe_args._8_moe_w2_scale.append(experts_layer.w2_weight_scale_inv)
|
||||
self.layer_moe.moe_args._9_gate_weight.append(gate.weight)
|
||||
self.layer_moe.moe_args._10_moe_bias.append(gate.e_score_correction_bias)
|
||||
self.layer_moe.moe_args._11_moe_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
|
||||
self.layer_moe.moe_args._12_moe_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
|
||||
self.layer_moe.moe_args._13_moe_w13_block_size = [block_size0, block_size1]
|
||||
self.layer_moe.moe_args._14_moe_w2_block_size = [block_size0, block_size1]
|
||||
|
||||
def capture_deepseek_mla_attn_weights(self, module: DeepseekV2DecoderLayer,
|
||||
weight_mapper: AttenArgs):
|
||||
if(self.sin_cache_all is None):
|
||||
self.sin_cache_all = module.self_attn.mla_attn.impl.rotary_emb.sin_cache
|
||||
self.cos_cache_all = module.self_attn.mla_attn.impl.rotary_emb.cos_cache
|
||||
|
||||
weight_mapper._a_hidden_states_norm_weight.append(module.input_layernorm.weight)
|
||||
|
||||
fused_params = {}
|
||||
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for name, param in module.self_attn.q_a_proj.named_parameters():
|
||||
fused_params['q_a_proj_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.q_a_layernorm.named_parameters():
|
||||
fused_params['q_a_layernorm_' + name] = param
|
||||
|
||||
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for name, param in module.self_attn.kv_a_proj_with_mqa.named_parameters():
|
||||
fused_params['kv_a_proj_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.kv_a_layernorm.named_parameters():
|
||||
fused_params['kv_a_layernorm_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.o_proj.named_parameters():
|
||||
fused_params['o_proj_' + name] = param
|
||||
|
||||
import os
|
||||
self._15_env_blk_grp_size = env_blk_grp_size
|
||||
# init sin,cos cache
|
||||
|
||||
mla_params = module.self_attn.mla_attn.impl.extract_weights()
|
||||
fused_params = {**fused_params, **mla_params}
|
||||
|
||||
weight_mapper._0_merge_q_kv_weights.append(module.self_attn.merge_q_kv_weights)
|
||||
weight_mapper._1_merge_q_kv_scale_inv.append(module.self_attn.merge_q_kv_scale_inv)
|
||||
weight_mapper._2_q_a_layernorm_weight.append(fused_params['q_a_layernorm_weight'])
|
||||
weight_mapper._3_W_Q.append(fused_params['W_Q'])
|
||||
weight_mapper._4_W_Q_scales.append(fused_params['W_Q_scales'])
|
||||
weight_mapper._5_W_UK.append(fused_params['W_UK'])
|
||||
weight_mapper._6_W_UK_scales.append(fused_params['W_UK_scales'])
|
||||
weight_mapper._7_W_QR.append(fused_params['W_QR'])
|
||||
weight_mapper._8_W_QR_scales.append(fused_params['W_QR_scales'])
|
||||
weight_mapper._9_kv_a_layernorm_weight.append(fused_params['kv_a_layernorm_weight'])
|
||||
#weight_mapper._10_sin_cache.append(None)
|
||||
#weight_mapper._11_cos_cache.append(None)
|
||||
#weight_mapper._12_slot_mapping.append(None)
|
||||
#weight_mapper._13_kv_cache.append(None)
|
||||
#weight_mapper._14_block_tables.append(None)
|
||||
# weight_mapper._15_env_blk_grp_size.append(None)
|
||||
weight_mapper._16_W_UV.append(fused_params['W_UV'])
|
||||
weight_mapper._17_W_UV_scales.append(fused_params['W_UV_scales'])
|
||||
weight_mapper._18_o_proj_weight.append(fused_params['o_proj_weight'])
|
||||
weight_mapper._19_o_proj_weight_scale_inv.append(fused_params['o_proj_weight_scale_inv'])
|
||||
weight_mapper._20_seq_lens = None
|
||||
weight_mapper._21_sm_scale = module.self_attn.scaling
|
||||
weight_mapper._22_head_num = module.self_attn.num_heads // module.self_attn.o_proj.tp_size
|
||||
|
||||
# 可优化,在c++里面只用Tensor即可
|
||||
def update_attn_args(self, seq_lens, slot_mapping, kv_caches_dense_layer, kv_caches_moe_layer, block_tables):
|
||||
positions = [i - 1 for i in seq_lens]
|
||||
cos_cache = [self.cos_cache_all[i] for i in positions]
|
||||
sin_cache = [self.sin_cache_all[i] for i in positions]
|
||||
|
||||
self.layer_mlp.attn_args._10_sin_cache = sin_cache
|
||||
self.layer_mlp.attn_args._11_cos_cache = cos_cache
|
||||
|
||||
self.layer_moe.attn_args._10_sin_cache = sin_cache
|
||||
self.layer_moe.attn_args._11_cos_cache = cos_cache
|
||||
|
||||
self.layer_mlp.attn_args._20_seq_lens = seq_lens
|
||||
self.layer_moe.attn_args._20_seq_lens = seq_lens
|
||||
|
||||
|
||||
self.layer_mlp.attn_args._13_kv_cache = kv_caches_dense_layer
|
||||
self.layer_moe.attn_args._13_kv_cache = kv_caches_moe_layer
|
||||
|
||||
self.layer_mlp.attn_args._12_slot_mapping = slot_mapping
|
||||
self.layer_mlp.attn_args._14_block_tables = block_tables
|
||||
|
||||
self.layer_moe.attn_args._12_slot_mapping = slot_mapping
|
||||
self.layer_moe.attn_args._14_block_tables = block_tables
|
||||
|
||||
# for i in range(self.mlp_nums):
|
||||
# if i < self.end_idx:
|
||||
# self.layer_mlp.attn_args._12_slot_mapping[i] = slot_mapping
|
||||
# self.layer_mlp.attn_args._14_block_tables[i] = block_tables
|
||||
|
||||
# for i in range(self.moe_nums):
|
||||
# if i < self.end_idx:
|
||||
# self.layer_moe.attn_args._12_slot_mapping[i] = slot_mapping
|
||||
# self.layer_moe.attn_args._14_block_tables[i] = block_tables
|
||||
|
||||
def logs(self):
|
||||
print("current layer mlp attn: \n")
|
||||
self.layer_mlp.attn_args.logs()
|
||||
self.layer_mlp.dist_args.logs()
|
||||
|
||||
class DeepseekMTPWegitCapture():
|
||||
# 相比DeepSeek Weight Capture, MTP只有1层DeepseekDecoderLayer, 且是MOE的layer
|
||||
def __init__(self, layer: torch.nn.Module):
|
||||
|
||||
self.layer_moe = WeightMapper()
|
||||
|
||||
self.sin_cache_all = None
|
||||
self.cos_cache_all = None
|
||||
|
||||
self.capture_deepseek_mla_attn_weights(layer, self.layer_moe.attn_args)
|
||||
self.capture_deepseek_moe_weights(layer)
|
||||
|
||||
if OUTPUT_ARGS_LOGS:
|
||||
self.layer_moe.attn_args.logs()
|
||||
self.layer_moe.moe_args.logs()
|
||||
|
||||
from vllm.distributed import get_tp_group
|
||||
tp_group = get_tp_group()
|
||||
|
||||
self.layer_moe.dist_args._0_world_size = tp_group.world_size
|
||||
self.layer_moe.dist_args._1_rank = tp_group.rank_in_group
|
||||
self.layer_moe.dist_args._2_group_id = tp_group.group_id
|
||||
self.layer_moe.dist_args._3_dev_info = tp_group.rank_device_infos
|
||||
|
||||
def capture_deepseek_moe_weights(self, module: DeepseekV2DecoderLayer):
|
||||
assert isinstance(module.mlp, DeepseekV2MoE)
|
||||
|
||||
share_expert_layer = module.mlp.shared_experts
|
||||
experts_layer = module.mlp.experts
|
||||
rms_norm = module.post_attention_layernorm
|
||||
gate = module.mlp.gate
|
||||
|
||||
w13_weight = share_expert_layer.gate_up_proj.weight
|
||||
w2_weight = share_expert_layer.down_proj.weight
|
||||
|
||||
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
|
||||
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
|
||||
|
||||
|
||||
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
|
||||
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
|
||||
w13_block_size0 = w13_block_size0 // scale_n
|
||||
w13_block_size1 = w13_block_size1 // scale_k
|
||||
|
||||
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
|
||||
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
|
||||
w2_block_size0 = w2_block_size0 // scale_n
|
||||
w2_block_size1 = w2_block_size1 // scale_k
|
||||
|
||||
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
|
||||
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
|
||||
block_size0, block_size1 = (
|
||||
hidden_dims // hidden_blocks,
|
||||
inter_dims // inter_blocks,
|
||||
)
|
||||
|
||||
self.layer_moe.moe_args._0_moe_rms_weight.append(rms_norm.weight)
|
||||
self.layer_moe.moe_args._1_moe_share_mlp_w13.append(w13_weight)
|
||||
self.layer_moe.moe_args._2_moe_share_mlp_w2.append(w2_weight)
|
||||
self.layer_moe.moe_args._3_moe_share_mlp_w13_scale.append(w13_wscale)
|
||||
self.layer_moe.moe_args._4_moe_share_mlp_w2_scale.append(w2_wscale)
|
||||
self.layer_moe.moe_args._5_moe_w13.append(experts_layer.w13_weight)
|
||||
self.layer_moe.moe_args._6_moe_w2.append(experts_layer.w2_weight)
|
||||
self.layer_moe.moe_args._7_moe_w13_scale.append(experts_layer.w13_weight_scale_inv)
|
||||
self.layer_moe.moe_args._8_moe_w2_scale.append(experts_layer.w2_weight_scale_inv)
|
||||
self.layer_moe.moe_args._9_gate_weight.append(gate.weight)
|
||||
self.layer_moe.moe_args._10_moe_bias.append(gate.e_score_correction_bias)
|
||||
self.layer_moe.moe_args._11_moe_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
|
||||
self.layer_moe.moe_args._12_moe_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
|
||||
self.layer_moe.moe_args._13_moe_w13_block_size = [block_size0, block_size1]
|
||||
self.layer_moe.moe_args._14_moe_w2_block_size = [block_size0, block_size1]
|
||||
|
||||
def capture_deepseek_mla_attn_weights(self, module: DeepseekV2DecoderLayer,
|
||||
weight_mapper: AttenArgs):
|
||||
if(self.sin_cache_all is None):
|
||||
self.sin_cache_all = module.self_attn.mla_attn.impl.rotary_emb.sin_cache
|
||||
self.cos_cache_all = module.self_attn.mla_attn.impl.rotary_emb.cos_cache
|
||||
|
||||
weight_mapper._a_hidden_states_norm_weight.append(module.input_layernorm.weight)
|
||||
|
||||
fused_params = {}
|
||||
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for name, param in module.self_attn.q_a_proj.named_parameters():
|
||||
fused_params['q_a_proj_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.q_a_layernorm.named_parameters():
|
||||
fused_params['q_a_layernorm_' + name] = param
|
||||
|
||||
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for name, param in module.self_attn.kv_a_proj_with_mqa.named_parameters():
|
||||
fused_params['kv_a_proj_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.kv_a_layernorm.named_parameters():
|
||||
fused_params['kv_a_layernorm_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.o_proj.named_parameters():
|
||||
fused_params['o_proj_' + name] = param
|
||||
|
||||
import os
|
||||
self._15_env_blk_grp_size = env_blk_grp_size
|
||||
# init sin,cos cache
|
||||
|
||||
mla_params = module.self_attn.mla_attn.impl.extract_weights()
|
||||
fused_params = {**fused_params, **mla_params}
|
||||
|
||||
weight_mapper._0_merge_q_kv_weights.append(module.self_attn.merge_q_kv_weights)
|
||||
weight_mapper._1_merge_q_kv_scale_inv.append(module.self_attn.merge_q_kv_scale_inv)
|
||||
weight_mapper._2_q_a_layernorm_weight.append(fused_params['q_a_layernorm_weight'])
|
||||
weight_mapper._3_W_Q.append(fused_params['W_Q'])
|
||||
weight_mapper._4_W_Q_scales.append(fused_params['W_Q_scales'])
|
||||
weight_mapper._5_W_UK.append(fused_params['W_UK'])
|
||||
weight_mapper._6_W_UK_scales.append(fused_params['W_UK_scales'])
|
||||
weight_mapper._7_W_QR.append(fused_params['W_QR'])
|
||||
weight_mapper._8_W_QR_scales.append(fused_params['W_QR_scales'])
|
||||
weight_mapper._9_kv_a_layernorm_weight.append(fused_params['kv_a_layernorm_weight'])
|
||||
#weight_mapper._10_sin_cache.append(None)
|
||||
#weight_mapper._11_cos_cache.append(None)
|
||||
#weight_mapper._12_slot_mapping.append(None)
|
||||
#weight_mapper._13_kv_cache.append(None)
|
||||
#weight_mapper._14_block_tables.append(None)
|
||||
# weight_mapper._15_env_blk_grp_size.append(None)
|
||||
weight_mapper._16_W_UV.append(fused_params['W_UV'])
|
||||
weight_mapper._17_W_UV_scales.append(fused_params['W_UV_scales'])
|
||||
weight_mapper._18_o_proj_weight.append(fused_params['o_proj_weight'])
|
||||
weight_mapper._19_o_proj_weight_scale_inv.append(fused_params['o_proj_weight_scale_inv'])
|
||||
weight_mapper._20_seq_lens = None
|
||||
weight_mapper._21_sm_scale = module.self_attn.scaling
|
||||
weight_mapper._22_head_num = module.self_attn.num_heads // module.self_attn.o_proj.tp_size
|
||||
|
||||
# 可优化,在c++里面只用Tensor即可
|
||||
def update_attn_args(self, seq_lens, slot_mapping, kv_caches_dense_layer, kv_caches_moe_layer, block_tables):
|
||||
positions = [i - 1 for i in seq_lens]
|
||||
cos_cache = [self.cos_cache_all[i] for i in positions]
|
||||
sin_cache = [self.sin_cache_all[i] for i in positions]
|
||||
|
||||
self.layer_moe.attn_args._10_sin_cache = sin_cache
|
||||
self.layer_moe.attn_args._11_cos_cache = cos_cache
|
||||
|
||||
self.layer_moe.attn_args._20_seq_lens = seq_lens
|
||||
|
||||
self.layer_moe.attn_args._13_kv_cache = kv_caches_moe_layer
|
||||
|
||||
self.layer_moe.attn_args._12_slot_mapping = slot_mapping
|
||||
self.layer_moe.attn_args._14_block_tables = block_tables
|
||||
|
||||
def logs(self):
|
||||
print("current layer mlp attn: \n")
|
||||
self.layer_mlp.attn_args.logs()
|
||||
self.layer_mlp.dist_args.logs()
|
||||
@@ -0,0 +1,154 @@
|
||||
import torch
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
|
||||
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeDecoderLayer,
|
||||
Qwen3MoeMLP)
|
||||
|
||||
class Qwen3Moe_DistributedArgs():
|
||||
def __init__(self):
|
||||
self._0_world_size = 32
|
||||
self._1_rank = -1
|
||||
self._2_group_id = 0
|
||||
self._3_dev_info = []
|
||||
|
||||
def __repr__(self):
|
||||
dist_infos = f"[dist] world_size = {self._0_world_size} \n" \
|
||||
+ f"[dist] rank = {self._1_rank} \n" \
|
||||
+ f"[dist] group_id = {self._2_group_id} \n" \
|
||||
+ f"[dist] dev_info = {self._3_dev_info}"
|
||||
return dist_infos
|
||||
|
||||
class Qwen3Moe_AttenArgs():
|
||||
def __init__(self):
|
||||
self._0_input_layernorm_weight = []
|
||||
self._1_qkv_proj_weight = [] #
|
||||
self._2_qkv_proj_weight_scale = []
|
||||
self._3_qkv_proj_bias = []
|
||||
self._4_qkv_proj_qzeros = []
|
||||
self._5_q_norm_weight = []
|
||||
self._6_k_norm_weight = []
|
||||
self._7_sin_cache = None
|
||||
self._8_cos_cache = None
|
||||
self._9_slot_mapping = None
|
||||
self._10_kv_cache = None
|
||||
self._11_block_tables = None
|
||||
self._12_block_group_size = None
|
||||
self._13_o_proj_weight = []
|
||||
self._14_o_proj_weight_scale = []
|
||||
self._15_o_proj_bias = []
|
||||
self._16_o_proj_qzeros = []
|
||||
self._17_seq_lens = None
|
||||
self._18_sm_scale =None
|
||||
self._19_num_attention_heads = None
|
||||
self._20_num_key_value_heads = None
|
||||
|
||||
def __repr__(self):
|
||||
attn_infos = "[qwen attn] 21 args" \
|
||||
+ f"[qwen attn] weight counts: {len(self._0_input_layernorm_weight)}"
|
||||
return attn_infos
|
||||
|
||||
class Qwen3Moe_MoeArgs():
|
||||
def __init__(self):
|
||||
#moe params
|
||||
self._0_rms_norm_weight = []
|
||||
self._1_w13_weight = []
|
||||
self._2_w2_weight = []
|
||||
self._3_w13_weight_scale_inv = []
|
||||
self._4_w2_weight_scale_inv = []
|
||||
self._5_gate_weight = []
|
||||
self._6_w13_block_size = None
|
||||
self._7_w2_block_size = None
|
||||
|
||||
def __repr__(self):
|
||||
moe_infos = f"[moe] w13_block_size: {self._6_w13_block_size}" \
|
||||
+ f"[moe] w2_block_size: {self._7_w2_block_size}" \
|
||||
+ f"[moe] weight counts: {len(self._1_w13_weight)}"
|
||||
return moe_infos
|
||||
|
||||
class Qwen3Moe_WeightMapper:
|
||||
def __init__(self):
|
||||
self.attn_args = Qwen3Moe_AttenArgs()
|
||||
self.moe_args = Qwen3Moe_MoeArgs()
|
||||
self.dist_args = Qwen3Moe_DistributedArgs()
|
||||
|
||||
class Qwen3Moe_WeightCapture():
|
||||
def __init__(self, layers: torch.nn.ModuleList,
|
||||
start: int,
|
||||
end: int):
|
||||
self.layer_mapper = Qwen3Moe_WeightMapper()
|
||||
# qwen3 only support fp8 now
|
||||
self.support_fused_weights = False
|
||||
for i in range(start, end):
|
||||
layer = layers[i]
|
||||
self.capture_attn_weights(layer)
|
||||
self.capture_moe_weights(layer)
|
||||
|
||||
# 注册 多卡环境信息
|
||||
from vllm.distributed import get_tp_group
|
||||
tp_group = get_tp_group()
|
||||
self.layer_mapper.dist_args._0_world_size = tp_group.world_size
|
||||
self.layer_mapper.dist_args._1_rank = tp_group.rank_in_group
|
||||
self.layer_mapper.dist_args._2_group_id = tp_group.group_id
|
||||
self.layer_mapper.dist_args._3_dev_info = tp_group.rank_device_infos
|
||||
|
||||
def capture_attn_weights(self, layer):
|
||||
from vllm_vacc.vllm.model_executor.models.qwen3_moe import set_fused_params
|
||||
# 注册融合算子
|
||||
fused_params = {}
|
||||
fused_params['input_layernorm_weight'] = layer.input_layernorm.weight
|
||||
fused_params['q_norm_weight'] = layer.self_attn.q_norm.weight
|
||||
fused_params['k_norm_weight'] = layer.self_attn.k_norm.weight
|
||||
set_fused_params(fused_params, layer.self_attn.qkv_proj.quant_method, layer.self_attn.qkv_proj, 'qkv_proj')
|
||||
set_fused_params(fused_params, layer.self_attn.o_proj.quant_method, layer.self_attn.o_proj, 'o_proj')
|
||||
|
||||
self.support_fused_weights = hasattr(layer.mlp.experts.quant_method, 'quant_config') and hasattr(layer.mlp.experts.quant_method.quant_config, 'weight_block_size')
|
||||
if not hasattr(layer.self_attn, "fused_params"):
|
||||
layer.self_attn.fused_params = fused_params
|
||||
|
||||
self.layer_mapper.attn_args._0_input_layernorm_weight.append(fused_params['input_layernorm_weight'])
|
||||
self.layer_mapper.attn_args._1_qkv_proj_weight.append(fused_params['qkv_proj_weight'])
|
||||
self.layer_mapper.attn_args._2_qkv_proj_weight_scale.append(fused_params['qkv_proj_weight_scale'])
|
||||
self.layer_mapper.attn_args._3_qkv_proj_bias.append(fused_params['qkv_proj_bias'])
|
||||
self.layer_mapper.attn_args._4_qkv_proj_qzeros.append(fused_params['qkv_proj_qzeros'])
|
||||
self.layer_mapper.attn_args._5_q_norm_weight.append(fused_params['q_norm_weight'])
|
||||
self.layer_mapper.attn_args._6_k_norm_weight.append(fused_params['k_norm_weight'])
|
||||
# self.layer_mapper.attn_args._7_sin_cache
|
||||
# self.layer_mapper.attn_args._8_cos_cache
|
||||
# self.layer_mapper.attn_args._9_slot_mapping
|
||||
# self.layer_mapper.attn_args._10_kv_cache
|
||||
# self.layer_mapper.attn_args._11_block_tables
|
||||
# self.layer_mapper.attn_args._12_block_group_size
|
||||
self.layer_mapper.attn_args._13_o_proj_weight.append(fused_params['o_proj_weight'])
|
||||
self.layer_mapper.attn_args._14_o_proj_weight_scale.append(fused_params['o_proj_weight_scale'])
|
||||
self.layer_mapper.attn_args._15_o_proj_bias.append(fused_params['o_proj_bias'])
|
||||
self.layer_mapper.attn_args._16_o_proj_qzeros.append(fused_params['o_proj_qzeros'])
|
||||
# self.layer_mapper.attn_args._17_seq_lens
|
||||
self.layer_mapper.attn_args._18_sm_scale = layer.self_attn.scaling
|
||||
self.layer_mapper.attn_args._19_num_attention_heads = layer.self_attn.total_num_heads
|
||||
self.layer_mapper.attn_args._20_num_key_value_heads = layer.self_attn.total_num_kv_heads
|
||||
|
||||
def capture_moe_weights(self, layer: Qwen3MoeDecoderLayer):
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
quant_method = layer.mlp.experts.quant_method if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock) \
|
||||
else layer.mlp.down_proj.quant_method
|
||||
|
||||
if not isinstance(quant_method, MoeWNA16Method):
|
||||
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import recompute_moe_layer_blocksize
|
||||
recompute_moe_layer_blocksize(layer.mlp.experts)
|
||||
self.layer_mapper.moe_args._0_rms_norm_weight.append(layer.post_attention_layernorm.weight)
|
||||
self.layer_mapper.moe_args._1_w13_weight.append(layer.mlp.experts.w13_weight)
|
||||
self.layer_mapper.moe_args._2_w2_weight.append(layer.mlp.experts.w2_weight)
|
||||
self.layer_mapper.moe_args._3_w13_weight_scale_inv.append(layer.mlp.experts.w13_weight_scale_inv)
|
||||
self.layer_mapper.moe_args._4_w2_weight_scale_inv.append(layer.mlp.experts.w2_weight_scale_inv)
|
||||
self.layer_mapper.moe_args._5_gate_weight.append(layer.mlp.gate.weight)
|
||||
self.layer_mapper.moe_args._6_w13_block_size = layer.mlp.experts.w13_block_size
|
||||
self.layer_mapper.moe_args._7_w2_block_size = layer.mlp.experts.w2_block_size
|
||||
else:
|
||||
self.layer_mapper.moe_args._0_rms_norm_weight.append(layer.post_attention_layernorm.weight)
|
||||
self.layer_mapper.moe_args._1_w13_weight.append(layer.mlp.experts.w13_qweight)
|
||||
self.layer_mapper.moe_args._2_w2_weight.append(layer.mlp.experts.w2_qweight)
|
||||
self.layer_mapper.moe_args._3_w13_weight_scale_inv.append(layer.mlp.experts.w13_scales)
|
||||
self.layer_mapper.moe_args._4_w2_weight_scale_inv.append(layer.mlp.experts.w2_scales)
|
||||
self.layer_mapper.moe_args._5_gate_weight.append(layer.mlp.gate.weight)
|
||||
self.layer_mapper.moe_args._6_w13_block_size = layer.mlp.experts.w13_block_size
|
||||
self.layer_mapper.moe_args._7_w2_block_size = layer.mlp.experts.w2_block_size
|
||||
|
||||
0
vllm_vacc/vllm/model_executor/ops/__init__.py
Normal file
0
vllm_vacc/vllm/model_executor/ops/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user