This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

@@ -0,0 +1,38 @@
import torch.nn as nn
from vllm.logger import init_logger
logger = init_logger(__name__)
class CustomOp(nn.Module):
def forward_vacc(self, *args, **kwargs):
raise NotImplementedError
def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
enabled = self.enabled()
logger.debug("custom op %s %s", self.__class__.name,
"enabled" if enabled else "disabled")
if not enabled:
return self.forward_native
return self.forward
if current_platform.is_rocm():
return self.forward_hip
elif current_platform.is_cpu():
return self.forward_cpu
elif current_platform.is_hpu():
return self.forward_hpu
elif current_platform.is_tpu():
return self.forward_tpu
elif current_platform.is_xpu():
return self.forward_xpu
elif current_platform.is_vacc():
return self.forward
else:
return self.forward_cuda

View File

@@ -0,0 +1,8 @@
import torch
def SiluAndMul_forward_vacc(self, x: torch.Tensor) -> torch.Tensor:
return torch.vacc.swiglu(x)
def QuickGELU_forward_vacc(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)

View File

@@ -0,0 +1,82 @@
import torch
from typing import Optional, Tuple
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
dtype = hidden_states.dtype
hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, -1)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)
return topk_weights, selected_experts
def grouped_topk_with_itype(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
try:
from torch_vacc.vacc.custom_ops import fused_moe_preprocess
return fused_moe_preprocess(gating_output, e_score_correction_bias)
except Exception as e:
print(f"fused group topk run fail, now use unfused group topk: {e}")
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (scores.view(num_token, num_expert_group,
-1).topk(2, dim=-1)[0].sum(dim=-1))
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(hidden_states.dtype), topk_ids.to(torch.int32)

View File

@@ -0,0 +1,715 @@
# SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod
from enum import Enum
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op
if current_platform.is_cuda_alike():
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore
if current_platform.is_tpu():
# the iterative moe implementation is used until the moe_pallas is fixed
from .moe_torch_iterative import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
def FusedMoE_init_(
self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
is_sequence_parallel=False,
):
from vllm.model_executor.layers.fused_moe import FusedMoE
super(FusedMoE, self).__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
vllm_config = get_current_vllm_config()
# FIXME (varun): We should have a better way of inferring the activation
# datatype. This works for now as the tensor datatype entering the MoE
# operation is typically unquantized (i.e. float16/bfloat16).
if vllm_config.model_config is not None:
moe_in_dtype = vllm_config.model_config.dtype
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
moe_in_dtype = params_dtype
tp_size_ = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
dp_size_ = (dp_size
if dp_size is not None else get_dp_group().world_size)
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
tp_size_=tp_size_,
dp_size_=dp_size_,
vllm_parallel_config=vllm_config.parallel_config))
self.global_num_experts = num_experts + num_redundant_experts
# For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.enable_eplb = enable_eplb
self.expert_load_view: Optional[torch.Tensor] = None
self.logical_to_physical_map: Optional[torch.Tensor] = None
self.logical_replica_count: Optional[torch.Tensor] = None
# Determine expert maps
if self.use_ep:
if self.enable_eplb:
assert self.global_num_experts % self.ep_size == 0, \
"EPLB currently only supports even distribution of " \
"experts across ranks."
else:
assert num_redundant_experts == 0, \
"Redundant experts are only supported with EPLB."
self.local_num_experts, self.expert_map = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts)
else:
self.local_num_experts, self.expert_map = (self.global_num_experts,
None)
self.top_k = top_k
assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
)
self.moe_config = moe
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError("EPLB is only supported for FP8 "
"quantization for now.")
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
# self.scale_n = self.quant_method.scale_n
# self.scale_k = self.quant_method.scale_k
self.scale_n = 1
self.scale_k = 1
self.scale_n_prefill = 1
if hasattr(self.quant_method, "scale_n") and hasattr(self.quant_method, "scale_k"):
self.scale_n = self.quant_method.scale_n
self.scale_k = self.quant_method.scale_k
if hasattr(self.quant_method, "scale_n_prefill"):
self.scale_n_prefill = self.quant_method.scale_n_prefill
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
# Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros(
(moe.max_num_tokens, num_experts),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
class FusedMoE(torch.nn.Module):
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int,
expert_id=0):
#print("w13 shape is:", expert_data.shape, loaded_weight.shape)
if self.scale_n > 1 and len(loaded_weight.shape) == 2 and torch.finfo(loaded_weight.dtype).bits > 8:
n_v, k_v = loaded_weight.shape
loaded_weight = loaded_weight.reshape(1, n_v, 1, k_v) #[1, n, 1, k]
if self.scale_n_prefill != self.scale_n and hasattr(self, 'w13_weight_scale_inv_prefill'):
loaded_weight0 = loaded_weight.repeat(self.scale_n_prefill,1,self.scale_k,1).permute(1,0,3,2).reshape([n_v * self.scale_n_prefill, k_v * self.scale_k])
shard_size = self.w13_weight_scale_inv_prefill.data[expert_id].shape[shard_dim] // 2
loaded_weight0 = loaded_weight0.narrow(shard_dim, shard_size * tp_rank, shard_size)
if shard_id == "w1":
self.w13_weight_scale_inv_prefill.data[expert_id, :loaded_weight0.shape[0], :loaded_weight0.shape[1]] = loaded_weight0
elif shard_id == "w3":
self.w13_weight_scale_inv_prefill.data[expert_id, -loaded_weight0.shape[0]:, -loaded_weight0.shape[1]:] = loaded_weight0
else:
raise ValueError('error shard_id: ',shard_id)
loaded_weight = loaded_weight.repeat(self.scale_n,1,self.scale_k,1).permute(1,0,3,2).reshape([n_v * self.scale_n, k_v * self.scale_k])
#print("w13 repeat shape is:", expert_data.shape, loaded_weight.shape)
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False,
expert_id=0):
#print("w2 shape is:", expert_data.shape, loaded_weight.shape)
if self.scale_n > 1 and len(loaded_weight.shape) == 2 and torch.finfo(loaded_weight.dtype).bits > 8:
n_v, k_v = loaded_weight.shape
loaded_weight = loaded_weight.reshape(1, n_v, 1, k_v) #[1, n, 1, k]
if self.scale_n_prefill != self.scale_n and hasattr(self, 'w2_weight_scale_inv_prefill'):
loaded_weight0 = loaded_weight.repeat(self.scale_k,1,self.scale_n_prefill,1).permute(1,0,3,2).reshape([n_v * self.scale_k, k_v * self.scale_n_prefill])
shard_size = self.w2_weight_scale_inv_prefill.data[expert_id].shape[shard_dim]
if not load_full:
loaded_weight0 = loaded_weight0.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
self.w2_weight_scale_inv_prefill.data[expert_id] = loaded_weight0
#print("loaded_weight:", loaded_weight.shape)
loaded_weight = loaded_weight.repeat(self.scale_k,1,self.scale_n,1).permute(1,0,3,2).reshape([n_v * self.scale_k, k_v * self.scale_n])
#print("w2 repeat shape is:", expert_data.shape, loaded_weight.shape)
#if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
def weight_loader(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int, return_success=True) -> None:
if self.quant_config and self.quant_config.get_name() == "mxfp4":
# (FIXME) for gpt-oss all experts are combined
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
param.data[:, :dim1].copy_(loaded_weight)
else:
dim1 = loaded_weight.shape[1]
dim2 = loaded_weight.shape[2]
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return True if return_success else None
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return False if return_success else None
quant_method_name = self.quant_method.__class__.__name__
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
if self.quant_method.__class__.__name__ in (
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod"):
loaded_weight = loaded_weight.t().contiguous()
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
# WEIGHT_SCALE_SUPPORTED = [
# e.value for e in FusedMoeWeightScaleSupported
# ]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
param.data.copy_(loaded_weight)
return True if return_success else None
# Case for BitsAndBytes
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
if use_bitsandbytes_4bit:
shard_dim = 0
expert_data = param.data[expert_id]
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
# BNB inflight quantization has already sharded the weights
full_load = True
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank,
load_full=full_load,
)
return True if return_success else None
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = int(not shard_dim)
full_load = len(loaded_weight.shape) == 3
if full_load:
shard_dim += 1
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
final_shape = list(loaded_weight.shape)
if shard_id in ["w1", "w3"]:
final_shape[1] *= 2
final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size
param.materialize(final_shape, dtype=loaded_weight.dtype)
expert_data = param.data if full_load else param.data[expert_id]
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}")
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return True if return_success else None
# Case g_idx
if "g_idx" in weight_name:
self._load_g_idx(shard_dim=0,
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
return True if return_success else None
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
if "ModelOpt" in quant_method_name:
# Determine per-tensor weight scale patterns based on variant
# Use the dedicated method instead of brittle string matching
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
)
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
# weights scales.
# Input scales are always per-tensor.
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
# "weight_scale" for per-tensor scales.
is_per_tensor = ("weight_scale_2" in weight_name
if uses_weight_scale_2 else "weight_scale"
in weight_name) or "input_scale" in weight_name
if is_per_tensor:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
return True if return_success else None
# If the weight is w13_weight_scale and w13_weight_scales are
# combined into single loaded_weight, call
# _load_combined_w13_weight_scale() to load it.
# This is checked by comparing the hidden_out dims of the
# loaded_weight and the param.
if "w13_weight_scale" in weight_name:
loaded_weight_hidden_out = loaded_weight.shape[-2]
param_hidden_out = param.data.shape[-2] * self.tp_size
if loaded_weight_hidden_out == param_hidden_out:
self._load_combined_w13_weight_scale(
shard_dim=shard_dim,
loaded_weight=loaded_weight,
param=param,
tp_rank=self.tp_rank,
)
return True if return_success else None
# For other weights, call _load_model_weight_or_group_weight_scale()
# to load it.
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
return True if return_success else None
# Case weight scales, zero_points and offset
if ("scale" in weight_name or "zero" in weight_name
or "offset" in weight_name):
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
]:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank,
load_full_w2=getattr(param, "load_full_w2", False),
expert_id=expert_id)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
else:
WEIGHT_SCALE_SUPPORTED = [
e.value for e in FusedMoeWeightScaleSupported
]
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
return True if return_success else None
# Case weight_shape
if "weight_shape" in weight_name:
# only required by compressed-tensors
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return True if return_success else None
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
return True if return_success else None
return False if return_success else None
def _load_model_weight_or_group_weight_scale(self,
shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full_w2: bool = False,
expert_id: int = 0):
"""
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2":
# In the case where we have actorder/g_idx, we do not partition the
# w2 scales, as indicated by `load_full` argument, for all tp cases
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
load_full=load_full_w2,
expert_id=expert_id)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
expert_id=expert_id)
class UnquantizedFusedMoEMethod():
def forward_vacc(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
hidden_states = x
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
dtype = hidden_states.dtype
intermediate_size = layer.w2_weight.shape[-1]
gating_output=router_logits
hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, global_num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
final_hidden_states = torch.zeros_like(hidden_states)
sel_experts = topk_ids.shape[1]
if hidden_states.shape[0] == 1:
for id in range(sel_experts):
expert_idx = topk_ids[0][id]
expert_w1 = layer.w13_weight[expert_idx].contiguous()
expert_w2 = layer.w2_weight[expert_idx].contiguous()
expert_weights = topk_weights[0][id].to(hidden_states.dtype)
x = hidden_states
x = F.linear(x, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
current_hidden_states = current_hidden_states.to(x.dtype)
final_hidden_states += current_hidden_states
else:
for expert_idx in range(global_num_experts):
# topk_ids [tokens, experts] => sample:[10, 8]
# expert_mask [tokens, experts] => sample:[10, 8]
expert_mask = topk_ids == expert_idx
idx = torch.where(expert_mask)[0]
if idx.numel() == 0:
continue
expert_w1 = layer.w13_weight[expert_idx].contiguous()
expert_w2 = layer.w2_weight[expert_idx].contiguous()
# [seq, experts]
expert_weights = (
topk_weights.masked_select(expert_mask)
.unsqueeze(1)
.to(hidden_states.dtype)
)
x = hidden_states[idx]
x = F.linear(x, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
current_hidden_states = current_hidden_states.to(x.dtype)
# final_hidden_states[idx] += current_hidden_states
final_hidden_states.index_add_(0, idx, current_hidden_states)
return final_hidden_states.view(orig_shape) # type: ignore
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
custom_forward = self.forward
if x.device.type == "vacc":
custom_forward = UnquantizedFusedMoEMethod.forward_vacc
return custom_forward(
x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
global_num_experts=global_num_experts,
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)

View File

@@ -0,0 +1,32 @@
from typing import Optional, Tuple, Union
import torch
def RMSNorm_forward_vacc(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if residual is not None:
# x = x + residual
# residual = x
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
# x_var=x_var.unsqueeze(0)
# out = torch.vacc.rms_norm(x_var,self.weight,self.variance_epsilon)
# if residual is None:
# return out.squeeze(0)
# else:
# return out.squeeze(0), residual
out = torch.vacc.fused_residual_rmsnorm(x_var, self.weight, residual, self.variance_epsilon, x_var, residual)
return out

View File

@@ -0,0 +1,465 @@
import itertools
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
BlockQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter)
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
WEIGHT_LOADER_V2_SUPPORTED,
LinearBase,
RowParallelLinear)
def ReplicatedLinear__init__(self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super(ReplicatedLinear,self).__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.scale_k = 1 # quant_block_k 128 需要除以 scale_k 如设置为2 即 quant_block_k 是 64
self.scale_k_slice = 1
self.scale_n = 1
self.scale_n_slice = 1
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
gcd_value = quant_config.weight_block_size[1]
import math
if input_size % quant_config.weight_block_size[1]:
gcd_value = math.gcd(input_size % quant_config.weight_block_size[1], quant_config.weight_block_size[1])
self.scale_k =self.scale_k * quant_config.weight_block_size[1] // gcd_value
self.scale_k_slice = input_size // gcd_value
if output_size % quant_config.weight_block_size[0]:
gcd_value = math.gcd(output_size % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
self.scale_n = self.scale_n * quant_config.weight_block_size[0] // gcd_value
self.scale_n_slice = output_size // gcd_value
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
scale_k = self.scale_k,
scale_n = self.scale_n,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def ReplicatedLinear_weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
if self.scale_k > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_k, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,2,0).reshape([loaded_weight.shape[1], -1])[:, :self.scale_k_slice]
#[1,n,k] -> [scale_k,n,k] -> [n,k,scale_k] -> [n, k*scale_k]
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[2]])[:self.scale_n_slice]
assert param.size() == loaded_weight.size(), f'{param.size()}, {loaded_weight.size()}'
param.data.copy_(loaded_weight)
def refine_block(block_size:list[int],
weight_size:list[int],
dim:int=0,
pingpong_size:int = 2.5*1024*1024, #bytes
core_number:int = 4,
data_type:int = 2, #bfloat16
max_iter_number:int = 2):
'''
对于不均匀分core 需要每个core <= 2.5M 才能保证可以pingpong,
core间相差数量为 block_size[dim] * weight_size[1-dim]
缩小block_size可以减小core间差距使得更平均一些直到大core数据量不超
如果均匀分core已经超了或者没有超就没必要调整
'''
if dim < 0:
dim = 2 + dim
pingpong_size = pingpong_size / data_type # number of data
block_size_refine = block_size[dim]
all_block_number = weight_size[dim] // block_size_refine
if all_block_number % core_number == 0:
#均分,这种情况不管有没有超,都无需调整
return block_size_refine
block_number_tiny = all_block_number // core_number
block_number_big = all_block_number // core_number + 1
if block_number_tiny * block_size_refine * weight_size[1-dim] >= pingpong_size or \
block_number_big * block_size_refine * weight_size[1-dim] <= pingpong_size :
# 小的已经超了,无法再调整了
# 大的没有超,无需调整
return block_size_refine
all_block_number_tmp = all_block_number
block_size_refine_tmp = block_size_refine
for iter_index in range(max_iter_number):
all_block_number_tmp = all_block_number_tmp * 2
block_size_refine_tmp = block_size_refine_tmp // 2
if all_block_number_tmp % core_number == 0:
block_number_tiny = all_block_number // core_number
if block_number_tiny * block_size_refine_tmp * weight_size[1-dim] <= pingpong_size:
return block_size_refine_tmp
else:
#均分还是超了,无需调整
return block_size_refine
else:
block_number_big = all_block_number_tmp // core_number + 1
if block_number_big * block_size_refine_tmp * weight_size[1-dim] <= pingpong_size:
return block_size_refine_tmp
return block_size_refine
def ColumnParallelLinear__init__(self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,):
# Divide the weight matrix along the last dimension.
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
super(ColumnParallelLinear,self).__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
self.scale_n = 1
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
gcd_value = quant_config.weight_block_size[0]
import math
if hasattr(self, "output_sizes"):
# 对于Merge类型的ColumnParallelLinear来说需要根据每个Part Linear的shape去计算最小公约数
output_size_no_merge = self.output_partition_sizes
block_values = [o % quant_config.weight_block_size[0] for o in output_size_no_merge]
is_gcd_recompute = sum(block_values)
if is_gcd_recompute:
import math
block_values.append(quant_config.weight_block_size[0])
gcd_value = math.gcd(*block_values)
# Notice:
# 这儿对于非对齐的Part-Weight 可能需要验证一下流程
# 对于DeepSeek来说仅存在于MLP&MOE中的MergeColumnLinear都是Shape一致的PartWeight
# 对于QWen3来说会存在QKVColumnLinear是Shape不一致的PartWeight但是由于QWen3当下的切分方案对于gcd_value无感无需重计算所以暂时不会进来
if hasattr(self, "output_sizes") and len(output_size_no_merge) == 2 and output_size_no_merge[0] == output_size_no_merge[1]:
#only refine mlp w13
gcd_value = refine_block([gcd_value, quant_config.weight_block_size[1]], [output_size_no_merge[0], input_size])
self.scale_n =self.scale_n * quant_config.weight_block_size[0] // gcd_value
else:
# 对于非Merge的ColumnParallelLinear来说, 仅仅根据当下shape去计算最小公约数
output_size_no_merge = self.output_size_per_partition
is_gcd_recompute = output_size_no_merge % quant_config.weight_block_size[0]
if is_gcd_recompute:
gcd_value = math.gcd(output_size_no_merge % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
self.scale_n =self.scale_n * quant_config.weight_block_size[0] // gcd_value
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
scale_n = self.scale_n,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
def ColumnParallelLinear_weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[-1]])
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,n,k] -> [n*scale_n, k]
param.load_column_parallel_weight(loaded_weight=loaded_weight)
class MergedColumnParallelLinear(ColumnParallelLinear):
def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[-1]])
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,n,k] -> [n*scale_n, k]
if self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
if self.quant_method.scale_k > 1 and len(loaded_weight.shape) == 2 and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]:
loaded_weight = loaded_weight.unsqueeze(1) #[k,1,n]
loaded_weight = loaded_weight.expand(loaded_weight.shape[0], self.quant_method.scale_k, loaded_weight.shape[2]).reshape([-1, loaded_weight.shape[2]])
#[k,1,n] -> [k,scale_k,n]] -> [k*scale_k, n]
if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=0)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id < len(self.output_sizes)
tp_size = get_tensor_model_parallel_world_size()
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0] // self.scale_n, weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n) // tp_size
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n // tp_size)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
def RowParallelLinear__init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
super(RowParallelLinear, self).__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None
self.scale_k = 1 # quant_block_k 128 需要除以 scale_k 如设置为2 即 quant_block_k 是 64
self.scale_n = 1
self.scale_n_slice = 1
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
gcd_value = quant_config.weight_block_size[1]
import math
if self.input_size_per_partition % quant_config.weight_block_size[1]:
gcd_value = math.gcd(self.input_size_per_partition % quant_config.weight_block_size[1], quant_config.weight_block_size[1])
self.scale_k =self.scale_k * quant_config.weight_block_size[1] // gcd_value
if output_size % quant_config.weight_block_size[0]:
gcd_value = math.gcd(output_size % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
self.scale_n = self.scale_n * quant_config.weight_block_size[0] // gcd_value
self.scale_n_slice = output_size // gcd_value
# N = 576, block = 128, n方向scale 扩充需要知道两个信息: 1.拷贝多少份 scale_n; 2. slice 有效的 scale_n_slice
# scale = [s0,s1,s2,s3,s4] 拷贝scale_n=2份
# scale = [s0,s0,s1,s1,s2,s2,s3,s3,s4,s4]slice scale_n_slice=9份 =>[s0,s0,s1,s1,s2,s2,s3,s3,s4]
if self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
gcd_value = quant_config.group_size
import math
if self.input_size_per_partition % quant_config.group_size:
gcd_value = math.gcd(self.input_size_per_partition % quant_config.group_size, quant_config.group_size)
self.quant_method.scale_k = self.quant_method.scale_k * quant_config.group_size // gcd_value
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size],
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
scale_k = self.scale_k,
scale_n = self.scale_n,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def RowParallelLinear_weight_loader_v2_vacc(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
if self.scale_k > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_k, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,2,0).reshape([loaded_weight.shape[1], -1])
#[1,n,k] -> [scale_k,n,k] -> [n,k,scale_k] -> [n, k*scale_k]
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[2]])[:self.scale_n_slice]
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,k] -> [n*scale_n,k]
elif self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
# broadcast scale TODO: broadcast zero
if self.quant_method.scale_k > 1 and len(loaded_weight.shape) == 2 and loaded_weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
loaded_weight = loaded_weight.unsqueeze(1) #[k,1,n]
loaded_weight = loaded_weight.expand(loaded_weight.shape[0], self.quant_method.scale_k, loaded_weight.shape[2]).reshape([-1, loaded_weight.shape[2]])
#[k,1,n] -> [k,scale_k,n]] -> [k*scale_k, n]
param.load_row_parallel_weight(loaded_weight=loaded_weight)
class UnquantizedLinearMethod():
"""Linear method without quantization."""
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if bias is not None:
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
parallel_embedding_output = None
if memory_recycler is not None:
if memory_recycler.EMBEDDING_OUT_BUFFER.size(0) == x.size(0):
parallel_embedding_output = memory_recycler.EMBEDDING_OUT_BUFFER
return torch.mm(x.view(-1, x.shape[-1]), layer.weight.transpose(1,0), out=parallel_embedding_output).view(*(x.shape[:-1]), layer.weight.shape[0])

View File

@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that compute logits from hidden_stats."""
from typing import Optional
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.distributed import tensor_model_parallel_gather
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.platforms import current_platform
def DumpLogits(logits):
import os
import time
VLLM_VACC_DUMP_LOGITS = os.getenv('VLLM_VACC_DUMP_LOGITS')
if VLLM_VACC_DUMP_LOGITS:
logit_arr = logits.cpu().to(torch.float32).numpy()
timestamp = time.time()
# print("timestamp:", timestamp)
if not os.path.exists(VLLM_VACC_DUMP_LOGITS):
os.makedirs(VLLM_VACC_DUMP_LOGITS)
logit_path = os.path.join(VLLM_VACC_DUMP_LOGITS, f'logit_{timestamp}.bin')
summary_path = os.path.join(VLLM_VACC_DUMP_LOGITS, 'summary.txt')
with open(summary_path, 'a') as f:
f.write(f'{logit_path}\n')
# print("save file:", logit_path)
logit_arr.tofile(logit_path)
class LogitsProcessor(nn.Module):
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
from vllm.distributed.parallel_state import get_tp_group
total_size_all_gather_input = hidden_states.size(0) * lm_head.weight.shape[0] * hidden_states.element_size() * get_tp_group().world_size
# fuse matmul and all_gather if hidden size less equal 7168 and divisible by 32 according to dsp developer
# if hidden_states.size(1) <= 7168 and hidden_states.size(1) % 32 == 0:
# fuse matmul and all_gather if total size of all_gather inputs doesn't exceeds 4MB
if total_size_all_gather_input <= 4194304:
try:
from torch_vacc.vacc.custom_ops import fused_matmul_allgather
# from vllm.distributed.parallel_state import get_tp_group
logits = fused_matmul_allgather(hidden_states, lm_head.weight.T,
get_tp_group().world_size,
get_tp_group().rank_in_group,
get_tp_group().group_id,
get_tp_group().rank_device_infos)
# ensure there is PP last stage at front
if (hasattr(current_platform, 'supports_v1') and current_platform.supports_v1(current_platform)) or get_tp_group().rank_in_group == 0:
seq, hidden_dims = hidden_states.shape
logits = logits.movedim(0, 1)
logits = logits.reshape(seq, -1)
logits = logits[..., :self.org_vocab_size]
if get_tp_group().rank_in_group == 0:
DumpLogits(logits)
else:
logits = None
return logits
except Exception as e:
print("Fused Matmul with AllGather run Fail, now use unfused. " ,e)
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
#print("quant method:", lm_head, lm_head.quant_method, embedding_bias, logits.shape)
# Gather logits for TP
logits = self._gather_logits(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits

View File

@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union
import torch
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolerActivation, get_pooling_params
class ClassifierPooler(Pooler):
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
if pooled_data.dtype != self.head_dtype:
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None:
pooled_data = self.classifier(pooled_data)
# pooled_data shape: [batchsize, num_labels]
if self.logit_bias is not None:
pooled_data -= self.logit_bias
pooling_params = get_pooling_params(pooling_metadata)
flags = [p.activation for p in pooling_params]
if len(set(flags)) == 1:
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
else:
scores = [
self.act_fn(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
# scores shape: [batchsize, num_labels]
return scores
class PoolerNormalize(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return torch.vacc.l2_norm(pooled_data, epsilon=1e-12)

View File

@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Literal, Type, get_args
QuantizationMethods = Literal[
# "aqlm",
"awq",
"deepspeedfp",
"tpu_int8",
"fp8",
"ptpc_fp8",
"fbgemm_fp8",
"modelopt",
"modelopt_fp4",
"bitblas",
"gguf",
"gptq_marlin_24",
"gptq_marlin",
"gptq_bitblas",
"awq_marlin",
"gptq",
"compressed-tensors",
"bitsandbytes",
"hqq",
"experts_int8",
"ipex",
"quark",
"moe_wna16",
"torchao",
"auto-round",
"rtn",
"inc",
"mxfp4",
"petit_nvfp4",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))

View File

@@ -0,0 +1,615 @@
import functools
import importlib.util
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_block_linear, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, expert_weight_is_col_major,
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
validate_fp8_block_shape)
# from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
# all_close_1d, apply_fp8_linear, convert_to_channelwise,
# cutlass_block_fp8_supported, cutlass_fp8_supported,
# normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
# requantize_with_max_scale)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, all_close_1d, convert_to_channelwise,
cutlass_block_fp8_supported, cutlass_fp8_supported,
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize, requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, is_layer_skipped)
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.utils import has_deep_gemm
logger = init_logger(__name__)
# has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def Fp8LinearMethod__init(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.quant_config.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static"
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
if self.block_quant:
self.block_size = self.quant_config.weight_block_size
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.scale_k = 1
self.scale_n = 1
self.scale_n_prefill = 1 # only for fp8 moe
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape)
class Fp8LinearMethod(LinearMethodBase):
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
if self.block_quant:
scale_n = extra_weight_attrs.get("scale_n")
scale_k = extra_weight_attrs.get("scale_k")
if scale_n is not None:
self.scale_n = scale_n
if scale_k is not None:
self.scale_k = scale_k
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_config.weight_block_size is not None
block_n, block_k = (
self.quant_config.weight_block_size[0] // self.scale_n ,
self.quant_config.weight_block_size[1] // self.scale_k ,
)
# Required by row parallel
if (tp_size > 1
and input_size // input_size_per_partition == tp_size
and input_size_per_partition % block_k != 0):
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# Required by column parallel or enabling merged weights
if (tp_size > 1 and output_size // output_size_per_partition
== tp_size) or len(output_partition_sizes) > 1:
for output_partition_size in output_partition_sizes:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}.")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
if not self.block_quant:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)
else:
assert self.quant_config.activation_scheme == "dynamic"
scale = BlockQuantScaleParameter(
data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n,
(input_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None:
# TODO(rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
if current_platform.is_fp8_fnuz():
weight, weight_scale_inv, _ = \
normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv)
else:
weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data
if isinstance(layer, QKVParallelLinear):
# NOTE: for QKVParallelLinear
# weight_scale should be divisible by 8 Dsps
shape = weight_scale_inv.shape[0]
repeat = 1
while shape % 8 != 0:
repeat *= 2
shape = shape * repeat
weight_scale_inv = torch.repeat_interleave(weight_scale_inv, repeats=repeat, dim=0)
# weight = self._maybe_pad_weight(weight)
# if self.block_quant:
# maybe_post_process_fp8_weight_block(
# layer, self.cutlass_block_fp8_supported)
# Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv,
requires_grad=False)
return
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
# Note: lazy import to avoid triton import error.
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear)
if self.block_quant:
assert self.quant_config.weight_block_size is not None
return apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=[layer.weight.shape[0] // layer.weight_scale_inv.shape[0], layer.weight.shape[1] // layer.weight_scale_inv.shape[1]],
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
)
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)
# return apply_fp8_linear(
# input=x,
# weight=layer.weight,
# weight_scale=layer.weight_scale,
# input_scale=layer.input_scale,
# bias=bias,
# cutlass_fp8_supported=self.cutlass_fp8_supported,
# # Default to using per_token quantization if cutlass is supported
# use_per_token_if_dynamic=self.cutlass_fp8_supported)
def Fp8MoEMethod_init_(self, quant_config: Fp8Config, layer: torch.nn.Module):
self.layer = layer
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
self.flashinfer_moe_backend = None
self.scale_k = 1
self.scale_n = 1
self.scale_n_prefill = 1
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm() or current_platform.is_vacc:
self.use_marlin = False
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm():
logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
" DeepGemm kernels")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True
else:
logger.warning_once(
"DeepGemm not supported on the current platform.")
# Check for CutlassBlockScaledGroupedGemm support.
self.allow_cutlass_block_scaled_grouped_gemm = False
if not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
"CutlassBlockScaledGroupedGemm kernels")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(100)):
logger.info_once(
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
)
self.allow_cutlass_block_scaled_grouped_gemm = True
else:
logger.warning_once(
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform.")
self.topk_indices_dtype = None
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
class Fp8MoEMethod(FusedMoEMethodBase):
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
if self.block_quant:
assert self.quant_config.weight_block_size is not None
scale_n = extra_weight_attrs.get("scale_n")
scale_n_prefill = extra_weight_attrs.get("scale_n_prefill")
scale_k = extra_weight_attrs.get("scale_k")
if scale_n is not None:
self.scale_n = scale_n
if scale_k is not None:
self.scale_k = scale_k
if scale_n_prefill is not None:
self.scale_n_prefill = scale_n_prefill
if self.quant_config is not None and self.quant_config.weight_block_size is not None:
self.gcd_value = self.quant_config.weight_block_size[0]
output_size_no_merge = intermediate_size_per_partition
#assert isinstance(output_size_no_merge, int), f"merge output size should divded int, valuue is: {output_size_no_merge}"
if output_size_no_merge % self.quant_config.weight_block_size[0]:
import math
gcd_value = math.gcd(output_size_no_merge % self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[0])
self.scale_n =self.scale_n * self.quant_config.weight_block_size[0] // gcd_value
self.scale_n_prefill =self.scale_n_prefill * self.quant_config.weight_block_size[0] // gcd_value
if hidden_size % self.quant_config.weight_block_size[1]:
import math
gcd_value = math.gcd(hidden_size % self.quant_config.weight_block_size[1], self.quant_config.weight_block_size[1])
self.scale_k =self.scale_k * self.quant_config.weight_block_size[1] // gcd_value
# self.scale_k = self.scale_n
# print('output_size_no_merge', output_size_no_merge)
# 按 block_size 分core
# output_size_no_merge = 384
# block_size = 128: 384 = 3x128 只能分3core x 128
# block_size = 16: 384 = 24x16 8core x (3x16) 可以分到 8core
# output_size_no_merge = 512
# block_size = 128: 512 = 4x128 只能分 4core x 128
# block_size = 64: 512 = 8x64 可以分到 8core x 64
# output_size_no_merge = 768
# block_size = 128: 768 = 6x128 只能分 6core x 128
# block_size = 32: 768 = 8x(3x32) 可以分到 8core x (3x32)
core_num = 8
min_block_size = 4
block_size_tmp = self.quant_config.weight_block_size[0] // self.scale_n
if output_size_no_merge > block_size_tmp and \
output_size_no_merge % block_size_tmp == 0 and \
output_size_no_merge // block_size_tmp < core_num and \
output_size_no_merge % core_num == 0:
core_num_old = output_size_no_merge // block_size_tmp
import math
gcd_value = math.gcd(core_num, core_num_old)
new_scale = core_num // gcd_value
if block_size_tmp // new_scale >= min_block_size:
self.scale_n = new_scale * self.scale_n
#print("moe scale n is:", self.scale_n, self.scale_k, intermediate_size_per_partition)
tp_size = get_tensor_model_parallel_world_size()
if self.scale_n != self.scale_n_prefill:
block_n_prefill = self.quant_config.weight_block_size[0] // self.scale_n_prefill
block_n, block_k = (
self.quant_config.weight_block_size[0] // self.scale_n,
self.quant_config.weight_block_size[1] // self.scale_k,
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}.")
if (tp_size > 1
and hidden_size % block_k != 0):
# Required by row parallel
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if not self.block_quant:
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, 2, dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
else:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n - 1) //
block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_k - 1) // block_k,
(intermediate_size_per_partition + block_n - 1) // block_n,
dtype=torch.float32,
),
requires_grad=False,
)
if self.scale_n != self.scale_n_prefill:
w13_weight_scale_prefill = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n_prefill - 1) //
block_n_prefill),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale_prefill = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_k - 1) // block_k,
(intermediate_size_per_partition + block_n_prefill - 1) // block_n_prefill,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv_prefill", w13_weight_scale_prefill)
layer.register_parameter("w2_weight_scale_inv_prefill", w2_weight_scale_prefill)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.
value} if self.block_quant else
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if self.scale_n != self.scale_n_prefill:
set_weight_attrs(w13_weight_scale_prefill, extra_weight_attrs)
set_weight_attrs(w2_weight_scale_prefill, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
w13_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def moe_fp8_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import FusedMoE
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
try:
from torch_vacc.vacc.custom_ops import fused_experts
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
experts_output = None
if memory_recycler is not None:
# remove MOE_EXPERT_OUT_BUFFER
# experts_output = memory_recycler.MOE_EXPERT_OUT_BUFFER
experts_output = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w13_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
decode_with_batch=layer.is_decode and x.shape[0] > 1,
output_opt=experts_output
)
except Exception as e:
print(f"vacc fused_expert run fail, now using unfused ops: {e}")
from torch_vacc.vacc.custom_ops_cpu import fused_experts
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w13_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)

View File

@@ -0,0 +1,282 @@
# SPDX-License-Identifier: Apache-2.0
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig as GPTQConfigOrig
from vllm.model_executor.layers.quantization.gptq import ExllamaState
from vllm_vacc.vllm.model_executor.models.vars import TRANSPOSE_GPTQ_WEIGHT
import math
def GPTQLinearMethod__init(self, quant_config: GPTQConfigOrig):
self.quant_config = quant_config
self.scale_k = 1
self.split_num = 4
def int32_to_int4(s0, axis = -2):
# 要先拉平 shape[1, n]
# 每个int32 拆成8个int4, 8个int32表示 得到[8, n]
# x32(int32) => 32bit => 4bit x 8 x4[8] 4bit
# x32 31-28 => x4[7]
# x32 27-24 => x4[6]
# ...
# x32 3-0 => x4[0]
# x32[index=0] => x4[7,6,5,4,3,2,1,0]
# 4bit转真实数字
# 不是按补码方式
# 1111 => 15 => 7
# 15-8 = 7
# 0101 => 6 =>-2
# 6-8 = -2
# 0x 6A CB 37 2B (内存中排列 2B 37 CB 6A => B273BCA6 => (-8) => int4: 3, -6, -1, -5, 3, 4, 2, -2
# 内存中实际排布为小端模式:
# int32: 2B 37 CB 6A => 2,11,3,7,12,11,6,10 => (-8) => -6,3, -5,-1, 4,3, -2,2 => 同一字节所在的两个交换得到 3, -6, -1, -5, 3, 4, 2, -2
# int4: 3, -6, -1, -5, 3, 4, 2, -2
s = s0.view(torch.uint32)
all = []
for i in range(8):
x = 15 << (i*4)
# s2 = torch.bitwise_and(x,s)
s2 = torch.from_numpy(np.bitwise_and(x, s.numpy()))
s3 = s2 / (2 ** (i*4))
s4 = s3.to(torch.int32)
# 补码, 结果不对
# s4[s4 > 7] = s4[s4 > 7]-16
# 直接 - 8 结果正确, 范围: -8-7
s4 = s4 - 8
all.append(s4.reshape(1,*s4.shape))
all = torch.concatenate(all, 0)
if axis == -2 or axis == 0:
# 8,K//8,N => K//8,8,N => K,N
all = all.transpose(-2,0).reshape(-1,all.shape[-1]).contiguous()
else:
# 8,N,K//8 => N,K//8,8 => N,K
all = all.permute(1,2,0).reshape(all.shape[-2],-1).contiguous()
return all
def dequant_weight(qw, scales, group_size = 128):
N = qw.shape[1]
int4_to_int32_axis = -2
if TRANSPOSE_GPTQ_WEIGHT:
N = qw.shape[0]
int4_to_int32_axis = -1
qweight = int32_to_int4(qw,int4_to_int32_axis).to(torch.float16) #int32 => 8 int4 +> fp16
if TRANSPOSE_GPTQ_WEIGHT:
scales = scales.T.contiguous()
qweight = qweight.T.contiguous()
scales = torch.concatenate([scales] * group_size, 1).reshape(-1, N) # scale 按 group_size 扩展, 每 group_size 个数共用一个scale
# print('qweight', qweight.shape, qweight.dtype)
# print('scale', scales.shape, scales.dtype)
dequant_weight = qweight * scales #dequant
return dequant_weight
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
Reference: https://arxiv.org/abs/2210.17323
"""
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
class GPTQLinearMethod(LinearMethodBase):
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
# if input_size_per_partition % self.quant_config.group_size != 0:
# raise ValueError(
# "The input size is not aligned with the quantized "
# "weight shape. This can be caused by too large "
# "tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
exllama_state = ExllamaState.UNINITIALIZED
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
# For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED
else:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
g_idx = RowvLLMParameter(data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
qzeros_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
layer.exllama_state = exllama_state
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# for torch.compile
# self.quant_config.weight_bits == 4
if TRANSPOSE_GPTQ_WEIGHT:
layer.qzeros = Parameter(layer.qzeros.data.T.contiguous(), requires_grad=False)
layer.qweight = Parameter(layer.qweight.data.T.contiguous(), requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data.T.contiguous(), requires_grad=False)
else:
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N
reshaped_x = x.reshape(-1, x.shape[-1])
# print(f"~~~~ start dequant")
# import time
# start_quant_time = time.time()
# weight = dequant_weight(layer.qweight.cpu(), layer.scales.cpu(), self.quant_config.group_size // self.scale_k).to(layer.qweight.device)
# end_quant_time = time.time()
# print(f"~~~~ dequant time: {end_quant_time - start_quant_time}")
# if torch.distributed.get_rank() == 0:
# print(f"~~~~ weight shape: {weight.shape}, dtype: {weight.dtype}")
# output = torch.matmul(reshaped_x, weight)
# print("entering GPTQLinearMethod apply, reshaped_x shape:", reshaped_x.shape, "reshaped_x stride", reshaped_x.stride(), "input_tensor", x.shape, "qweight shape:", layer.qweight.shape, "scales shape:", layer.scales.shape)
output = torch.vacc.w4a8_block_int4_matmul(
reshaped_x,
layer.qweight.transpose(-1, -2),
layer.scales.transpose(-1, -2),
[1, self.quant_config.group_size // self.scale_k],
)
# print("exiting GPTQLinearMethod apply, output shape:", output.shape)
# end_gemm_time = time.time()
# if torch.distributed.get_rank() == 0:
# print(f"~~~~ gemm time: {end_gemm_time - end_quant_time}")
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)

View File

@@ -0,0 +1,372 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional, Union
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
# [num_experts, N, K//8], int32 ==> [num_experts, N, K], int4 ==> [num_experts, N//8, K], int32
def repack_quant_moe_weight_old(original_packed_tensor):
num_experts = original_packed_tensor.shape[0]
N = original_packed_tensor.shape[1]
K = original_packed_tensor.shape[2] * 8
if original_packed_tensor.dtype != torch.int32:
raise ValueError("data type of input tensor should be int32")
if N % 8 != 0:
raise ValueError("N of input tensor should be divisible by 8")
# --- 1. 解包:将 int32 张量展开为逻辑上的 int4 张量 ---
# 创建一个临时张量来存储解包后的所有 int4 值
# 用 torch.uint8 作为 int4 的临时存储,因为 PyTorch 没有原生的 int4 dtype
unpacked_int4_tensor = torch.zeros(
(num_experts, N, K),
dtype=torch.uint8,
device=original_packed_tensor.device
)
mask = 0b1111
for i in range(8):
# 提取当前 int4 所需的 int32 块中的值
# 通过右移 (i * 4) 位,我们将第 i 个 4 位整数移动到最低有效位
# 然后通过按位与操作与掩码结合,提取出这 4 位的值
extracted_int4s = (original_packed_tensor >> (i * 4)) & mask
# 将提取出的 int4 值放置到 unpacked_int4_tensor 的正确位置
# 使用切片 `i::8`,意思是:从索引 `i` 开始,每隔 8 个位置填充一次
unpacked_int4_tensor[:, :, i::8] = extracted_int4s
# --- 2. 重新打包:将 int4 逻辑张量重新打包为新的 int32 张量 ---
new_packed_tensor = torch.zeros(
(num_experts, N//8, K),
dtype=torch.int32,
device=original_packed_tensor.device
)
for i in range(8):
# 从解包后的 int4 张量中提取当前需要打包的 int4 序列,使用切片 `i::8` 沿着N方向来提取
current_int4_segment = unpacked_int4_tensor[:, i::8, :]
# 将这个 int4 序列转换为 int32 类型(因为打包到 int32并左移到其在新 int32 块中的正确位置
# 然后使用按位或操作符将其合并到 new_packed_tensor 中
new_packed_tensor |= (current_int4_segment.to(torch.int32) << (i * 4))
return new_packed_tensor
def repack_quant_moe_weight(original_packed_tensor):
if original_packed_tensor.dtype != torch.int32:
raise ValueError("data type of input tensor should be int32")
num_experts, N, K_packed = original_packed_tensor.shape
K = K_packed * 8
if N % 8 != 0:
raise ValueError("N of input tensor should be divisible by 8")
new_packed_tensor = torch.zeros((num_experts, N // 8, K),
dtype=torch.int32,
device=original_packed_tensor.device)
for i in range(8):
source_slice = original_packed_tensor[:, i::8, :]
for j in range(8):
unpacked_strip = (source_slice >> (j * 4)) & 0b1111
new_packed_tensor[:, :, j::8] |= (unpacked_strip.to(torch.int32) << (i * 4))
return new_packed_tensor
class MoeWNA16Method(FusedMoEMethodBase):
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
self.moe = layer
layer.quant_config = self.quant_config
bit8_pack_factor = self.quant_config.bit8_pack_factor
bit32_pack_factor = 32 // self.quant_config.weight_bits
group_size = self.quant_config.group_size
group_size_div_factor = 1
group_size_w13 = self.quant_config.group_size
group_size_div_factor_w13 = 1
group_size_w2 = self.quant_config.group_size
group_size_div_factor_w2 = 1
# make intermediate_size and hidden_size divisible by group_size
# we reduce the group size to ensure that
# and we would repeat the loaded_weight later
while intermediate_size_per_partition % group_size or \
hidden_size % group_size:
group_size = group_size // 2
group_size_div_factor *= 2
assert group_size >= 32
layer.group_size = group_size
layer.group_size_div_factor = group_size_div_factor
while intermediate_size_per_partition % group_size_w2:
group_size_w2 = group_size_w2 // 2
group_size_div_factor_w2 *= 2
assert group_size_w2 >= 32
layer.w2_block_size = group_size_w2
layer.group_size_div_factor_w2 = group_size_div_factor_w2
while hidden_size % group_size_w13:
group_size_w13 = group_size_w13 // 2
group_size_div_factor_w13 *= 2
assert group_size_w13 >= 32
layer.w13_block_size = group_size_w13
layer.group_size_div_factor_w13 = group_size_div_factor_w13
strategy = FusedMoeWeightScaleSupported.GROUP.value
extra_weight_attrs.update({
"quant_method": strategy,
"is_transposed": False
})
assert 'weight_loader' in extra_weight_attrs
weight_loader = extra_weight_attrs['weight_loader']
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
layer, weight_loader)
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
# Fused gate_up_proj (column parallel)
# w13_qweight = torch.nn.Parameter(torch.empty(
# num_experts,
# 2 * intermediate_size_per_partition,
# hidden_size // bit8_pack_factor,
# dtype=torch.uint8),
# requires_grad=False)
w13_qweight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // bit32_pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
# down_proj (row parallel)
# w2_qweight = torch.nn.Parameter(torch.empty(
# num_experts,
# hidden_size,
# intermediate_size_per_partition // bit8_pack_factor,
# dtype=torch.uint8),
# requires_grad=False)
w2_qweight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // bit32_pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
w13_scales = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // group_size_w13,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
w2_scales = torch.nn.Parameter(torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // group_size_w2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
if self.quant_config.has_zp:
w13_qzeros = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition // bit8_pack_factor,
hidden_size // group_size,
dtype=torch.uint8),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
w2_qzeros = torch.nn.Parameter(torch.zeros(
num_experts,
hidden_size // bit8_pack_factor,
intermediate_size_per_partition // group_size,
dtype=torch.uint8),
requires_grad=False)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
if self.quant_config.linear_quant_method == "gptq":
# some param are unused, but we need to init them in order to
# load weights
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
if not self.quant_config.has_zp:
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
for key in invalid_param_keys:
param = torch.nn.Parameter(torch.empty((0, ),
dtype=torch.int32),
requires_grad=False)
layer.register_parameter(key, param)
set_weight_attrs(param, extra_weight_attrs)
@staticmethod
def get_weight_loader(layer, weight_loader):
def convert_awq_tensor(tensor, tensor_type):
# convert awq qweight/qzeros to a standard format (assume int4)
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
# qzeros: (k // group_size, n // pack_factor_bit32) ->
# (n // pack_factor_bit8, k // group_size)
# pack_factor_bit32 = 32 // weight_bits
# pack_factor_bit8 = 8 // weight_bits
# 0. suppose origin shape (a, b), dtype int32
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
size0 = tensor.size(0)
tensor = tensor.view(torch.uint8)
# 2. unpack to uint4 (only when weight_bits == 4)
# shape (a, 4 * b) -> (a, 4 * b, 2)
shifter = torch.tensor([0, 4],
dtype=torch.uint8,
device=tensor.device)
tensor = (tensor[:, :, None] >> shifter) & 0xF
# 3. change order, see
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
# shape -> (a, 4 * b * pack_factor_bit8)
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
tensor = tensor.view(size0, -1)
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
tensor = tensor.T.contiguous()
# 5. repack (only when weight_bits == 4)
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
# qzeros shape -> (4 * b, a)
if tensor_type == "qweight":
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
elif tensor_type == "qzeros":
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
return tensor
def convert_gptq_int4_qzeros(tensor):
tensor = tensor.view(torch.uint8)
shifter = torch.tensor([0, 4],
dtype=torch.uint8,
device=tensor.device)
tensor = (tensor[:, :, None] >> shifter) & 0xF
tensor = tensor + 1
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
return tensor
def moe_wna16_weight_loader(param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
return_success: bool = False):
if "g_idx" in weight_name:
return False if return_success else None
if not layer.quant_config.has_zp and "qzeros" in weight_name:
return False if return_success else None
device = get_tp_group().device
tp_rank = get_tensor_model_parallel_rank()
loaded_weight = loaded_weight.to(device)
shard_size = layer.intermediate_size_per_partition
# convert gptq and awq weight to a standard format
if layer.quant_config.linear_quant_method == "awq":
assert layer.quant_config.weight_bits == 4
if "weight" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight,
"qweight")
elif "zeros" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
else:
loaded_weight = loaded_weight.T
elif layer.quant_config.linear_quant_method == "gptq":
assert layer.quant_config.weight_bits in [4, 8]
if "weight" in weight_name:
# loaded_weight = loaded_weight.T.contiguous().view(
# torch.uint8)
loaded_weight = loaded_weight.T.contiguous()
elif "zeros" in weight_name:
# add 1 to gptq qzeros to align with awq
loaded_weight = loaded_weight.view(torch.uint8)
if layer.quant_config.weight_bits == 4:
loaded_weight = convert_gptq_int4_qzeros(
loaded_weight).T
else:
loaded_weight = loaded_weight.T + 1
else:
# loaded_weight = loaded_weight.T
loaded_weight = loaded_weight.T.contiguous()
# repeat the qzeros/scales to fit new group size
if layer.group_size_div_factor_w13 > 1 and \
"qzeros" in weight_name or "scales" in weight_name and \
shard_id == "w1" or shard_id == "w3":
loaded_weight = loaded_weight.repeat_interleave(
layer.group_size_div_factor_w13, 1)
elif layer.group_size_div_factor_w2 > 1 and \
"qzeros" in weight_name or "scales" in weight_name and \
shard_id == "w2":
loaded_weight = loaded_weight.repeat_interleave(
layer.group_size_div_factor_w2, 1)
elif layer.group_size_div_factor > 1 and \
"qzeros" in weight_name or "scales" in weight_name:
loaded_weight = loaded_weight.repeat_interleave(
layer.group_size_div_factor, 1)
if "w13_qzeros" in weight_name:
tensor = loaded_weight.view(layer.tp_size, -1,
loaded_weight.size(1))[tp_rank]
if shard_id == "w1":
param.data[expert_id, :shard_size // 2] = tensor
else:
param.data[expert_id, shard_size // 2:] = tensor
return True if return_success else None
elif "w2_qzeros" in weight_name:
param.data[expert_id] = loaded_weight.view(
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
return True if return_success else None
else:
# Delegate to the original loader, passing return_success
return weight_loader(param,
loaded_weight,
weight_name,
shard_id,
expert_id,
return_success=return_success)
return moe_wna16_weight_loader
def process_weights_after_loading(self, layer):
dev_w2 = layer.w2_qweight.device
# torch.Size([128, 2048, 24]), torch.int32, strides: (49152, 24, 1)
# ======>
# torch.Size([128, 256, 192]), torch.int32, strides: (49152, 1, 256)
layer.w2_qweight = torch.nn.Parameter(repack_quant_moe_weight(layer.w2_qweight.cpu()).transpose(-1, -2).contiguous().transpose(-1, -2).to(device=dev_w2), requires_grad=False)
# torch.Size([128, 2048, 3]), torch.float16, strides: (6144, 3, 1)
# ======>
# torch.Size([128, 2048, 3]), torch.float16, strides: (6144, 1, 2048)
layer.w2_scales = torch.nn.Parameter(layer.w2_scales.transpose(-1, -2).contiguous().transpose(-1, -2), requires_grad=False)

View File

@@ -0,0 +1,39 @@
import torch
from typing import List, Optional
def _apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = True,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
assert len(block_size) == 2, "only support dim2 block now"
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
try:
from torch_vacc.vacc.custom_ops import w8a8_block_fp8_linear
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
mla_oproj_output = None
if memory_recycler is not None:
os1, os2 = memory_recycler.MLA_OPROJ_OUT_BUFFER.shape
if os1 == input_2d.size(0) and os2 == weight.size(0):
mla_oproj_output = memory_recycler.MLA_OPROJ_OUT_BUFFER
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size, output = mla_oproj_output)
except Exception as e:
print("vacc fuse fp8 matmul run fail:", e, " , now use unfused ops")
from torch_vacc.vacc.custom_ops_cpu import w8a8_block_fp8_linear
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)

View File

@@ -0,0 +1,9 @@
from .rotary_embedding import (
RotaryEmbedding_init_vacc,
RotaryEmbedding_forward_vacc,
ScalingRotaryEmbedding_forward_vacc,
_compute_inv_freq_vacc,
_deepseek_compute_cos_sin_cache_vacc,
_yarn_compute_cos_sin_cache_vacc,
_compute_cos_sin_cache_vacc
)

View File

@@ -0,0 +1,101 @@
import itertools
from typing import Optional, Union
import numpy as np
import torch
from transformers import PretrainedConfig
class MRotaryEmbedding:
@classmethod
def _qwen3vl_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw
for _ in range(t)]
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta

View File

@@ -0,0 +1,203 @@
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
# from vllm.model_executor.layers.rotary_embedding import _apply_rotary_emb
# from vllm.model_executor.layers.rotary_embedding import _yarn_find_correction_range, _yarn_linear_ramp_mask
from vllm.model_executor.layers.rotary_embedding.common import yarn_find_correction_range as _yarn_find_correction_range
from vllm.model_executor.layers.rotary_embedding.common import yarn_linear_ramp_mask as _yarn_linear_ramp_mask
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from ...ops.mrope_op import get_sin_cos_mrope
def RotaryEmbedding_init_vacc(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super(CustomOp, self).__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
# cache = self._compute_cos_sin_cache()
cos, sin = self._compute_cos_sin_cache()
cos = cos.to(dtype)
sin = sin.to(dtype)
self.register_buffer("cos_cache", cos, persistent=False)
self.register_buffer("sin_cache", sin, persistent=False)
# cache = cache.to(dtype)
# self.cos_sin_cache: torch.Tensor
# self.register_buffer("cos_sin_cache", cache, persistent=False)
def RotaryEmbedding_forward_vacc(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-vacc implementation of forward()."""
if offsets is not None:
positions = positions + offsets
num_tokens = positions.numel()
# positions = positions.flatten()
# num_tokens = positions.shape[0]
# cos_sin = self.cos_sin_cache.index_select(0, positions)
# cos_sin = self.cos_sin_cache[positions]
# cos, sin = cos_sin.chunk(2, dim=-1)
if isinstance(self, MRotaryEmbedding):
# get mrope sin/cos
cos, sin = get_sin_cos_mrope(self, positions)
num_tokens = num_tokens//3
else:
positions = positions.flatten()
cos = self.cos_cache[positions]
sin = self.sin_cache[positions]
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
mode = "neox"
if not self.is_neox_style:
mode = "gptj"
query_rot, key_rot=torch.vacc.RotaryPosEmbedding(query_rot, key_rot, cos, sin, 0, mode)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
return query, key
def ScalingRotaryEmbedding_forward_vacc(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
# self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
# positions.device)
# cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
# if offsets is not None else positions]
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
# cos_sin = self.cos_sin_cache[positions]
# cos, sin = cos_sin.chunk(2, dim=-1)
cos = self.cos_cache[positions]
sin = self.sin_cache[positions]
# TODO: to be removed (require odsp support)
# if self.is_neox_style:
# # NOTE(woosuk): Here we assume that the positions tensor has the
# # shape [batch_size, seq_len].
# cos = cos.repeat(1, 1, 2).unsqueeze(-2)
# sin = sin.repeat(1, 1, 2).unsqueeze(-2)
# else:
# cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
# sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
# rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
mode = "neox" if self.is_neox_style else "gptj"
# query_rot = query_rot * cos + rotate_fn(query_rot) * sin
# key_rot = key_rot * cos + rotate_fn(key_rot) * sin
query_rot, key_rot=torch.vacc.RotaryPosEmbedding(query_rot, key_rot, cos, sin, 0, mode)
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key
def _compute_inv_freq_vacc(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device=current_platform.device_type) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _deepseek_compute_cos_sin_cache_vacc(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device=current_platform.device_type,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
# NOTE: for odsp friendly
# seperate cos/sin cache can gurantee cos/sin
# always has contigous layout for dim[-1]
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
return cos, sin
# cache = torch.cat((cos, sin), dim=-1)
# return cache
def _yarn_compute_cos_sin_cache_vacc(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device=current_platform.device_type,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
return cos, sin
# cache = torch.cat((cos, sin), dim=-1)
# return cache
def _compute_cos_sin_cache_vacc(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings,
device=current_platform.device_type,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
# NOTE: for odsp friendly
# seperate cos/sin cache can gurantee cos/sin
# always has contigous layout for dim[-1]
cos = freqs.cos()
sin = freqs.sin()
return cos, sin
# cache = torch.cat((cos, sin), dim=-1)
# return cache
# import vllm.model_executor.layers.rotary_embedding as rotary_embedding
# rotary_embedding.RotaryEmbedding.forward_vacc=RotaryEmbedding_forward_vacc
# rotary_embedding.DeepseekScalingRotaryEmbedding._compute_inv_freq=_compute_inv_freq_vacc
# rotary_embedding.DeepseekScalingRotaryEmbedding._compute_cos_sin_cache=_compute_cos_sin_cache_vacc

View File

@@ -0,0 +1,542 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""
import itertools
import warnings
from dataclasses import dataclass
from importlib.util import find_spec
from math import inf
from typing import Dict, Iterator, List, Optional, Tuple, Union
import msgspec
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.model_executor.layers.utils import apply_penalties
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors,
SequenceGroupToSample)
from vllm.sampling_params import SamplingType
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.model_executor.layers.sampler import (SamplerOutput,
_apply_min_tokens_penalty,
_apply_top_k_top_p,
_apply_min_p,
_sample,
SampleResultArgsType,
get_logprobs,
_build_sampler_output,
SampleReturnType,
SampleResultsDictType,
SampleMetadataType,
MultinomialSamplesType,
_modify_greedy_probs_inplace,
_top_k_top_p_multinomial_with_flashinfer,
_multinomial,
get_pythonized_sample_results,
)
from vllm_vacc.vllm.model_executor.models.vars import USE_DS3_SAMPLER as use_ds3_sampler
from vllm_vacc.vllm.model_executor.models.vars import USE_DS3_SAMPLER_OP as use_ds3_sampler_op
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
# yapf: disable
from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
# yapf: enable
else:
flashinfer_top_k_top_p_sampling = None
class SamplerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This data structure implements methods, so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs: List[CompletionSequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional[torch.Tensor] = None
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
# Holds either (1) the pythonized sampler result (single-step scheduling)
# or (2) what will be arguments for later deferred pythonization of the
# sampler result (muliti-step scheduling)
deferred_sample_results_args: Optional[SampleResultArgsType] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
# CPU tensor containing the sampled token ids. Used during multi-step to
# return the sampled token ids from last rank to AsyncLLMEngine to be
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu: Optional[torch.Tensor] = None
# On-device tensor containing the sampled token embeddings (embeddings
# corresponding to the sampled token ids). Used when prompt embeddings are
# specified in lieu of prompt token ids or text.
sampled_token_embeds: Optional[torch.Tensor] = None
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None
# Optional prefill hidden states from the model
# (used for models like EAGLE).
prefill_hidden_states: Optional[torch.Tensor] = None
# Time taken in the forward pass for this across all workers
model_forward_time: Optional[float] = None
# Time taken in the model execute function. This will include model forward,
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None
def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
return self.outputs[idx]
def __setitem__(self, idx: int, value):
self.outputs[idx] = value
def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
return iter(self.outputs)
def __len__(self):
return len(self.outputs)
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs
def __repr__(self) -> str:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
else self.sampled_token_probs.shape)
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
self.sampled_token_ids.shape)
return (
f"SamplerOutput(outputs={self.outputs}, "
f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr},")
def Sampler_forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert logits is not None
# print(f'Sampler_forward all_greedy={all_greedy}')
# Prepare sampling tensors with pinned memory to avoid blocking.
if not sampling_metadata.reuse_sampling_tensors:
self._init_sampling_tensors(logits, sampling_metadata)
elif self._do_penalties:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self._init_sampling_tensors(logits, sampling_metadata)
assert self._sampling_tensors is not None
sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
do_min_p = self._do_min_p
is_greedy = (len(sampling_metadata.categorized_sample_indices[SamplingType.GREEDY]) == logits.shape[0])
is_random = (len(sampling_metadata.categorized_sample_indices[SamplingType.RANDOM]) == logits.shape[0])
is_random_seed = (len(sampling_metadata.categorized_sample_indices[SamplingType.RANDOM_SEED]) == logits.shape[0])
max_n_in_batch = sampling_metadata.seq_groups[0].sampling_params.n
generator = sampling_metadata.seq_groups[0].generator
min_tokens = sampling_metadata.seq_groups[0].sampling_params.min_tokens
# print("use_ds3_sampler ", use_ds3_sampler)
if use_ds3_sampler == True and (is_greedy == True or ((is_random == True or is_random_seed == True) \
and do_penalties == False \
and flashinfer_top_k_top_p_sampling is None \
and min_tokens <= 0 \
and do_min_p == False \
and max_n_in_batch == 1 \
# and self._should_modify_greedy_probs_inplace == False
# and self.include_gpu_probs_tensor == False
)):
sampling_type = SamplingType.GREEDY
sample_metadata: SampleMetadataType = {}
multinomial_samples: MultinomialSamplesType = {}
greedy_samples: Optional[torch.Tensor] = None
multinomial_out: Optional[torch.Tensor] = None
vacc_device = logits.device
# Create output tensor for sampled token ids.
if self.include_gpu_probs_tensor:
sampled_token_ids_tensor = torch.full((logits.shape[0], 1),
VLLM_INVALID_TOKEN_ID,
dtype=torch.long,
device=vacc_device)
probs_out = torch.empty_like(logits)
logprobs_out = torch.empty_like(logits)
else:
probs_out = None
logprobs_out = None
sampled_token_ids_tensor = None
if is_greedy == True:
greedy_samples, _ = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 0)
sampling_type = SamplingType.GREEDY
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor = greedy_samples.unsqueeze(-1).to(torch.long)
if probs_out is not None:
# probs_out = torch.softmax(logits.to(torch.float), dim=-1, dtype=torch.float).to(logits)
probs_out = torch.softmax(logits, dim=-1)
if self._should_modify_greedy_probs_inplace == True:
sample_indices = (sampling_metadata.categorized_sample_indices[SamplingType.GREEDY]).long()
probs_out[sample_indices, :] = 0
probs_out[sample_indices, greedy_samples] = 1.0
elif is_random == True and do_top_p_top_k == True:
if use_ds3_sampler_op:
logits = logits.to(torch.float)
multinomial_out, probs_out = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 2)
multinomial_out = multinomial_out.view(-1, max_n_in_batch)
else:
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
logits = torch.vacc.topk_topp(logits, sampling_tensors.top_ps, sampling_tensors.top_ks)
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
probs_out = probs
# multinomial_out = torch.multinomial(probs, 1)
q = torch.empty_like(probs)
q.exponential_()
multinomial_out = probs.div_(q).argmax(dim=1).view(-1, max_n_in_batch)
sampling_type = SamplingType.RANDOM
elif is_random_seed == True and generator is not None and do_top_p_top_k == True:
if use_ds3_sampler_op:
# print("is_random_seed ", is_random_seed)
logits = logits.to(torch.float)
multinomial_out, probs_out = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 1, generator)
multinomial_out = multinomial_out.view(-1, max_n_in_batch)
else:
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
logits = torch.vacc.topk_topp(logits, sampling_tensors.top_ps, sampling_tensors.top_ks).to(torch.float)
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
probs_out = probs
# torch.manual_seed(sampling_metadata.seq_groups[0].sampling_params.seed)
# multinomial_out = torch.multinomial(probs, 1)
q = torch.empty_like(probs)
q.exponential_(generator=generator)
multinomial_out = probs.div_(q).argmax(dim=1).view(-1, max_n_in_batch)
sampling_type = SamplingType.RANDOM_SEED
multinomial_samples[sampling_type] = multinomial_out
if sampled_token_ids_tensor is not None:
if(sampling_type != SamplingType.GREEDY):
# Store sampled tokens in output tensor.
sampled_token_ids_tensor = multinomial_samples[sampling_type].to(torch.long)
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
t: []
for t in SamplingType
}
for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
sample_results_dict: SampleResultsDictType = {}
maybe_deferred_args = SampleResultArgsType(
sampling_metadata=sampling_metadata,
sample_metadata=sample_metadata,
multinomial_samples=multinomial_samples,
greedy_samples=greedy_samples,
# beam_search_logprobs=None,
sample_results_dict=sample_results_dict)
if not sampling_metadata.skip_sampler_cpu_output:
# GPU<->CPU sync happens here.
# This also converts the sampler output to a Python object.
# Return Pythonized sampler result & sampled token ids
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = get_pythonized_sample_results(
maybe_deferred_args), sampled_token_ids_tensor
else:
# Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = (
maybe_deferred_args,
sampled_token_ids_tensor,
)
if self.include_gpu_probs_tensor:
on_device_tensors = (probs_out, logprobs_out, maybe_sampled_tokens_tensor)
else:
on_device_tensors = None
# Get the logprobs query results.
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
logprobs = logits
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Apply presence and frequency penalties.
# if do_penalties:
# logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
# sampling_tensors.output_tokens,
# sampling_tensors.presence_penalties.to(logits.device),
# sampling_tensors.frequency_penalties.to(logits.device),
# sampling_tensors.repetition_penalties.to(logits.device))
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps.to(logits.device),
sampling_tensors.top_ks.to(logits.device))
if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
)
if self.include_gpu_probs_tensor:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors = None
# Get the logprobs query results.
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
def rejection_forward(
self,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
if seeded_seqs is None:
out, index = torch.vacc.rejection_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, 1)
else:
out, index = torch.vacc.rejection_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, 0, seeded_seqs[0])
return out
class Sampler(nn.Module):
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert logits is not None
_, vocab_size = logits.shape
# Prepare sampling tensors with pinned memory to avoid blocking.
if not sampling_metadata.reuse_sampling_tensors:
self._init_sampling_tensors(logits, sampling_metadata)
elif self._do_penalties:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self._init_sampling_tensors(logits, sampling_metadata)
assert self._sampling_tensors is not None
sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
do_min_p = self._do_min_p
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Apply presence and frequency penalties.
if do_penalties:
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
sampling_tensors.output_tokens,
sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
# print("tempratures is:", temperatures)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1).to(logits.device))
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
)
if self.include_gpu_probs_tensor:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors = None
# Get the logprobs query results.
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
def _apply_top_k_top_p_vacc(
logits: torch.Tensor,
p: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1).to(probs_sum.device)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = torch.empty_like(logits_sort).scatter_(dim=-1,
index=logits_idx,
src=logits_sort)
return logits

View File

@@ -0,0 +1,69 @@
import torch
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.platforms import current_platform
from typing import Tuple
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
def VocabParallelEmbedding_forward(self, input_):
try:
if self.tp_size > 1:
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
parallel_embedding_output = None
if memory_recycler is not None:
if memory_recycler.EMBEDDING_OUT_BUFFER.size(0) == input_.size(0):
parallel_embedding_output = memory_recycler.EMBEDDING_OUT_BUFFER.to(self.weight.dtype)
output_parallel = torch.vacc.parallel_embedding(
input_,
self.weight,
self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index,
output = parallel_embedding_output
)
else:
raise ValueError("not support non-tp")
except:
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self,
masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
#TODO: fuse all_reduce
return tensor_model_parallel_all_reduce(output_parallel)

View File

@@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import nn
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from .vars import *
class BertLayer(nn.Module):
def forward(self, hidden_states: torch.Tensor):
if USE_FUSED_BERT_ATTENTION:
tp_group = get_tp_group()
world_size = tp_group.world_size
rank = tp_group.rank_in_group
total_bytes = hidden_states.numel() * hidden_states.element_size() * world_size
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
if isinstance(attn_metadata_all, dict):
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
else:
attn_metadata = attn_metadata_all
# (matmul + bias_add) with TP + all_reduce结构为了避免重复加bias只对rank 0下发bias做bias add
# bert layer里BertSelfOutput和BertOutput模块存在这种结构对应的bias参数是下面的self_bias和output_bias
if total_bytes < 4194304 or world_size == 1:
# 1. TP场景all_reduce输入小于4MB时会在以下融合算子里调用dsp all_reduce大于等于4MB时由于限制需要在外面调用vccl all_reduce
# 2. 没有TP的场景也会调用下面的融合算子
output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
qkv_weight=self.attention.self.qkv_proj.weight,
qkv_bias=self.attention.self.qkv_proj.bias,
self_weight=self.attention.output.dense.weight,
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
self_norm_weight=self.attention.output.LayerNorm.weight,
self_norm_bias=self.attention.output.LayerNorm.bias,
intermediate_weight=self.intermediate.dense.weight,
intermediate_bias=self.intermediate.dense.bias,
output_weight=self.output.dense.weight,
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
output_norm_weight=self.output.LayerNorm.weight,
output_norm_bias=self.output.LayerNorm.bias,
dense_out=torch.Tensor(),
seqs=attn_metadata.seq_lens,
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.FullStage,
sm_scale=self.attention.self.scaling,
num_q_heads=self.attention.self.num_heads * world_size,
num_kv_heads=self.attention.self.num_kv_heads * world_size,
flash_attention=False,
reduce_result=True if world_size > 1 else False,
world_size=world_size,
rank=rank,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
else:
attn_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
qkv_weight=self.attention.self.qkv_proj.weight,
qkv_bias=self.attention.self.qkv_proj.bias,
self_weight=self.attention.output.dense.weight,
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
self_norm_weight=torch.Tensor(),
self_norm_bias=torch.Tensor(),
intermediate_weight=torch.Tensor(),
intermediate_bias=torch.Tensor(),
output_weight=torch.Tensor(),
output_bias=torch.Tensor(),
output_norm_weight=torch.Tensor(),
output_norm_bias=torch.Tensor(),
dense_out=torch.Tensor(),
seqs=attn_metadata.seq_lens,
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.AttnOutStage,
sm_scale=self.attention.self.scaling,
num_q_heads=self.attention.self.num_heads * world_size,
num_kv_heads=self.attention.self.num_kv_heads * world_size,
flash_attention=False,
reduce_result=False,
world_size=world_size,
rank=rank,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
if world_size > 1:
attn_out_stage_output = tensor_model_parallel_all_reduce(attn_out_stage_output)
if USE_FUSED_MLP_VISION:
attn_output = self.attention.output.LayerNorm(attn_out_stage_output + hidden_states)
inter_out_stage_output = torch.vacc.fuse_mlp_vision(src=attn_output,
weights_13=self.intermediate.dense.weight,
weights_2=self.output.dense.weight,
weights_13_bias=self.intermediate.dense.bias,
weights_2_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
act_type=0 # gelu
)
else:
inter_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
qkv_weight=torch.Tensor(),
qkv_bias=torch.Tensor(),
self_weight=torch.Tensor(),
self_bias=torch.Tensor(),
self_norm_weight=self.attention.output.LayerNorm.weight,
self_norm_bias=self.attention.output.LayerNorm.bias,
intermediate_weight=self.intermediate.dense.weight,
intermediate_bias=self.intermediate.dense.bias,
output_weight=self.output.dense.weight,
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
output_norm_weight=torch.Tensor(),
output_norm_bias=torch.Tensor(),
dense_out=attn_out_stage_output,
seqs=attn_metadata.seq_lens,
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.InterOutStage,
sm_scale=self.attention.self.scaling,
num_q_heads=self.attention.self.num_heads * world_size,
num_kv_heads=self.attention.self.num_kv_heads * world_size,
flash_attention=False,
reduce_result=False,
world_size=world_size,
rank=rank,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
if world_size > 1:
inter_out_stage_output = tensor_model_parallel_all_reduce(inter_out_stage_output)
if USE_FUSED_MLP_VISION:
output = self.output.LayerNorm(inter_out_stage_output + attn_output)
else:
output = self.output.LayerNorm(inter_out_stage_output + attn_out_stage_output)
else:
attn_output = self.attention(hidden_states)
intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output)
return output

View File

@@ -0,0 +1,292 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Set, Tuple, Optional
import torch
import torch.nn as nn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.forward_context import ForwardContext, get_forward_context
from .vars import *
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictorLayer as DeepSeekMultiTokenPredictorLayerOrig
from vllm.distributed import get_tp_group
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
def DeepSeekMultiTokenPredictorLayer__init__(self, vllm_config: VllmConfig, prefix: str) -> None:
super(DeepSeekMultiTokenPredictorLayerOrig, self).__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if USE_PARALLEL_MTP_EH_PROJ:
self.eh_proj = RowParallelLinear(config.hidden_size * 2,
config.hidden_size,
bias=False,
return_bias=False)
else:
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
from vllm.model_executor.models.deepseek_mtp import SharedHead
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda")
else:
topk_indices_buffer = None
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer)
class DeepSeekMultiTokenPredictorLayer(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
if not hasattr(self, "weight_capture"):
from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekMTPWegitCapture
self.weight_capture = DeepseekMTPWegitCapture(self.mtp_block)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
if inputs_embeds.shape[0] > 256:
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
deepseek_mtp_layer_input_buffer = None
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler):
deepseek_mtp_layer_input_buffer = memory_recycler.DEEPSEEK_MTP_LAYER_INPUT
from torch_vacc.vacc.custom_ops import fuse_mtp_stage0
hidden_states = fuse_mtp_stage0(
inputs_embeds,
previous_hidden_states,
positions,
self.enorm.weight,
self.hnorm.weight,
self.enorm.variance_epsilon,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos,
output=deepseek_mtp_layer_input_buffer,
)
# if USE_PARALLEL_MTP_EH_PROJ:
# tp_size = get_tensor_model_parallel_world_size()
# rank_id = get_tensor_model_parallel_rank()
# last_dim = hidden_states.shape[-1]
# if tp_size > 1:
# hiddens_tp = last_dim//tp_size
# hidden_states = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp]
hidden_states = self.eh_proj(hidden_states)
else:
hidden_states = torch.vacc.fuse_mtp_allreduce(
inputs_embeds,
previous_hidden_states,
positions,
self.enorm.weight,
self.hnorm.weight,
self.eh_proj.weight,
self.enorm.variance_epsilon,
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
rank = self.weight_capture.layer_moe.dist_args._1_rank,
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE):
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
else:
from torch_vacc.vacc.custom_ops import fuse_mla_moe_v2_allreduce_decode
layer = self.mtp_block
layer_id = 0
kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine]
positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens]
cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions]
sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions]
# 对于MTP Layer来说 residual为None且需要返回residual
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode(
hidden_states = hidden_states,
residual = None,
hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[layer_id],
q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[layer_id],
q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[layer_id],
q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[layer_id],
w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[layer_id],
w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[layer_id],
w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[layer_id],
w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[layer_id],
w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[layer_id],
w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[layer_id],
kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[layer_id],
sin_cache = sin_cache,
cos_cache = cos_cache,
slot_mapping = attn_metadata.slot_mapping,
kv_cache = kv_cache,
block_tables = attn_metadata.decode_metadata.block_tables,
block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[layer_id],
w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[layer_id],
o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[layer_id],
o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[layer_id],
# mla params
seq_lens = attn_metadata.decode_metadata.seq_lens,
sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale,
head_num = self.weight_capture.layer_moe.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# moe weight
rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[layer_id],
mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[layer_id],
mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[layer_id],
mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[layer_id],
mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[layer_id],
moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[layer_id],
moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[layer_id],
moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[layer_id],
moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[layer_id],
mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[layer_id],
moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[layer_id],
# moe params
mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
# vccl info
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
rank = self.weight_capture.layer_moe.dist_args._1_rank,
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
#hidden_states = residual + hidden_states
hidden_states = residual.add_(hidden_states)
return hidden_states
class DeepSeekMTP(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
from vllm.model_executor.models.deepseek_v2 import get_spec_layer_idx_from_weight_name
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if USE_MERGE_Q_KV_GEN_AND_Q_QR:
from vllm.model_executor.models.utils import PPMissingLayer
for layer_id in self.model.layers:
layer = self.model.layers[layer_id]
if isinstance(layer_id, PPMissingLayer):
continue
layer.mtp_block.self_attn.merge_qkv_weights()
return loaded_params
class SharedHead(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
try:
from torch_vacc.vacc.custom_ops import rms_norm
return rms_norm(hidden_states, self.norm.weight, output=hidden_states)
except Exception as e:
print(f"fuse rms_norm run fail, now use unfused ops: {e}")
return self.norm(hidden_states)

View File

@@ -0,0 +1,658 @@
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tp_group,
get_tensor_model_parallel_world_size,get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm.model_executor.models.deepseek_v2 import yarn_get_mscale, DeepseekV2MLAAttention, Indexer
from vllm.logger import init_logger
logger = init_logger(__name__)
from .vars import *
from ..ops.deepseek_fused_mlp_moe import (vacc_fused_decode_moe_fp8,
vacc_fused_prefill_moe_fp8,
vacc_fused_mlp_fp8)
from .fused_forward import *
import os
test_layer_en = os.getenv("test_layer_en", "0")
# class DeepseekV2MLAAttention(nn.Module):
# def __init__(
# self,
# vllm_config: VllmConfig,
# config: Union[DeepseekV2Config, DeepseekV3Config],
# hidden_size: int,
# num_heads: int,
# qk_nope_head_dim: int,
# qk_rope_head_dim: int,
# v_head_dim: int,
# q_lora_rank: Optional[int],
# kv_lora_rank: int,
# rope_theta: float = 10000,
# rope_scaling: Optional[dict[str, Any]] = None,
# max_position_embeddings: int = 8192,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# prefix: str = "",
# topk_indices_buffer: Optional[torch.Tensor] = None,
# ) -> None:
# super(DeepseekV2MLAAttention,self).__init__()
# self.hidden_size = hidden_size
# self.qk_nope_head_dim = qk_nope_head_dim
# self.qk_rope_head_dim = qk_rope_head_dim
# self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
# self.v_head_dim = v_head_dim
# self.q_lora_rank = q_lora_rank
# self.kv_lora_rank = kv_lora_rank
# self.num_heads = num_heads
# tp_size = get_tensor_model_parallel_world_size()
# assert num_heads % tp_size == 0
# self.num_local_heads = num_heads // tp_size
# self.scaling = self.qk_head_dim**-0.5
# self.rope_theta = rope_theta
# self.max_position_embeddings = max_position_embeddings
# if self.q_lora_rank is not None:
# if USE_PARALLEL_Q_KV_GEN:
# self.q_a_proj = RowParallelLinear(self.hidden_size,
# self.q_lora_rank,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.q_a_proj")
# else:
# self.q_a_proj = ReplicatedLinear(self.hidden_size,
# self.q_lora_rank,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.q_a_proj")
# self.q_a_layernorm = RMSNorm(self.q_lora_rank,
# eps=config.rms_norm_eps)
# self.q_b_proj = ColumnParallelLinear(q_lora_rank,
# self.num_heads *
# self.qk_head_dim,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.q_b_proj")
# else:
# self.q_proj = ColumnParallelLinear(self.hidden_size,
# self.num_heads *
# self.qk_head_dim,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.q_proj")
# if USE_PARALLEL_Q_KV_GEN:
# self.kv_a_proj_with_mqa = RowParallelLinear(
# self.hidden_size,
# self.kv_lora_rank + self.qk_rope_head_dim,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.kv_a_proj_with_mqa")
# else:
# self.kv_a_proj_with_mqa = ReplicatedLinear(
# self.hidden_size,
# self.kv_lora_rank + self.qk_rope_head_dim,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.kv_a_proj_with_mqa")
# self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
# eps=config.rms_norm_eps)
# self.kv_b_proj = ColumnParallelLinear(
# self.kv_lora_rank,
# self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.kv_b_proj")
# self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
# self.hidden_size,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.o_proj")
# rope_scaling["rope_type"] = 'deepseek_yarn'
# self.rotary_emb = get_rope(qk_rope_head_dim,
# rotary_dim=qk_rope_head_dim,
# max_position=max_position_embeddings,
# base=rope_theta,
# rope_scaling=rope_scaling,
# is_neox_style=False)
# if rope_scaling:
# mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
# scaling_factor = rope_scaling["factor"]
# mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
# self.scaling = self.scaling * mscale * mscale
# self.is_v32 = hasattr(config, "index_topk")
# if self.is_v32:
# self.indexer = Indexer(vllm_config, config, hidden_size,
# q_lora_rank, quant_config, cache_config,
# topk_indices_buffer, f"{prefix}.indexer")
# else:
# self.indexer = None
# self.mla_attn = Attention(
# num_heads=self.num_local_heads,
# head_size=self.kv_lora_rank,
# scale=self.scaling,
# num_kv_heads=1,
# cache_config=cache_config,
# quant_config=quant_config,
# prefix=f"{prefix}.attn",
# use_mla=True,
# # MLA Args
# q_lora_rank=self.q_lora_rank,
# kv_lora_rank=self.kv_lora_rank,
# qk_nope_head_dim=self.qk_nope_head_dim,
# qk_rope_head_dim=self.qk_rope_head_dim,
# qk_head_dim=self.qk_head_dim,
# v_head_dim=self.v_head_dim,
# rotary_emb=self.rotary_emb,
# q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
# kv_b_proj=self.kv_b_proj,
# o_proj=self.o_proj,
# )
# self.prefix = prefix
# self.debug_layer_idx = int(self.prefix.split(".")[-2])
# def forward(
# self,
# positions: torch.Tensor,
# hidden_states: torch.Tensor,
# kv_cache: torch.Tensor,
# attn_metadata: AttentionMetadata,
# ) -> torch.Tensor:
# tp_size = get_tensor_model_parallel_world_size()
# rank_id = get_tensor_model_parallel_rank()
# last_dim = hidden_states.shape[-1]
# if USE_PARALLEL_Q_KV_GEN: #tp qa and kva
# hidden_states_split = hidden_states
# if tp_size > 1:
# hiddens_tp = last_dim//tp_size
# hidden_states_split = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp].contiguous()
# if self.q_lora_rank is not None:
# ckq = self.q_a_proj(hidden_states_split)[0]
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
# else:
# hidden_states_or_q_c = hidden_states
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states_split)[0].split(
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
# attn_metadata)
# if self.q_lora_rank is not None:
# ckq = self.q_a_proj(hidden_states)[0]
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
# else:
# hidden_states_or_q_c = hidden_states
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
# attn_metadata)
class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor, residual = None, rms_norm = None):
# moe layer support prefill&decode vacc ops
if residual is not None:
try:
reduce_result = self.tp_size > 1
# decode moe, first seq
if self.is_decode:
hidden_states, residual = vacc_fused_decode_moe_fp8(self, self.shared_experts,
hidden_states, residual,
rms_norm, self.gate, self.experts,
self.routed_scaling_factor,
reduce_result)
return hidden_states, residual
# prefill moe, first expert
else:
hidden_states, residual = vacc_fused_prefill_moe_fp8(self, self.shared_experts,
hidden_states, residual,
rms_norm, self.gate, self.experts,
self.routed_scaling_factor,
reduce_result)
return hidden_states, residual
except Exception as e:
logger.warning("vacc fused moe run fail, now use unfused ops %s", e)
hidden_states, residual = rms_norm(hidden_states, residual)
self.experts.is_decode = self.is_decode
# 1. fuse_prefill_pre_moe
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
try:
shared_output = vacc_fused_mlp_fp8(self.shared_experts, hidden_states, moe_share=True)
except Exception as e:
logger.warning("fused mlp is Error, now use Default:%s", e)
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
# 2. fused_moe
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
# 3. add_reduce
# now fuse share_mlp add to experts
# if shared_output is not None:
# # out = input + other * alpha
# final_hidden_states = shared_output.add_(final_hidden_states, alpha=self.routed_scaling_factor)
# else:
# final_hidden_states = final_hidden_states * self.routed_scaling_factor
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if residual is not None:
return final_hidden_states.view(num_tokens, hidden_dim), residual
return final_hidden_states.view(num_tokens, hidden_dim)
class DeepseekV2MLP(nn.Module):
def forward(self, x, residual = None, rms_norm = None):
# use all fused ops
if residual is not None:
reduce_result = self.down_proj.reduce_results and self.down_proj.tp_size > 1
hidden_states, residual = vacc_fused_mlp_fp8(self,
x, residual,
rms_norm,
reduce_result)
return hidden_states, residual
# use default fuse ops
try:
output_parallel = vacc_fused_mlp_fp8(self, x, residual, rms_norm)
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
x = tensor_model_parallel_all_reduce(output_parallel)
else:
x = output_parallel
except Exception as e:
logger.warning("fuse_mlp run fail, now use default: %s", e)
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
first_k_dense_replace = self.config.first_k_dense_replace if hasattr(self.config, "first_k_dense_replace") else 3
if not hasattr(self, "weight_capture"):
from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekWeightCapture
self.weight_capture = DeepseekWeightCapture(self.layers, self.start_layer, self.end_layer)
self.cached_weights_state = True
self.cached_batch = 1
self.layer_nums = self.end_layer - self.start_layer
self.is_pipeline_first = get_pp_group().is_first_rank
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
else:
# update global seq lens, use for serve infos
# update_seqence_length(attn_metadata.decode_metadata.seq_lens)
if FUSE_ALL_DECODER_LAYERS:
self.weight_capture.update_attn_args(attn_metadata.decode_metadata.seq_lens,
attn_metadata.slot_mapping,
[self.layers[i].self_attn.mla_attn.kv_cache[forward_context.virtual_engine] for i in range(self.start_layer, first_k_dense_replace)],
[self.layers[i].self_attn.mla_attn.kv_cache[forward_context.virtual_engine] for i in range(first_k_dense_replace, self.end_layer)],
attn_metadata.decode_metadata.block_tables)
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 0)
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 1)
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 2)
if hidden_states.shape[0] != self.cached_batch:
# batch切换重新执行缓存
self.cached_weights_state = True
self.cached_batch = hidden_states.shape[0]
if self.cached_weights_state:
self.cached_weights_state = False
hidden_states, residual = forward_mla_moe_layers_with_weights(hidden_states, residual, self.weight_capture)
else:
hidden_states, residual = forward_mla_moe_layers_without_weights(hidden_states, residual, self.weight_capture)
else:
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode,fuse_mla_moe_v2_allreduce_decode
for i in range(0, self.layer_nums):
layer_id = i + self.start_layer
layer = self.layers[layer_id]
kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine]
positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens]
cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions]
sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions]
if layer_id < first_k_dense_replace:
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode(
hidden_states = hidden_states,
residual = residual,
hidden_states_norm_weight = self.weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight[i],
q_a_proj_weight = self.weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights[i],
q_a_proj_weight_scale_inv = self.weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv[i],
q_a_layernorm_weight = self.weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight[i],
w_q = self.weight_capture.layer_mlp.attn_args._3_W_Q[i],
w_q_scale = self.weight_capture.layer_mlp.attn_args._4_W_Q_scales[i],
w_uk = self.weight_capture.layer_mlp.attn_args._5_W_UK[i],
w_uk_scale = self.weight_capture.layer_mlp.attn_args._6_W_UK_scales[i],
w_qr = self.weight_capture.layer_mlp.attn_args._7_W_QR[i],
w_qr_scale = self.weight_capture.layer_mlp.attn_args._8_W_QR_scales[i],
kv_a_layernorm_weight = self.weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight[i],
sin_cache = sin_cache,# self.weight_capture.layer_mlp.attn_args._10_sin_cache,
cos_cache = cos_cache,# self.weight_capture.layer_mlp.attn_args._11_cos_cache,
slot_mapping = attn_metadata.slot_mapping,#self.weight_capture.layer_mlp.attn_args._12_slot_mapping[i],
kv_cache = kv_cache,#self.weight_capture.layer_mlp.attn_args._13_kv_cache[i],
block_tables = attn_metadata.decode_metadata.block_tables,#self.weight_capture.layer_mlp.attn_args._14_block_tables[i],
block_group_size = self.weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
w_uv = self.weight_capture.layer_mlp.attn_args._16_W_UV[i],
w_uv_scale = self.weight_capture.layer_mlp.attn_args._17_W_UV_scales[i],
o_proj_weight = self.weight_capture.layer_mlp.attn_args._18_o_proj_weight[i],
o_proj_weight_scale_inv = self.weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv[i],
# mla params
seq_lens = attn_metadata.decode_metadata.seq_lens,
sm_scale = self.weight_capture.layer_mlp.attn_args._21_sm_scale,
head_num = self.weight_capture.layer_mlp.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# mlp weight
rms_weight = self.weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight[i],
mlp_weight_13 = self.weight_capture.layer_mlp.mlp_args._1_mlp_w13[i],
mlp_weight_2 = self.weight_capture.layer_mlp.mlp_args._2_mlp_w2[i],
mlp_weight_scale_13 = self.weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale[i],
mlp_weight_scale_2 = self.weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale[i],
# mlp params
mlp_block_size_w13 = self.weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
mlp_block_size_w2 = self.weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
# vccl info
world_size = self.weight_capture.layer_mlp.dist_args._0_world_size,
rank = self.weight_capture.layer_mlp.dist_args._1_rank,
group_id = self.weight_capture.layer_mlp.dist_args._2_group_id,
dev_info = self.weight_capture.layer_mlp.dist_args._3_dev_info)
else:
wid = i - first_k_dense_replace if self.is_pipeline_first else i
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode(
hidden_states = hidden_states,
residual = residual,
hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[wid],
q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[wid],
q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[wid],
q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[wid],
w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[wid],
w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[wid],
w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[wid],
w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[wid],
w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[wid],
w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[wid],
kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[wid],
sin_cache = sin_cache,# self.weight_capture.layer_mlp.attn_args._10_sin_cache,
cos_cache = cos_cache,# self.weight_capture.layer_mlp.attn_args._11_cos_cache,
slot_mapping = attn_metadata.slot_mapping,#self.weight_capture.layer_mlp.attn_args._12_slot_mapping[i],
kv_cache = kv_cache,#self.weight_capture.layer_mlp.attn_args._13_kv_cache[i],
block_tables = attn_metadata.decode_metadata.block_tables,
block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[wid],
w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[wid],
o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[wid],
o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[wid],
# mla params
seq_lens = attn_metadata.decode_metadata.seq_lens,
sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale,
head_num = self.weight_capture.layer_moe.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# moe weight
rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[wid],
mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[wid],
mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[wid],
mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[wid],
mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[wid],
moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[wid],
moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[wid],
moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[wid],
moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[wid],
mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[wid],
moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[wid],
# moe params
mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
# vccl info
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
rank = self.weight_capture.layer_moe.dist_args._1_rank,
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
from .memory.memory_recycling import init_huge_memory_allocator
from .vars import LLM_MAX_PREFILL_SEQ_LEN
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
# default is deepseek, config can set to ['deepseek_mtp',]
model_name = "deepseek"
config_infos = vllm_vacc_config_manager().get_model_infos()
if config_infos != "default":
if config_infos in ['mtp']:
model_name = "deepseek_mtp"
else:
model_name = config_infos
if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name):
logger.warning("init huge memory allocator fail. prefill memory recycling will disable")
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if test_layer_en == "1":
test_layer = 5
if name not in ['model.embed_tokens.weight', 'model.norm.weight', 'lm_head.weight']:
if int(name.split(".")[2]) > test_layer:
continue
# TODO(simon): support nextn predict layers
if hasattr(self.config, "num_nextn_predict_layers"
) and self.config.num_nextn_predict_layers > 0:
assert self.config.num_nextn_predict_layers == 1
layer_idx = self.config.num_hidden_layers
if name.startswith(f"model.layers.{layer_idx}"):
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if USE_MERGE_Q_KV_GEN_AND_Q_QR:
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
layer.self_attn.merge_qkv_weights()
return loaded_params
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
if attn_metadata.prefill_metadata is not None:
from .memory.memory_recycling import alloc_memory_recycler
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
if hasattr(attn_metadata, 'num_prefill_tokens'):
tokens = attn_metadata.num_prefill_tokens
else:
tokens = attn_metadata.prefill_metadata.num_prefill_tokens
vllm_model_mode = "deepseek"
config_infos = vllm_vacc_config_manager().get_model_infos()
if config_infos != "default":
if config_infos in ['mtp']:
vllm_model_mode = "deepseek_mtp"
else:
vllm_model_mode = config_infos
if get_tp_group().rank_in_group == 0:
memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}'
logger.info(memory_infos)
if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size):
logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,216 @@
import torch
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode_layers,fuse_mla_moe_v2_allreduce_decode_layers,fuse_mla_moe_v2_allreduce_decode_layers_v2
from typing import Optional
from .weight_capture.deepseek_weight_capture import DeepseekWeightCapture
import time
from .vars import *
# 单层 mla + mlp
def forward_mla_mlp_single_layer(hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
weight_capture : DeepseekWeightCapture,
layer_id: int):
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode(
hidden_states = hidden_states,
residual = residual,
hidden_states_norm_weight = weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight[layer_id],
q_a_proj_weight = weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights[layer_id],
q_a_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv[layer_id],
q_a_layernorm_weight = weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight[layer_id],
w_q = weight_capture.layer_mlp.attn_args._3_W_Q[layer_id],
w_q_scale = weight_capture.layer_mlp.attn_args._4_W_Q_scales[layer_id],
w_uk = weight_capture.layer_mlp.attn_args._5_W_UK[layer_id],
w_uk_scale = weight_capture.layer_mlp.attn_args._6_W_UK_scales[layer_id],
w_qr = weight_capture.layer_mlp.attn_args._7_W_QR[layer_id],
w_qr_scale = weight_capture.layer_mlp.attn_args._8_W_QR_scales[layer_id],
kv_a_layernorm_weight = weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight[layer_id],
sin_cache = weight_capture.layer_mlp.attn_args._10_sin_cache,
cos_cache = weight_capture.layer_mlp.attn_args._11_cos_cache,
slot_mapping = weight_capture.layer_mlp.attn_args._12_slot_mapping,
kv_cache = weight_capture.layer_mlp.attn_args._13_kv_cache[layer_id],
block_tables = weight_capture.layer_mlp.attn_args._14_block_tables,
block_group_size = weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
w_uv = weight_capture.layer_mlp.attn_args._16_W_UV[layer_id],
w_uv_scale = weight_capture.layer_mlp.attn_args._17_W_UV_scales[layer_id],
o_proj_weight = weight_capture.layer_mlp.attn_args._18_o_proj_weight[layer_id],
o_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv[layer_id],
# mla params
seq_lens = weight_capture.layer_mlp.attn_args._20_seq_lens,
sm_scale = weight_capture.layer_mlp.attn_args._21_sm_scale,
head_num = weight_capture.layer_mlp.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# mlp weight
rms_weight = weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight[layer_id],
mlp_weight_13 = weight_capture.layer_mlp.mlp_args._1_mlp_w13[layer_id],
mlp_weight_2 = weight_capture.layer_mlp.mlp_args._2_mlp_w2[layer_id],
mlp_weight_scale_13 = weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale[layer_id],
mlp_weight_scale_2 = weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale[layer_id],
# mlp params
mlp_block_size_w13 = weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
mlp_block_size_w2 = weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
# vccl info
world_size = weight_capture.layer_mlp.dist_args._0_world_size,
rank = weight_capture.layer_mlp.dist_args._1_rank,
group_id = weight_capture.layer_mlp.dist_args._2_group_id,
dev_info = weight_capture.layer_mlp.dist_args._3_dev_info)
return hidden_states, residual
# 多层 mla + mlp
def forward_mla_mlp_layers(hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
weight_capture : DeepseekWeightCapture):
if residual == None:
residual = torch.zeros_like(hidden_states)
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode_layers(
hidden_states = hidden_states,
residual = residual,
hidden_states_norm_weight = weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight,
q_a_proj_weight = weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights,
q_a_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv,
q_a_layernorm_weight = weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight,
w_q = weight_capture.layer_mlp.attn_args._3_W_Q,
w_q_scale = weight_capture.layer_mlp.attn_args._4_W_Q_scales,
w_uk = weight_capture.layer_mlp.attn_args._5_W_UK,
w_uk_scale = weight_capture.layer_mlp.attn_args._6_W_UK_scales,
w_qr = weight_capture.layer_mlp.attn_args._7_W_QR,
w_qr_scale = weight_capture.layer_mlp.attn_args._8_W_QR_scales,
kv_a_layernorm_weight = weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight,
sin_cache = weight_capture.layer_mlp.attn_args._10_sin_cache,
cos_cache = weight_capture.layer_mlp.attn_args._11_cos_cache,
slot_mapping = weight_capture.layer_mlp.attn_args._12_slot_mapping,
kv_cache = weight_capture.layer_mlp.attn_args._13_kv_cache,
block_tables = weight_capture.layer_mlp.attn_args._14_block_tables,
block_group_size = weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
w_uv = weight_capture.layer_mlp.attn_args._16_W_UV,
w_uv_scale = weight_capture.layer_mlp.attn_args._17_W_UV_scales,
o_proj_weight = weight_capture.layer_mlp.attn_args._18_o_proj_weight,
o_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv,
# mla params
seq_lens = weight_capture.layer_mlp.attn_args._20_seq_lens,
sm_scale = weight_capture.layer_mlp.attn_args._21_sm_scale,
head_num = weight_capture.layer_mlp.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# mlp weight
rms_weight = weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight,
mlp_weight_13 = weight_capture.layer_mlp.mlp_args._1_mlp_w13,
mlp_weight_2 = weight_capture.layer_mlp.mlp_args._2_mlp_w2,
mlp_weight_scale_13 = weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale,
mlp_weight_scale_2 = weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale,
# mlp params
mlp_block_size_w13 = weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
mlp_block_size_w2 = weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
# vccl info
world_size = weight_capture.layer_mlp.dist_args._0_world_size,
rank = weight_capture.layer_mlp.dist_args._1_rank,
group_id = weight_capture.layer_mlp.dist_args._2_group_id,
dev_info = weight_capture.layer_mlp.dist_args._3_dev_info)
return hidden_states, residual
# 多层 mla + moe, 未缓存weights
def forward_mla_moe_layers_with_weights(hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
weight_capture : DeepseekWeightCapture):
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode_layers(
hidden_states = hidden_states,
residual = residual,
hidden_states_norm_weight = weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight,
q_a_proj_weight = weight_capture.layer_moe.attn_args._0_merge_q_kv_weights,
q_a_proj_weight_scale_inv = weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv,
q_a_layernorm_weight = weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight,
w_q = weight_capture.layer_moe.attn_args._3_W_Q,
w_q_scale = weight_capture.layer_moe.attn_args._4_W_Q_scales,
w_uk = weight_capture.layer_moe.attn_args._5_W_UK,
w_uk_scale = weight_capture.layer_moe.attn_args._6_W_UK_scales,
w_qr = weight_capture.layer_moe.attn_args._7_W_QR,
w_qr_scale = weight_capture.layer_moe.attn_args._8_W_QR_scales,
kv_a_layernorm_weight = weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight,
sin_cache = weight_capture.layer_moe.attn_args._10_sin_cache,
cos_cache = weight_capture.layer_moe.attn_args._11_cos_cache,
slot_mapping = weight_capture.layer_moe.attn_args._12_slot_mapping,
kv_cache = weight_capture.layer_moe.attn_args._13_kv_cache,
block_tables = weight_capture.layer_moe.attn_args._14_block_tables,
block_group_size = weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
w_uv = weight_capture.layer_moe.attn_args._16_W_UV,
w_uv_scale = weight_capture.layer_moe.attn_args._17_W_UV_scales,
o_proj_weight = weight_capture.layer_moe.attn_args._18_o_proj_weight,
o_proj_weight_scale_inv = weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv,
# mla params
seq_lens = weight_capture.layer_moe.attn_args._20_seq_lens,
sm_scale = weight_capture.layer_moe.attn_args._21_sm_scale,
head_num = weight_capture.layer_moe.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# moe weight
rms_weight = weight_capture.layer_moe.moe_args._0_moe_rms_weight,
mlp_weight_13 = weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13,
mlp_weight_2 = weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2,
mlp_weight_scale_13 = weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale,
mlp_weight_scale_2 = weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale,
moe_weight_13 = weight_capture.layer_moe.moe_args._5_moe_w13,
moe_weight_2 = weight_capture.layer_moe.moe_args._6_moe_w2,
moe_weight_scale_13 = weight_capture.layer_moe.moe_args._7_moe_w13_scale,
moe_weight_scale_2 = weight_capture.layer_moe.moe_args._8_moe_w2_scale,
mm_weight = weight_capture.layer_moe.moe_args._9_gate_weight,
moe_bias = weight_capture.layer_moe.moe_args._10_moe_bias,
# moe params
mlp_block_size_w13 = weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
mlp_block_size_w2 = weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
moe_block_size_w13 = weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
moe_block_size_w2 = weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
# vccl info
world_size = weight_capture.layer_moe.dist_args._0_world_size,
rank = weight_capture.layer_moe.dist_args._1_rank,
group_id = weight_capture.layer_moe.dist_args._2_group_id,
dev_info = weight_capture.layer_moe.dist_args._3_dev_info)
return hidden_states, residual
# 多层 mla + moe 缓存weights必须要在未缓存weights算子执行之后才可以调用
def forward_mla_moe_layers_without_weights(hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
weight_capture : DeepseekWeightCapture):
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode_layers_v2(
hidden_states = hidden_states,
residual = residual,
sin_cache = weight_capture.layer_moe.attn_args._10_sin_cache,
cos_cache = weight_capture.layer_moe.attn_args._11_cos_cache,
slot_mapping = weight_capture.layer_moe.attn_args._12_slot_mapping,
kv_cache = weight_capture.layer_moe.attn_args._13_kv_cache,
block_tables = weight_capture.layer_moe.attn_args._14_block_tables,
block_group_size = weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
# mla params
seq_lens = weight_capture.layer_moe.attn_args._20_seq_lens,
sm_scale = weight_capture.layer_moe.attn_args._21_sm_scale,
head_num = weight_capture.layer_moe.attn_args._22_head_num,
# flash_attention
flash_attention = (USE_FLASH_ATTENTION==1),
# moe weight
# moe params
mlp_block_size_w13 = weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
mlp_block_size_w2 = weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
moe_block_size_w13 = weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
moe_block_size_w2 = weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
# vccl info
world_size = weight_capture.layer_moe.dist_args._0_world_size,
rank = weight_capture.layer_moe.dist_args._1_rank,
group_id = weight_capture.layer_moe.dist_args._2_group_id,
dev_info = weight_capture.layer_moe.dist_args._3_dev_info)
return hidden_states, residual
def forward_deepseekv3(model: torch.nn.Module,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
weight_capture : DeepseekWeightCapture):
hidden_states, residual = forward_mla_mlp_layers(hidden_states, residual, weight_capture)
hidden_states, residual = forward_mla_moe_layers_with_weights(hidden_states, residual, weight_capture)
return hidden_states, residual

View File

@@ -0,0 +1,239 @@
import warnings
from transformers.utils import (
CONFIG_NAME,
IMAGE_PROCESSOR_NAME,
cached_file,
is_timm_config_dict,
is_timm_local_checkpoint,
is_torchvision_available,
is_vision_available,
logging,
)
from transformers.models.auto.configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
from transformers.image_processing_utils import ImageProcessingMixin
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto.image_processing_auto import (
AutoImageProcessor,
logger,
FORCE_FAST_IMAGE_PROCESSOR,
IMAGE_PROCESSOR_MAPPING_NAMES,
IMAGE_PROCESSOR_MAPPING,
get_image_processor_class_from_name,
resolve_trust_remote_code,
_warning_fast_image_processor_available,
get_class_from_dynamic_module
)
def check_vacc_support_module(module_class):
if module_class.__name__ == "Qwen2VLImageProcessorFast":
from .qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
return Qwen2VLImageProcessorFastWithVacc
return module_class
"""AutoImageProcessor class."""
class AutoImageProcessorWithVacc(AutoImageProcessor):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if kwargs.get("token") is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
kwargs["token"] = use_auth_token
config = kwargs.pop("config", None)
# TODO: @yoni, change in v4.48 (use_fast set to True by default)
use_fast = kwargs.pop("use_fast", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True
# Resolve the image processor config filename
if "image_processor_filename" in kwargs:
image_processor_filename = kwargs.pop("image_processor_filename")
elif is_timm_local_checkpoint(pretrained_model_name_or_path):
image_processor_filename = CONFIG_NAME
else:
image_processor_filename = IMAGE_PROCESSOR_NAME
# Load the image processor config
try:
# Main path for all transformers models and local TimmWrapper checkpoints
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
)
except Exception as initial_exception:
# Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
# instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
# except the model name, the only way to check if a remote checkpoint is a timm model is to try to
# load `config.json` and if it fails with some error, we raise the initial exception.
try:
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
)
except Exception:
raise initial_exception
# In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
# because only timm models have image processing in `config.json`.
if not is_timm_config_dict(config_dict):
raise initial_exception
image_processor_type = config_dict.get("image_processor_type", None)
# 跳转vacc预处理算子相关替换
# if image_processor_type == "Qwen2VLImageProcessorFast":
# from .qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
# return Qwen2VLImageProcessorFastWithVacc.from_dict(config_dict, **kwargs)
image_processor_auto_map = None
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
# If we still don't have the image processor class, check if we're loading from a previous feature extractor config
# and if so, infer the image processor class from there.
if image_processor_type is None and image_processor_auto_map is None:
feature_extractor_class = config_dict.pop("feature_extractor_type", None)
if feature_extractor_class is not None:
image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor")
# If we don't find the image processor class in the image processor config, let's try the model config.
if image_processor_type is None and image_processor_auto_map is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
# It could be in `config.image_processor_type``
image_processor_type = getattr(config, "image_processor_type", None)
if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map:
image_processor_auto_map = config.auto_map["AutoImageProcessor"]
image_processor_class = None
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
if image_processor_type is not None:
# if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
if use_fast is None:
use_fast = image_processor_type.endswith("Fast")
if not use_fast and image_processor_type in FORCE_FAST_IMAGE_PROCESSOR and is_torchvision_available():
use_fast = True
logger.warning_once(
f"The image processor of type `{image_processor_type}` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. "
"This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. "
"Note that this behavior will be extended to all models in a future release."
)
if not use_fast:
logger.warning_once(
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
)
if use_fast and not image_processor_type.endswith("Fast"):
image_processor_type += "Fast"
if use_fast and not is_torchvision_available():
# check if there is a slow image processor class to fallback to
image_processor_class = get_image_processor_class_from_name(image_processor_type[:-4])
if image_processor_class is None:
raise ValueError(
f"`{image_processor_type}` requires `torchvision` to be installed. Please install `torchvision` and try again."
)
logger.warning_once(
"Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor."
)
use_fast = False
if use_fast:
for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
if image_processor_type in image_processors:
break
else:
image_processor_type = image_processor_type[:-4]
use_fast = False
logger.warning_once(
"`use_fast` is set to `True` but the image processor class does not have a fast version. "
" Falling back to the slow version."
)
image_processor_class = get_image_processor_class_from_name(image_processor_type)
else:
image_processor_type_slow = image_processor_type.removesuffix("Fast")
image_processor_class = get_image_processor_class_from_name(image_processor_type_slow)
if image_processor_class is None and image_processor_type.endswith("Fast"):
raise ValueError(
f"`{image_processor_type}` does not have a slow version. Please set `use_fast=True` when instantiating the processor."
)
has_remote_code = image_processor_auto_map is not None
has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
if has_remote_code:
if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple):
# In some configs, only the slow image processor class is stored
image_processor_auto_map = (image_processor_auto_map, None)
if use_fast and image_processor_auto_map[1] is not None:
class_ref = image_processor_auto_map[1]
else:
class_ref = image_processor_auto_map[0]
if "--" in class_ref:
upstream_repo = class_ref.split("--")[0]
else:
upstream_repo = None
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
)
if has_remote_code and trust_remote_code:
if not use_fast and image_processor_auto_map[1] is not None:
_warning_fast_image_processor_available(image_processor_auto_map[1])
image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
_ = kwargs.pop("code_revision", None)
image_processor_class.register_for_auto_class()
# check preprocess module supported by vacc
image_processor_class = check_vacc_support_module(image_processor_class)
return image_processor_class.from_dict(config_dict, **kwargs)
elif image_processor_class is not None:
# check preprocess module supported by vacc
image_processor_class = check_vacc_support_module(image_processor_class)
return image_processor_class.from_dict(config_dict, **kwargs)
# Last try: we use the IMAGE_PROCESSOR_MAPPING.
elif type(config) in IMAGE_PROCESSOR_MAPPING:
image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)]
image_processor_class_py, image_processor_class_fast = image_processor_tuple
if not use_fast and image_processor_class_fast is not None:
_warning_fast_image_processor_available(image_processor_class_fast)
if image_processor_class_fast and (use_fast or image_processor_class_py is None):
# check preprocess module supported by vacc
image_processor_class_fast = check_vacc_support_module(image_processor_class_fast)
return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
if image_processor_class_py is not None:
# check preprocess module supported by vacc
image_processor_class_py = check_vacc_support_module(image_processor_class_py)
return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
raise ValueError(
"This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
)
raise ValueError(
f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a "
f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in IMAGE_PROCESSOR_MAPPING_NAMES)}"
)

View File

@@ -0,0 +1,402 @@
# coding=utf-8
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for Qwen2-VL."""
from typing import Optional, Union
import torch
from torchvision.transforms.v2 import functional as F
from transformers.image_processing_utils import BatchFeature
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
TensorType,
auto_docstring,
logging,
)
from transformers.video_utils import VideoInput, make_batched_videos
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import (
Qwen2VLFastImageProcessorKwargs,
logger
)
# reseize_normalize_repeat_transpose_reshape
def fuse_qwen2_vl_preprocess_img_cpu(
image: "torch.Tensor",
do_resize: bool,
min_pixels: int,
max_pixels: int,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
resized_height: int,
resized_width: int,
interpolation: Optional["F.InterpolationMode"],
patch_size: int,
temporal_patch_size: int,
merge_size: int,
image_mean_0: float,
image_mean_1: float,
image_mean_2: float,
image_std_0: float,
image_std_1: float,
image_std_2: float,
batch_size: int = 1,
grid_t: int = 1,
channel: int = 3,
) -> "torch.Tensor":
def resize(
image: "torch.Tensor",
size_h: int,
size_w: int,
interpolation: Optional["F.InterpolationMode"] = None,
antialias: bool = True,
) -> "torch.Tensor":
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
if size_h and size_w:
new_size = (size_h, size_w)
else:
raise ValueError(
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
f" {size_h} and {size_w}."
)
return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
def fuse_mean_std_and_rescale_factor(
image_mean_0: float,
image_mean_1: float,
image_mean_2: float,
image_std_0: float,
image_std_1: float,
image_std_2: float,
do_normalize: Optional[bool] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
device: Optional["torch.device"] = None,
) -> tuple:
if do_rescale and do_normalize:
image_mean = torch.tensor([image_mean_0, image_mean_1, image_mean_2], device=device)
image_std = torch.tensor([image_std_0, image_std_1, image_std_2], device=device)
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
return image_mean, image_std
def rescale_and_normalize(
images: "torch.Tensor",
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean_0: float,
image_mean_1: float,
image_mean_2: float,
image_std_0: float,
image_std_1: float,
image_std_2: float
) -> "torch.Tensor":
image_mean, image_std = fuse_mean_std_and_rescale_factor(
image_mean_0, image_mean_1, image_mean_2,
image_std_0, image_std_1, image_std_2,
do_normalize,
do_rescale,
rescale_factor,
device=images.device,
)
if do_normalize:
images = F.normalize(images.to(dtype=torch.float32), image_mean, image_std)
return images
if image.dim() == 3:
image = image.unsqueeze(0)
stacked_images = resize(
image=image,
size_h=resized_height,
size_w=resized_width,
interpolation=interpolation, # BICUBIC插值
)
patches = rescale_and_normalize(
stacked_images,
do_rescale,
rescale_factor,
do_normalize,
image_mean_0, image_mean_1, image_mean_2,
image_std_0, image_std_1, image_std_2
)
if patches.ndim == 4:
patches = patches.unsqueeze(1)
if patches.shape[1] % temporal_patch_size != 0:
repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)
patches = torch.cat([patches, repeats], dim=1)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
patches = patches.view(
batch_size, # 1 - 批次大小
grid_t, # 1 - 时间网格数
temporal_patch_size, # 2 - 时间块大小
channel, # 3 - 通道数(RGB)
grid_h // merge_size, # 12 // 2 = 6 - 高度方向合并后的网格数
merge_size, # 2 - 高度方向合并大小
patch_size, # 14 - 高度方向块大小
grid_w // merge_size, # 38 // 2 = 19 - 宽度方向合并后的网格数
merge_size, # 2 - 宽度方向合并大小
patch_size, # 14 - 宽度方向块大小
)
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
flatten_patches = patches.reshape(
# batch_size, # 1
grid_t * grid_h * grid_w, # 1 * 12 * 38 = 456 (总网格数)
channel * temporal_patch_size * patch_size * patch_size, # 3 * 2 * 14 * 14 = 1176 (每个网格的特征维度)
)
return flatten_patches
class Qwen2VLImageProcessorFastWithVacc(BaseImageProcessorFast):
do_resize = True
resample = PILImageResampling.BICUBIC
size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
do_rescale = True
do_normalize = True
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
do_convert_rgb = True
patch_size = 14
temporal_patch_size = 2
merge_size = 2
min_pixels = None
max_pixels = None
valid_kwargs = Qwen2VLFastImageProcessorKwargs
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
def __init__(self, **kwargs: Unpack[Qwen2VLFastImageProcessorKwargs]):
size = kwargs.pop("size", None)
min_pixels = kwargs.pop("min_pixels", None)
max_pixels = kwargs.pop("max_pixels", None)
# backward compatibility: override size with min_pixels and max_pixels if they are provided
size = self.size if size is None else size
if min_pixels is not None:
size["shortest_edge"] = min_pixels
size.pop("min_pixels", None)
if max_pixels is not None:
size["longest_edge"] = max_pixels
size.pop("max_pixels", None)
if "shortest_edge" not in size or "longest_edge" not in size:
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
super().__init__(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
def _further_process_kwargs(
self,
size: Optional[SizeDict] = None,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
**kwargs,
) -> dict:
"""
Update kwargs that need further processing before being validated
Can be overridden by subclasses to customize the processing of kwargs.
"""
if min_pixels is not None and max_pixels is not None:
size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
elif size is not None:
if "shortest_edge" not in size or "longest_edge" not in size:
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
min_pixels = size["shortest_edge"]
max_pixels = size["longest_edge"]
else:
size = {**self.size}
return super()._further_process_kwargs(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
def preprocess(
self,
images: ImageInput,
videos: Optional[VideoInput] = None,
**kwargs: Unpack[Qwen2VLFastImageProcessorKwargs],
) -> BatchFeature:
return super().preprocess(images, videos, **kwargs)
def _preprocess_image_like_inputs(
self,
images: ImageInput,
videos: VideoInput,
do_convert_rgb: bool,
input_data_format: ChannelDimension,
device: Optional[Union[str, "torch.device"]] = None,
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
) -> BatchFeature:
"""
Preprocess image-like inputs.
To be overridden by subclasses when image-like inputs other than images should be processed.
It can be used for segmentation maps, depth maps, etc.
"""
# Prepare input images
batch_feature = BatchFeature()
if images is not None:
images = self._prepare_image_like_inputs(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
batch_feature = self._preprocess(images, **kwargs)
if videos is not None:
logger.warning(
"`Qwen2VLImageProcessorFast` works only with image inputs and doesn't process videos anymore. "
"This is a deprecated behavior and will be removed in v5.0. "
"Your videos should be forwarded to `Qwen2VLVideoProcessor`. "
)
# Can't change _prepare_images_structure to work with videos because it also needs to work with images.
videos = make_batched_videos(videos)
videos = [
torch.stack(self._prepare_image_like_inputs(video, do_convert_rgb, input_data_format, device))
for video in videos
]
video_outputs = self._preprocess(videos, **kwargs)
batch_feature.update(
{"pixel_values_videos": video_outputs.pixel_values, "video_grid_thw": video_outputs.image_grid_thw}
)
return batch_feature
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
patch_size: int,
temporal_patch_size: int,
merge_size: int,
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
):
min_pixels=size["shortest_edge"],
max_pixels=size["longest_edge"],
processed_images = []
processed_grids = []
for img in images:
height, width = img.shape[-2:]
if do_resize:
resized_height, resized_width = smart_resize(
height,
width,
factor=patch_size * merge_size,
min_pixels=min_pixels[0],
max_pixels=max_pixels[0],
)
# reseize_normalize_repeat = fuse_qwen2_vl_preprocess_img_cpu(
# img,
# do_resize,
# min_pixels,
# max_pixels,
# do_rescale,
# rescale_factor,
# do_normalize,
# resized_height,
# resized_width,
# interpolation,
# patch_size,
# temporal_patch_size,
# merge_size,
# image_mean[0], image_mean[1], image_mean[2],
# image_std[0], image_std[1], image_std[2],
# 1,1,3
# )
import torch_vacc
if img.device.type != "vacc":
img = img.to("vacc")
reseize_normalize_repeat = torch_vacc.vacc.custom_qwen3_ops.qwen2vl_img_preprocess(img,
do_resize,
min_pixels[0],
max_pixels[0],
do_rescale,
rescale_factor,
do_normalize,
resized_height,
resized_width,
0x1003, # interpolation,
patch_size,
temporal_patch_size,
merge_size,
image_mean[0], image_mean[1], image_mean[2],
image_std[0], image_std[1], image_std[2] )
processed_images.append(reseize_normalize_repeat)
grid_t = 1
grid_h = resized_height // patch_size
grid_w = resized_width // patch_size
grid_thw_ = torch.tensor([[grid_t, grid_h, grid_w]])
processed_grids.append(grid_thw_)
pixel_values = torch.cat(processed_images, dim=0)
image_grid_thw = torch.cat(processed_grids, dim=0)
return BatchFeature(
data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors
)
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Note: Do not remove this method! It is used by vLLM to infer the number of patches and placeholders
without an image input.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of image patches per image.
"""
min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"]
max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"]
patch_size = images_kwargs.get("patch_size", self.patch_size)
merge_size = images_kwargs.get("merge_size", self.merge_size)
factor = patch_size * merge_size
resized_height, resized_width = smart_resize(
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
return grid_h * grid_w
__all__ = ["Qwen2VLImageProcessorFastWithVacc"]

View File

@@ -0,0 +1,91 @@
from transformers.models.qwen3_vl import Qwen3VLProcessor
# from transformers.models.auto.image_processing_auto import AutoImageProcessor
from transformers.models.qwen2_vl import Qwen2VLProcessor
class Qwen3VLProcessorWithVacc(Qwen3VLProcessor):
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
@classmethod
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Identify and instantiate the subcomponents of Processor classes, like image processors and
tokenizers. This method uses the Processor attributes like `tokenizer_class` to figure out what class those
subcomponents should be. Note that any subcomponents must either be library classes that are accessible in
the `transformers` root, or they must be custom code that has been registered with the relevant autoclass,
via methods like `AutoTokenizer.register()`. If neither of these conditions are fulfilled, this method
will be unable to find the relevant subcomponent class and will raise an error.
"""
args = []
for attribute_name in cls.attributes:
class_name = getattr(cls, f"{attribute_name}_class")
if isinstance(class_name, tuple):
classes = tuple(cls.get_possibly_dynamic_module(n) if n is not None else None for n in class_name)
if attribute_name == "image_processor":
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
use_fast = kwargs.get("use_fast")
if use_fast is None:
print(
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
)
else:
use_fast = kwargs.get("use_fast", True)
if use_fast and classes[1] is not None:
attribute_class = classes[1]
else:
attribute_class = classes[0]
else:
attribute_class = cls.get_possibly_dynamic_module(class_name)
if attribute_class.__name__ == "AutoImageProcessor":
from .auto_image_preprocessor import AutoImageProcessorWithVacc
attribute_class = AutoImageProcessorWithVacc
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
return args
class Qwen2VLProcessorWithVacc(Qwen2VLProcessor):
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
@classmethod
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Identify and instantiate the subcomponents of Processor classes, like image processors and
tokenizers. This method uses the Processor attributes like `tokenizer_class` to figure out what class those
subcomponents should be. Note that any subcomponents must either be library classes that are accessible in
the `transformers` root, or they must be custom code that has been registered with the relevant autoclass,
via methods like `AutoTokenizer.register()`. If neither of these conditions are fulfilled, this method
will be unable to find the relevant subcomponent class and will raise an error.
"""
args = []
for attribute_name in cls.attributes:
class_name = getattr(cls, f"{attribute_name}_class")
if isinstance(class_name, tuple):
classes = tuple(cls.get_possibly_dynamic_module(n) if n is not None else None for n in class_name)
if attribute_name == "image_processor":
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
use_fast = kwargs.get("use_fast")
if use_fast is None:
print(
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
)
else:
use_fast = kwargs.get("use_fast", True)
if use_fast and classes[1] is not None:
attribute_class = classes[1]
else:
attribute_class = classes[0]
else:
attribute_class = cls.get_possibly_dynamic_module(class_name)
if attribute_class.__name__ == "AutoImageProcessor":
from .auto_image_preprocessor import AutoImageProcessorWithVacc
attribute_class = AutoImageProcessorWithVacc
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
return args

View File

@@ -0,0 +1,235 @@
import torch
import torch.nn as nn
from typing import List
class VaccHugeMemoryAllocator(nn.Module):
# self._active_bytes means the real tensor buffers used bytes.
# you can use this value to slice the src buffer
# you can not free the self._src_buffer_array, because the src buffer is the max buffer
# self._block_bytes means the part max buffer size
def __init__(self, blocks, dtype = torch.bfloat16, use_contiguous = False):
self._total_blocks = blocks
self._dtype = dtype
self._enable = False
self._src_buffer_array = None
self._block_bytes = 0
self._max_tokens = 0
self._hiddens = 0
self._active_bytes = 0
self._use_contiguous_buffer = use_contiguous
# max tokens for dynamic buffer size
# dynamic buffer size is bigger than normal buffer usually
self._dynamic_max_tokens = 0
self._dynamic_block_bytes = 0
# malloc the max buffer, and not free
def init_buffers(self, max_tokens, hiddens):
self._max_tokens = max_tokens
self._hiddens = hiddens
try:
import torch_vacc
self._block_bytes = self._max_tokens * self._hiddens \
* self.get_dtype_bytes(self._dtype)
self._all_bytes = self._block_bytes * self._total_blocks
if self._use_contiguous_buffer:
# 一次性申请[N*3,]大小的BytesBuffer
self._src_buffer = torch.zeros(self._all_bytes,
dtype = torch.uint8,
device = "vacc")
tmp_buffer_array = self._src_buffer.view(self._total_blocks, -1)
self._src_buffer_array = [tmp_buffer_array[i]
for i in range(self._total_blocks)]
else:
# 一次性申请3块[N,]大小的BytesBuffer
self._src_buffer_array = [torch.zeros(self.block_bytes,
dtype = torch.uint8,
device = "vacc")
for i in range(self._total_blocks)]
self._enable = True
except Exception as e:
print(f"vacc huge buffer alloc fail: {e}")
# 为有dynamic buffer需求的网络设计 dynamic buffer 可能会要比普通的max_tokens buffers大一些
def init_buffers_with_dynamic(self, max_tokens, dynamic_tokens, hiddens, dynamic_buffers_mask: List):
self._max_tokens = max_tokens
self._dynamic_max_tokens = dynamic_tokens
self._hiddens = hiddens
dynamic_buffers_count = sum(dynamic_buffers_mask)
normal_buffers_count = self._total_blocks - dynamic_buffers_count
self._block_bytes = self._max_tokens * self._hiddens \
* self.get_dtype_bytes(self._dtype)
self._dynamic_block_bytes = self._dynamic_max_tokens * self._hiddens \
* self.get_dtype_bytes(self._dtype)
self._all_bytes = self._block_bytes * normal_buffers_count + \
self._dynamic_block_bytes * dynamic_buffers_count
# print("创建重复利用buffer: dynamic的数量->", dynamic_buffers_count,
# " 正常的数量->", normal_buffers_count,
# " dynamic buffer大小->", self._dynamic_block_bytes,
# " 正常大小->", self._block_bytes)
try:
assert self._use_contiguous_buffer is False, "malloc dynamic recycle memory buffers only support separation buffer now"
self._src_buffer_array = []
for i in range(self._total_blocks):
if dynamic_buffers_mask[i]:
self._src_buffer_array.append(torch.zeros(self._dynamic_block_bytes,
dtype = torch.uint8,
device = "vacc"))
else:
# 一次性申请3块[N,]大小的BytesBuffer
self._src_buffer_array.append(torch.zeros(self._block_bytes,
dtype = torch.uint8,
device = "vacc"))
self._enable = True
except Exception as e:
print("vacc huge buffer alloc fail.", e)
# slice buffers from the src buffer (48K * blocks)
# use for target tensor
# you should analyse such tensor position by model, such as deepseek have 4 buffers
# you need alloc the buffer by real input tokens when new request is in
# notice: you should warn the dtype, because the buffer created by uint8
def alloc_memory_buffers(self, tokens, dtype=torch.bfloat16):
if tokens > self._max_tokens:
print("alloc memory buffer fail, tokens is large than max_tokens.", self._max_tokens)
return None
self._active_bytes = tokens * self._hiddens * self.get_dtype_bytes(dtype)
return [sub_array[:self._active_bytes]
for sub_array in self._src_buffer_array]
@property
def memory_buffers(self):
return self._src_buffer_array
# allock 1_2 buffers
# @params tokens 待缓存的prefill tokens buffer大小
# @params part 总共划分的区域
# @params return_buffer_list 需要返回的区域列表,如果为空的话,返回所有
# @params dtype 数据类型
# 创建1/N的buffers
# 返回[第N部分]
def alloc_1_div_N_buffers(self, part = 2,
return_buffer_list = [], ):
if not hasattr(self, "_src_buffer"):
print("1 div N alloctor need a contiguous buffer")
return None
assert isinstance(return_buffer_list, list), "return_buffer_list need List object"
# 数据以int8的方式划分为part
tmp_buffer_array = self._src_buffer.view(part, -1)
# 如果未指定return_buffer_list 返回所有的part buffer
if len(return_buffer_list) == 0:
return [tmp_buffer_array[i] for i in range(part)]
return [tmp_buffer_array[i] for i in return_buffer_list]
def get_dtype_bytes(self, dtype):
if isinstance(dtype, torch.dtype):
if dtype in [torch.float16, torch.bfloat16, torch.half]:
return 2
elif dtype in [torch.float32, torch.float, torch.int32]:
return 4
elif dtype in [torch.float64, torch.double, torch.int64]:
return 8
elif dtype in [torch.int8, torch.uint8, torch.bool]:
return 1
else:
return 1
elif dtype == int:
return 8
elif dtype == float:
return 8
elif dtype == bool:
return 1
return 0
@property
def enable(self):
return self._enable
@property
def blocks(self):
return self._total_blocks
@property
def max_tokens(self):
return self._max_tokens
@property
def hiddens(self):
return self._hiddens
@property
def active_bytes(self):
return self._active_bytes
@property
def block_bytes(self):
return self._block_bytes
@property
def dynamic_block_bytes(self):
return self._dynamic_block_bytes
class LLMMemoryRecycler:
def __init__(self):
self.count = 3
self.embedding_output = None
self.moe_shared_mlp_output = None
self.mla_oproj_output = None
#self.moe_expert_output = None
def clear(self):
self.embedding_output = None
self.moe_shared_mlp_output = None
self.mla_oproj_output = None
#self.moe_expert_output = None
@property
def EMBEDDING_OUT_BUFFER(self):
return self.embedding_output
@property
def MOE_SHARED_MLP_OUT_BUFFER(self):
return self.moe_shared_mlp_output
@property
def MLA_OPROJ_OUT_BUFFER(self):
return self.mla_oproj_output
def alloc_memory_recycler_llm(tokens,
alloctor:VaccHugeMemoryAllocator,
recycler:LLMMemoryRecycler,
dtype:torch.dtype = torch.bfloat16):
if not alloctor.enable:
print("memory alloctor is not Init.")
return False
recycler.clear()
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
if out_buffers is None:
print("llm memory recycler buffers alloc fail. now disable it")
return False
if len(out_buffers) != recycler.count:
print("memory recycler buffers not equal llm.")
return False
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
#recycler.moe_expert_output = out_buffers[3].view(dtype).view(tokens, alloctor.hiddens)
return True

View File

@@ -0,0 +1,116 @@
import torch
from .allocator import VaccHugeMemoryAllocator,LLMMemoryRecycler
'''
DeepSeek support 48K input tokens, there are 3 buffers can recycle:
1. parallel_embedding output buffer
2. mla_oproject output buffer
3. moe_shared_mlp output buffer
# 4. moe_expert output buffer
each buffer size is 48K * 7168 * 2 bytes
'''
class DeepseekV3MemoryRecycler(LLMMemoryRecycler):
def __init__(self):
super().__init__()
#self.moe_expert_output = None
# @property
# def MOE_EXPERT_OUT_BUFFER(self):
# return self.moe_expert_output
class DeepseekMTPMemoryRecycler(LLMMemoryRecycler):
def __init__(self):
super().__init__()
self.dynamic_output = None
self.deepseek_mtp_layer_input = None
@property
def DYNAMIC_OUTPUT_BUFFER(self):
return self.dynamic_output
@property
def DEEPSEEK_MTP_LAYER_INPUT(self):
return self.deepseek_mtp_layer_input
def alloc_memory_recycler_deepseek_v3(tokens,
alloctor:VaccHugeMemoryAllocator,
recycler:DeepseekV3MemoryRecycler,
dtype:torch.dtype = torch.bfloat16):
if not alloctor.enable:
print("memory alloctor is not Init.")
return False
recycler.clear()
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
if out_buffers is None:
print("deepseek_v3 memory recycler buffers alloc fail. now disable it")
return False
if len(out_buffers) != recycler.count:
print("memory recycler buffers not equal deepseek_v3.")
return False
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
#recycler.moe_expert_output = out_buffers[3].view(dtype).view(tokens, alloctor.hiddens)
return True
def alloc_memory_recycler_deepseek_mtp(tokens,
alloctor:VaccHugeMemoryAllocator,
world_size:int,
recycler:DeepseekMTPMemoryRecycler,
dtype:torch.dtype = torch.bfloat16):
if not alloctor.enable:
print("memory alloctor is not Init.")
return False
recycler.clear()
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
if out_buffers is None:
print("deepseek_mtp memory recycler buffers alloc fail. now disable it")
return False
if len(out_buffers) != recycler.count:
print("memory recycler buffers not equal deepseek_mtp.")
return False
#MTP的内存布局为
# 1. deepseek 主模型和 草稿模型中的decoder_layer
# a. embedding_output
# b. mla_oproj_output
# c. moe_shared_mlp_output
# 因为: moe会作为previous_hidden_states被重新组织一次组织好的buffer会置于buffer0[a.]中
# 如果embedding_output放在第一位也可以相关地址需要整体往后偏移
# 为了理解方便把embedding置于最后不会参与buffer的重划分复用
# 2. deepseek_mtp 草稿模型(未启用该策略, )
# a. dynamic_output 占用1/2
# b. mtp_input 占用1/6,并且位于dynamic_output buffer后面
# dynamic_buffer = alloctor.alloc_1_div_N_buffers(2, [0,])
# mtp_input_buffer = alloctor.alloc_1_div_N_buffers(6, [3,])
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
#1/2用broadcast的自由分配mtp涉及到previous_hidden_states的广播
memory_buffers = alloctor.memory_buffers
recycler.dynamic_output = memory_buffers[1] #公共用mla_oproject
#1/6用于mtp decoder layer的输入 该算子特殊处理仅用了1/tp的输出buffer即可
# recycler.deepseek_mtp_layer_input = mtp_input_buffer[0]
mtp_layer_input_dims = alloctor.hiddens * 2 // world_size
mtp_layer_input_numels = tokens * mtp_layer_input_dims
from vllm import envs
if envs.VLLM_USE_V1:
#v1 无需重新缓存previous_hidden_states因此moe还在被占用状态因此需要先用attention的o-buffer去暂存mtp 预处理的空间
recycler.deepseek_mtp_layer_input = memory_buffers[1].view(dtype)[:mtp_layer_input_numels].view(tokens, mtp_layer_input_dims)
else:
# 公共用moe output
recycler.deepseek_mtp_layer_input = memory_buffers[2].view(dtype)[:mtp_layer_input_numels].view(tokens, mtp_layer_input_dims)
return True

View File

@@ -0,0 +1,132 @@
import os
import torch
from .allocator import VaccHugeMemoryAllocator
from .deepseek_v3_memory_recycler import DeepseekMTPMemoryRecycler
VLLM_MODEL_MODE = os.environ.get("VLLM_MODEL_MODE", "deepseek")
global huge_memory_alloctor
global memory_recycler
huge_memory_alloctor = None
memory_recycler = None
# you should call this function when new request is in
def alloc_memory_recycler(tokens, dtype=torch.bfloat16, **argv):
global huge_memory_alloctor
global memory_recycler
vllm_model = argv.get('vllm_model')
if vllm_model is None:
print("model infos is empty, now using VLLM_MODEL_MODE")
vllm_model = VLLM_MODEL_MODE
# TODO: use default memory-recycle schedule
# if vllm_model in ['xxx']:
# vllm_model = "llm_default"
memory_recycler = None
if vllm_model == "deepseek":
from .deepseek_v3_memory_recycler import DeepseekV3MemoryRecycler, alloc_memory_recycler_deepseek_v3
memory_recycler = DeepseekV3MemoryRecycler()
state = alloc_memory_recycler_deepseek_v3(tokens, huge_memory_alloctor, memory_recycler, dtype)
if not state:
del memory_recycler
memory_recycler = None
return state
if vllm_model == "deepseek_mtp":
from .deepseek_v3_memory_recycler import DeepseekMTPMemoryRecycler, alloc_memory_recycler_deepseek_mtp
memory_recycler = DeepseekMTPMemoryRecycler()
if argv.get('world_size') is None:
print("mtp should have TP world size, memory recycler allock fail")
return False
state = alloc_memory_recycler_deepseek_mtp(tokens, huge_memory_alloctor, argv['world_size'], memory_recycler, dtype)
if not state:
del memory_recycler
memory_recycler = None
return state
if vllm_model == "qwen3_moe":
from .qwen3_moe_memory_recycler import QWen3MoeMemoryRecycler, alloc_memory_recycler_qwen3_moe
memory_recycler = QWen3MoeMemoryRecycler()
state = alloc_memory_recycler_qwen3_moe(tokens, huge_memory_alloctor, memory_recycler, dtype)
if not state:
del memory_recycler
memory_recycler = None
return state
if vllm_model == "llm_default":
from .allocator import LLMMemoryRecycler, alloc_memory_recycler_llm
memory_recycler = LLMMemoryRecycler()
state = alloc_memory_recycler_llm(tokens, huge_memory_alloctor, memory_recycler, dtype)
if not state:
del memory_recycler
memory_recycler = None
return False
# LLM pipeline parallel 方案下, 对于非stage0的PART
# 在接收来自BEFORE PART的hiddens residual的时候
# 符合内存复用规则
# 该过程与llm forward过程相独立, 需要单独维护
# hiddens 对应 llm forward的时候moe mlp buffer
# residual 对应 llm forward的时候embedding buffer
def alloc_pipeline_parallel_recycler_buffer(size:torch.Size, dtype:torch.dtype, key:str):
global huge_memory_alloctor
if huge_memory_alloctor is None:
return None
intermize_tensor_dict = {
"hidden_states": 2,
"attention": 1,
"residual": 0
}
if not key in intermize_tensor_dict:
return None
src_tensors = huge_memory_alloctor.memory_buffers[intermize_tensor_dict[key]]
all_bytes = size.numel() * huge_memory_alloctor.get_dtype_bytes(dtype)
return src_tensors[:all_bytes].view(dtype).view(size)
# you should call this function when new server and every workers is start
def init_huge_memory_allocator(max_tokens, hidden_size, vllm_model = None):
global huge_memory_alloctor
if vllm_model is None:
print("model infos is empty, now using VLLM_MODEL_MODE")
vllm_model = VLLM_MODEL_MODE
if huge_memory_alloctor is not None:
del huge_memory_alloctor
torch.vacc.empty_cache()
huge_memory_alloctor = None
if vllm_model == "deepseek":
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
huge_memory_alloctor.init_buffers(max_tokens, hidden_size)
return True
# deepseek_mtp set use_congituous = True
# deepseek_mtp buffer recycler:
# buffer[0]: normal_buffer -> embedding_output
# buffer[1]: dynamic_buffer -> mla_oproj_output, dynamic_output
# buffer[2]: normal_buffer -> moe_shared_mlp_output, deepseek_mtp_layer_input
if vllm_model == "deepseek_mtp":
# deepseek dynamic tokens last block use for mtp-weights
# dynamic_buffer_max_tokens = max_tokens + 128, we will let mtp only support 48K now
dynamic_buffer_max_tokens = max_tokens
# dynamic bufffer use more 128tokens for broadcast
# positions, input tokens
deepseek_mtp_max_tokens = max_tokens
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
# huge_memory_alloctor.init_buffers(max_tokens, 7168)
huge_memory_alloctor.init_buffers_with_dynamic(deepseek_mtp_max_tokens, dynamic_buffer_max_tokens, hidden_size, [False, True, False])
return True
if vllm_model == "qwen3_moe":
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
huge_memory_alloctor.init_buffers(max_tokens, hidden_size)
return True
return False

View File

@@ -0,0 +1,38 @@
import torch
from .allocator import VaccHugeMemoryAllocator, LLMMemoryRecycler
'''
QWen3-Moe support 56K input tokens, there are 3 buffers can recycle:
1. parallel_embedding output buffer
2. mla_oproject output buffer
3. moe_shared_mlp output buffer
'''
class QWen3MoeMemoryRecycler(LLMMemoryRecycler):
def __init__(self):
super().__init__()
def alloc_memory_recycler_qwen3_moe(tokens,
alloctor:VaccHugeMemoryAllocator,
recycler:QWen3MoeMemoryRecycler,
dtype:torch.dtype = torch.bfloat16):
if not alloctor.enable:
print("memory alloctor is not Init.")
return False
recycler.clear()
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
if out_buffers is None:
print("qwen3_moe memory recycler buffers alloc fail. now disable it")
return False
if len(out_buffers) != recycler.count:
print("memory recycler buffers not equal qwen3_moe.")
return False
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
return True

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,33 @@
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vacc_tools.trace_logger import get_trace_api
trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
get_trace_api("deepseek")
)
logger = init_logger(__name__)
class Qwen2_5_VisionAttention(nn.Module):
# @trace_time('Qwen2_5_VisionAttention_vacc_split_qkv')
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
q1, k1, v1 = qkv.chunk(3, dim=-1)
q1, k1, v1 = (x.view(*new_shape) for x in (q1, k1, v1))
return q1, k1, v1
class Qwen2_5_VisionPatchEmbed(nn.Module):
# convert conv3d to matmul
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.matmul(x, self.proj.weight.view(self.hidden_size, -1).T)

View File

@@ -0,0 +1,285 @@
"""Inference-only Qwen2VL model compatible with HuggingFace weights."""
import torch
import torch.nn as nn
from typing import Optional
from vllm.attention.layer import check_upstream_fa_availability
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.models.qwen2_vl import Qwen2VisionAttention as Qwen2VisionAttentionOrg
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend
from vllm.logger import init_logger
from .hf_processor.qwenvl_processor import Qwen2VLProcessorWithVacc
from .hf_processor.qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm_vacc.vllm.model_executor.models.vars import USE_FUSED_QWEN_ATTENTION
logger = init_logger(__name__)
class Qwen2VisionPatchEmbed(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self.proj, 'bias') and self.proj.bias is not None:
return torch.nn.functional.linear(x, self.proj.weight.view(self.hidden_size, -1), self.proj.bias)
return torch.matmul(x, self.proj.weight.view(self.embed_dim, -1).T)
class Qwen2VLProcessingInfo():
def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessorWithVacc:
return self.ctx.get_hf_processor(
Qwen2VLProcessorWithVacc,
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFastWithVacc:
return self.get_hf_processor(**kwargs).image_processor
import torch.nn.functional as F
class Qwen2VisionTransformer(nn.Module):
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
# compute position embedding
if USE_FUSED_QWEN_ATTENTION:
try:
from torch_vacc.vacc.custom_qwen3_ops import rot_pos_emb_qwenvl
sin_cache, cos_cache = rot_pos_emb_qwenvl(grid_thw, self.embed_dim, self.num_heads, self.spatial_merge_size, self.dtype, self.device)
except Exception as e:
logger.error(f"rot_pos_emb fused ops run fail, e:{e}")
rotary_pos_emb = None
else:
rotary_pos_emb = self.rot_pos_emb(grid_thw)
sin_cache, cos_cache = None, None
# tmp_rotary_pos_emb = self.transformer_rot_pos_emb(grid_thw)
# qwen3_rotary_pos_emb = self.qwen3_rot_pos_emb(grid_thw)
# compute cu_seqlens
grid_thw_ = torch.tensor(grid_thw)
cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2],
grid_thw_[:, 0]).cumsum(
dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers
x = x.unsqueeze(1)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
if USE_FUSED_QWEN_ATTENTION:
cu_seqlens = cu_seqlens.tolist()
max_seqlen, seqlens = None, None
else:
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
for blk in self.blocks:
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
sin_cache=sin_cache,
cos_cache=cos_cache,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
# adapter
x = self.merger(x)
return x
class Qwen2VisionAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: Optional["QuantizationConfig"] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super(Qwen2VisionAttentionOrg, self).__init__()
# Per attention head and per partition values.
self.tp_size = (1 if use_data_parallel else
parallel_state.get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
# self.qkv = ColumnParallelLinear(input_size=embed_dim,
# output_size=3 * projection_size,
# quant_config=quant_config,
# prefix=f"{prefix}.qkv",
# disable_tp=use_data_parallel)
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel)
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now.")
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
q1, k1, v1 = qkv.chunk(3, dim=-1)
q1, k1, v1 = (x.view(*new_shape) for x in (q1, k1, v1))
return q1, k1, v1
class Qwen2VisionBlock(nn.Module):
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
sin_cache: torch.Tensor,
cos_cache: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
if USE_FUSED_QWEN_ATTENTION:
total_bytes = x.numel() * x.element_size() * get_tp_group().world_size
reduce_result = get_tp_group().world_size > 1 and total_bytes < 4194304
# hidden_states = self.norm1(x)
attn_outs = torch.vacc.fuse_atten_vit(
hidden_states=x.view(-1, x.shape[-1]),
hidden_states_norm_weight = self.norm1.weight,
hidden_states_norm_bias = self.norm1.bias,
# hidden_states_norm_weight = torch.Tensor(),
# hidden_states_norm_bias = torch.Tensor(),
qkv_proj_weight=self.attn.qkv.weight,
qkv_proj_bias=self.attn.qkv.bias,
sin_cache=sin_cache,
cos_cache=cos_cache,
o_proj_weight=self.attn.proj.weight,
o_proj_bias=self.attn.proj.bias if self.attn.proj.tp_rank == 0 else torch.Tensor(),
seq_lens=cu_seqlens,
sm_scale=-1,
num_attention_heads=self.attn.num_attention_heads_per_partition * get_tp_group().world_size,
flash_attention=True,
reduce_result=reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos
)
attn_out = attn_outs[0] if reduce_result else tensor_model_parallel_all_reduce(attn_outs[0])
attn_out = attn_out.view(x.shape)
x = x + attn_out
else:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
return x
class Qwen2VisionMLP():
def forward(self, x: torch.Tensor):
try:
from torch_vacc.vacc import fuse_mlp_vision
hiddens_shape = x.shape
tp_rank_id = get_tp_group().rank_in_group
fc2_bias = None if tp_rank_id > 0 else self.fc2.bias
hidden_states = fuse_mlp_vision(x.view(-1, hiddens_shape[-1]),
self.fc1.weight, # nk
self.fc2.weight, # nk
self.fc1.bias,
fc2_bias,
2) # 0 is gelu, 1 is relu, 2 is quick_gelu
vacc_res = tensor_model_parallel_all_reduce(hidden_states).view(hiddens_shape)
return vacc_res
except Exception as e:
logger.error(f"mlp fused ops run fail, e:{e}")
x_parallel, _ = self.fc1(x)
x_parallel = self.act(x_parallel)
x, _ = self.fc2(x_parallel)
return x
class Qwen2VisionPatchMerger():
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x)
x = x.view(-1, self.hidden_size)
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
try:
from torch_vacc.vacc import patch_merger_vision
tp_rank_id = get_tp_group().rank_in_group
fc2_bias = None if tp_rank_id > 0 else mlp_fc2.bias
hidden_states = patch_merger_vision(x,
mlp_fc1.weight,
mlp_fc2.weight,
mlp_fc1.bias,
fc2_bias,
0) #0 is gelu, 1 is silu
vacc_res = tensor_model_parallel_all_reduce(hidden_states)
return vacc_res
except Exception as e:
logger.error(f"merge patch fused vision mlp run fail, cased by:{e}")
x_parallel, _ = mlp_fc1(x)
x_parallel = mlp_act(x_parallel)
out, _ = mlp_fc2(x_parallel)
return out

View File

@@ -0,0 +1,194 @@
"""Inference-only Qwen3 model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union, Any, Dict
import torch
from torch import nn
from vllm.logger import init_logger
from .vars import *
from vllm.model_executor.layers.linear import UnquantizedLinearMethod as UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
logger = init_logger(__name__)
# uniform the params names from different quantize method
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
if isinstance(quant_method, UnquantizedLinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = torch.Tensor()
fused_params[name + '_bias'] = None
fused_params[name + '_qzeros'] = None
elif isinstance(quant_method, Fp8LinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, GPTQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, AWQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
else:
raise ValueError(f"Unsupported quant_method: {quant_method}")
class Qwen3Attention(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None # new added params
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
# reshape kvcache
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
if isinstance(attn_metadata_all, dict):
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
is_decode = attn_metadata.prefill_metadata is None
else:
is_decode = attn_metadata_all.prefill_metadata is None
attn_metadata = attn_metadata_all
reduce_result = is_decode
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
# # only support 4M now
# if total_bytes < 4194304:
# reduce_result = True
if USE_FUSED_QWEN_ATTENTION:
if is_decode:
positions = [i - 1 for i in attn_metadata.seq_lens]
cos_cache = [self.rotary_emb.cos_cache[i:i+1, ...] for i in positions]
sin_cache = [self.rotary_emb.sin_cache[i:i+1, ...] for i in positions]
else:
cos_cache = [self.rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
sin_cache = [self.rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
if residual is None:
res_out = hidden_states
#from torch_vacc.vacc import fuse_atten_qwen3
attn_outs = None
if not is_decode:
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
if memory_recycler is not None:
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
total_num_kv_heads = self.total_num_kv_heads
if self.total_num_kv_heads < get_tp_group().world_size:
assert get_tp_group().world_size % self.total_num_kv_heads == 0
total_num_kv_heads = get_tp_group().world_size
attn_outs = torch.vacc.fuse_atten_qwen3(
# attn_outs = vacc_fused_attn_qwen3_naive(
hidden_states=hidden_states,
residual=residual,
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
q_layernorm_weight=self.fused_params['q_norm_weight'],
k_layernorm_weight=self.fused_params['k_norm_weight'],
sin_cache=sin_cache,
cos_cache=cos_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache=kv_cache,
block_tables=attn_metadata.block_tables, # tensor
block_group_size=env_blk_grp_size,
o_proj_weight=self.fused_params['o_proj_weight'],
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
o_proj_bias=self.fused_params['o_proj_bias'],
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
seq_lens=attn_metadata.seq_lens,
sm_scale=self.scaling,
num_attention_heads=self.total_num_heads,
num_key_value_heads=total_num_kv_heads,
flash_attention=is_decode, # decode use flash_atten by default
is_decode=is_decode,
reduce_result=reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos,
output_opt=attn_outs,
res_opt=residual)
if residual is None:
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
else:
res_out = attn_outs[1]
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
return attn_out, res_out
else:
# orig code
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Qwen3DecoderLayer(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
if USE_FUSED_QWEN_ATTENTION:
if not hasattr(self.self_attn, "fused_params"):
self.self_attn.fused_params = {}
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
hidden_states, residual = self.self_attn(
positions=positions,
hidden_states=hidden_states,
residual=residual)
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual

View File

@@ -0,0 +1,790 @@
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union, List
import itertools
import os
import torch
from torch import nn
from transformers import PretrainedConfig
from torch_vacc.vacc.custom_ops_cpu import (
w8a8_block_fp8_linear as w8a8_block_fp8_linear_cpu,
)
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock, Qwen3MoeMLP
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
from ..ops.mrope_op import get_sin_cos_mrope
from ..ops.qwen3_fused_moe import vacc_fused_prefill_moe_fp8, vacc_fused_decode_moe_fp8, recompute_moe_layer_blocksize
from .vars import *
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
logger = init_logger(__name__)
# uniform the params names from different quantize method
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
if isinstance(quant_method, UnquantizedLinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = None
fused_params[name + '_bias'] = None
fused_params[name + '_qzeros'] = None
if isinstance(quant_method, Fp8LinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, GPTQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, AWQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
else:
raise ValueError(f"Unsupported quant_method: {quant_method}")
def apply_w8a8_block_fp8_linear_v2(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_scale = None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
block_size = [
weight.shape[-2] // weight_scale.shape[-2],
weight.shape[-1] // weight_scale.shape[-1],
]
if input.device.type == "vacc":
output = torch.vacc.w8a8_block_fp8_linear(
input_2d, weight, input_scale, weight_scale, block_size
)
else:
output = w8a8_block_fp8_linear_cpu(
input_2d, weight, input_scale, weight_scale, block_size
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def vacc_fused_attn_qwen3_naive(
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
hidden_states_norm_weight: torch.Tensor,
qkv_proj_weight: torch.Tensor,
qkv_proj_weight_scale: torch.Tensor,
qkv_proj_bias: Optional[torch.Tensor],
qkv_proj_qzeros: Optional[torch.Tensor],
q_layernorm_weight: torch.Tensor,
k_layernorm_weight: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
block_group_size: int,
o_proj_weight: torch.Tensor,
o_proj_weight_scale: torch.Tensor,
o_proj_bias: Optional[torch.Tensor],
o_proj_qzeros: Optional[torch.Tensor],
seq_lens: List[int],
sm_scale: float,
num_attention_heads: int,
num_key_value_heads: int,
flash_attention: bool,
is_decode: bool,
reduce_result: bool,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
block_size: int = 16
):
if residual is not None:
hidden_states = hidden_states + residual
residual_out = hidden_states
hidden_states = torch.vacc.rms_norm(
hidden_states.unsqueeze(0), hidden_states_norm_weight, 1e-6).squeeze(0)
# NOTE: for qwen3 and qwen2.5, head_dim is always 128
head_dim = 128
# qkv gen
qkv = apply_w8a8_block_fp8_linear_v2(
input=hidden_states,
weight=qkv_proj_weight,
weight_scale=qkv_proj_weight_scale)
num_q_heads = num_attention_heads // world_size
num_kv_heads = num_key_value_heads // world_size
q_size = head_dim * num_q_heads
kv_size = head_dim * num_kv_heads
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
# q_by_head = self.q_norm.forward_native(q_by_head)
q_norm = torch.vacc.rms_norm(q_by_head, q_layernorm_weight, 1e-6)
# q = q_by_head.view(q.shap
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
# k_by_head = k_norm.forward_native(k_by_head)
k_norm = torch.vacc.rms_norm(k_by_head, k_layernorm_weight, 1e-6)
# k = k_by_head.view(k.shap
v = v.view(-1, num_kv_heads, head_dim)
# q, k = self.rotary_emb(positions, q, k)
start = 0
attn_outs = []
if is_decode:
# convert block_tables to 8K group index
block_per_group = block_group_size // block_size
block_tables = (block_tables // block_per_group).to(torch.int32)
# logger.warning(f"decode block table: {block_tables}")
num_blocks = kv_cache.shape[1]
key_cache_split = kv_cache[0].view(num_blocks, -1, num_kv_heads, head_dim)
value_cache_split = kv_cache[1].view(num_blocks, -1, num_kv_heads, head_dim)
# bs loop
for i in range(len(seq_lens)):
if not is_decode:
# prefill
end = start + seq_lens[i]
else:
# decode
end = start + 1
cos = cos_cache[i].unsqueeze(-2)
sin = sin_cache[i].unsqueeze(-2)
q, k = torch.vacc.RotaryPosEmbedding(
q_norm[start : end, ...], k_norm[start : end, ...], cos, sin, 0, "neox")
# cache concat
torch.vacc.reshape_and_cache_attention(k, key_cache_split, slot_mapping[start : end, ...])
torch.vacc.reshape_and_cache_attention(v[start : end, ...], value_cache_split, slot_mapping[start : end, ...])
# attn_output = self.attn(q, k, v)
if not is_decode:
# prefill
attn_out = torch.vacc.scaled_dot_product_attention(
query=q,
key=k,
value=v[start : end, ...],
attn_mask = None,
dropout_p = 0.0,
is_causal = True, #causal_attn and not self.need_mask,
is_train = False,
recompute = False,
flash_attention = False,
sm_scale=sm_scale)
else:
# decode
key_cache = key_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
value_cache = value_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
k_slices = key_cache[block_tables[i], ...]
k_cached = torch.cat(
[k_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
dim=0,
)
k_cached = k_cached.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
v_slices = value_cache[block_tables[i], ...]
v_cached = torch.cat(
[v_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
dim=0,
)
v_cached = v_cached.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
attn_out = torch.vacc.scaled_dot_product_attention(
query=q,
key=k_cached,
value=v_cached,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,#flash_attention,
sm_scale=sm_scale)
attn_outs.append(attn_out)
# update start
start = end
attn_out = torch.cat(attn_outs, dim=0)
# output, _ = self.o_proj(attn_output)
o_proj = apply_w8a8_block_fp8_linear_v2(
input = attn_out.reshape(hidden_states.shape[0], -1),
weight = o_proj_weight,
weight_scale = o_proj_weight_scale,
)
if reduce_result:
o_proj = tensor_model_parallel_all_reduce(o_proj)
if residual is not None:
return o_proj, residual_out
return o_proj
def Qwen3MoeSparseMoeBlock__init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super(Qwen3MoeSparseMoeBlock, self).__init__()
config = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE(num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=True,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate")
#patch here to transpose w2/w2_scale's data arrange , only for block quant
if hasattr(self.experts.quant_method, 'quant_config') and hasattr(self.experts.quant_method.quant_config, 'weight_block_size'):
self.experts.w2_weight.data = self.experts.w2_weight.data.transpose(-1,-2).contiguous().transpose(-1,-2)
self.experts.w2_weight_scale_inv.data = self.experts.w2_weight_scale_inv.data.transpose(-1,-2).contiguous().transpose(-1,-2)
if hasattr(self.experts, 'w2_weight_scale_inv_prefill'):
self.experts.w2_weight_scale_inv_prefill.data = self.experts.w2_weight_scale_inv_prefill.data.transpose(-1,-2).contiguous().transpose(-1,-2)
def get_cos_sin_cache(rotary_emb: Union["MRotaryEmbedding", "RotaryEmbedding"],
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]],
positions: Union[torch.Tensor, list],
is_decode: bool):
if isinstance(rotary_emb, MRotaryEmbedding):
# get mrope sin/cos
cos_cache, sin_cache = get_sin_cos_mrope(rotary_emb, positions)
if len(attn_metadata.seq_lens) > 1:
if is_decode:
cos_cache = torch.chunk(cos_cache, len(attn_metadata.seq_lens))
sin_cache = torch.chunk(sin_cache, len(attn_metadata.seq_lens))
else:
cos_cache = torch.split(cos_cache, attn_metadata.seq_lens)
sin_cache = torch.split(sin_cache, attn_metadata.seq_lens)
else:
cos_cache = [cos_cache]
sin_cache = [sin_cache]
else:
if is_decode:
positions = [i - 1 for i in attn_metadata.seq_lens]
cos_cache = [rotary_emb.cos_cache[i:i+1, ...] for i in positions]
sin_cache = [rotary_emb.sin_cache[i:i+1, ...] for i in positions]
else:
cos_cache = [rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
sin_cache = [rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
return cos_cache, sin_cache
class Qwen3MoeDecoderLayer(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
cos_cache: list[torch.Tensor],
sin_cache: list[torch.Tensor]
) -> torch.Tensor:
# Self Attention
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
if USE_FUSED_QWEN_ATTENTION:
if not hasattr(self.self_attn, "fused_params"):
self.self_attn.fused_params = {}
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
hidden_states, residual = self.self_attn(
positions=positions,
hidden_states=hidden_states,
residual=residual,
cos_cache=cos_cache,
sin_cache=sin_cache)
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
cos_cache=cos_cache,
sin_cache=sin_cache
)
# # Fully Connected
# hidden_states, residual = self.post_attention_layernorm(
# hidden_states, residual)
# hidden_states = self.mlp(hidden_states)
# return hidden_states, residual
# TODO for noquant or not block_quant
if not hasattr(self.mlp.experts.quant_method, 'quant_config') or \
not hasattr(self.mlp.experts.quant_method.quant_config, 'weight_block_size'):
if not isinstance(self.mlp.experts.quant_method, MoeWNA16Method):
logger.warning('TODO for noquant or other quant')
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
if isinstance(attn_metadata, dict):
# is_prefill = get_forward_context().attn_metadata['test'].prefill_metadata
attn_metadata_0 = get_forward_context().attn_metadata.items().__iter__().__next__()[1]
is_prefill = attn_metadata_0.prefill_metadata
else:
is_prefill = get_forward_context().attn_metadata.prefill_metadata
quant_method = self.mlp.experts.quant_method if isinstance(self.mlp, Qwen3MoeSparseMoeBlock) \
else self.mlp.down_proj.quant_method
if is_prefill is not None:
if isinstance(quant_method, MoeWNA16Method):
try:
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_prefill_moe_gptq_int4
return vacc_fused_prefill_moe_gptq_int4(hidden_states,
residual,
self.post_attention_layernorm,
self.mlp.gate,
self.mlp.experts)
except Exception as e:
print(f'vacc_fused_prefill_moe_gptq_int4 fail: {e}')
else:
recompute_moe_layer_blocksize(self.mlp.experts)
try:
return vacc_fused_prefill_moe_fp8(hidden_states,
residual,
self.post_attention_layernorm,
self.mlp.gate,
self.mlp.experts)
except Exception as e:
print(f'vacc_fused_prefill_moe_fp8 fail: {e}')
else:
if isinstance(quant_method, MoeWNA16Method):
try:
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_decode_moe_gptq_int4
return vacc_fused_decode_moe_gptq_int4(hidden_states,
residual,
self.post_attention_layernorm,
self.mlp.gate,
self.mlp.experts)
except Exception as e:
print(f'vacc_fused_decode_moe_gptq_int4 fail: {e}')
else:
try:
return vacc_fused_decode_moe_fp8(hidden_states,
residual,
self.post_attention_layernorm,
self.mlp.gate,
self.mlp.experts)
except Exception as e:
print(f'vacc_fused_decode_moe_fp8 fail: {e}')
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen3MoeAttention(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None, # new added params
cos_cache: list[torch.Tensor] = None,
sin_cache: list[torch.Tensor] = None,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
# reshape kvcache
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
if isinstance(attn_metadata_all, dict):
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
is_decode = attn_metadata.prefill_metadata is None
else:
is_decode = attn_metadata_all.prefill_metadata is None
attn_metadata = attn_metadata_all
reduce_result = is_decode
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
# # only support 4M now
# if total_bytes < 4194304:
# reduce_result = True
if USE_FUSED_QWEN_ATTENTION:
if cos_cache is None or sin_cache is None:
cos_cache, sin_cache = get_cos_sin_cache(self.rotary_emb, attn_metadata, positions, is_decode)
if residual is None:
res_out = hidden_states
#from torch_vacc.vacc import fuse_atten_qwen3
attn_outs = None
if not is_decode:
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
if memory_recycler is not None:
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
total_num_kv_heads = self.total_num_kv_heads
if self.total_num_kv_heads < get_tp_group().world_size:
assert get_tp_group().world_size % self.total_num_kv_heads == 0
total_num_kv_heads = get_tp_group().world_size
attn_outs = torch.vacc.fuse_atten_qwen3(
# attn_outs = vacc_fused_attn_qwen3_naive(
hidden_states=hidden_states,
residual=residual,
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
q_layernorm_weight=self.fused_params['q_norm_weight'],
k_layernorm_weight=self.fused_params['k_norm_weight'],
sin_cache=sin_cache,
cos_cache=cos_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache=kv_cache,
block_tables=attn_metadata.block_tables,
block_group_size=env_blk_grp_size,
o_proj_weight=self.fused_params['o_proj_weight'],
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
o_proj_bias=self.fused_params['o_proj_bias'],
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
seq_lens=attn_metadata.seq_lens,
sm_scale=self.scaling,
num_attention_heads=self.total_num_heads,
num_key_value_heads=total_num_kv_heads,
flash_attention=is_decode, # decode use flash_atten by default
is_decode=is_decode,
reduce_result=reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos,
output_opt=attn_outs,
res_opt=residual)
# debug_qwen3_moe_attention_prefill(hidden_states=hidden_states,
# residual=residual,
# attn_outs=attn_outs,
# fused_params=self.fused_params,
# attn_metadata=attn_metadata,
# is_decode=is_decode,
# sin_cache=sin_cache,
# cos_cache=cos_cache,
# kv_cache=kv_cache,
# env_blk_grp_size=env_blk_grp_size,
# scaling=self.scaling,
# total_num_heads=self.total_num_heads,
# total_num_kv_heads=self.total_num_kv_heads,
# world_size=get_tp_group().world_size,
# rank=get_tp_group().rank_in_group,
# group_id=get_tp_group().group_id,
# dev_info=get_tp_group().rank_device_infos)
if residual is None:
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
else:
res_out = attn_outs[1]
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
return attn_out, res_out
else:
# orig code
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm.forward_native(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm.forward_native(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Qwen3MoeModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
deepstack_input_embeds: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
if not hasattr(self, "weight_capture"):
from vllm_vacc.vllm.model_executor.models.weight_capture.qwen3_moe_weight_capture import Qwen3Moe_WeightCapture
self.weight_capture = Qwen3Moe_WeightCapture(self.layers, self.start_layer, self.end_layer)
self.layer_nums = self.end_layer - self.start_layer
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# fused layer decoder only support fp8 quant model now
use_default_layer = self.weight_capture.support_fused_weights and USE_DECODER_LAYER_FUSE_MODE
# print('Qwen3MoeModel attn_metadata', attn_metadata)
if isinstance(attn_metadata_all, dict):
# is_decode = attn_metadata_all['test'].prefill_metadata is None
# attn_metadata = attn_metadata_all['test']
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
is_decode = attn_metadata.prefill_metadata is None
else:
is_decode = attn_metadata_all.prefill_metadata is None
attn_metadata = attn_metadata_all
if(use_default_layer and is_decode):
from torch_vacc.vacc.custom_ops import qwen3_fuse_attention_moe_decode
layer0 = self.layers[self.start_layer]
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode=True)
for i in range(0, self.layer_nums):
layer = self.layers[i + self.start_layer]
kv_cache = layer.self_attn.attn.kv_cache[forward_context.virtual_engine]
num_kv_heads = max(1, layer.self_attn.total_num_kv_heads // get_tp_group().world_size)
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, layer.self_attn.head_dim)
total_num_kv_heads = layer.self_attn.total_num_kv_heads
if layer.self_attn.total_num_kv_heads < get_tp_group().world_size:
assert get_tp_group().world_size % layer.self_attn.total_num_kv_heads == 0
total_num_kv_heads = get_tp_group().world_size
hidden_states, residual = qwen3_fuse_attention_moe_decode(hidden_states, residual,
hidden_states_norm_weight=self.weight_capture.layer_mapper.attn_args._0_input_layernorm_weight[i],
qkv_proj_weight=self.weight_capture.layer_mapper.attn_args._1_qkv_proj_weight[i],
qkv_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._2_qkv_proj_weight_scale[i],
qkv_proj_bias=self.weight_capture.layer_mapper.attn_args._3_qkv_proj_bias[i],
qkv_proj_qzeros=self.weight_capture.layer_mapper.attn_args._4_qkv_proj_qzeros[i],
q_layernorm_weight=self.weight_capture.layer_mapper.attn_args._5_q_norm_weight[i],
k_layernorm_weight=self.weight_capture.layer_mapper.attn_args._6_k_norm_weight[i],
sin_cache=sin_cache,
cos_cache=cos_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache=kv_cache,
block_tables=attn_metadata.block_tables,
block_group_size=env_blk_grp_size,
o_proj_weight=self.weight_capture.layer_mapper.attn_args._13_o_proj_weight[i],
o_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._14_o_proj_weight_scale[i],
o_proj_bias=self.weight_capture.layer_mapper.attn_args._15_o_proj_bias[i],
o_proj_qzeros=self.weight_capture.layer_mapper.attn_args._16_o_proj_qzeros[i],
seq_lens_num=attn_metadata.seq_lens,
sm_scale=layer.self_attn.scaling,
num_attention_heads=layer.self_attn.total_num_heads,
num_key_value_heads=total_num_kv_heads,
flash_attentiton=True,
is_decode=True,
reduce_result=True,
# moe
rms_weight=self.weight_capture.layer_mapper.moe_args._0_rms_norm_weight[i],
moe_weight_13=self.weight_capture.layer_mapper.moe_args._1_w13_weight[i],
moe_weight_2=self.weight_capture.layer_mapper.moe_args._2_w2_weight[i],
moe_weight_13_dequat=self.weight_capture.layer_mapper.moe_args._3_w13_weight_scale_inv[i],
moe_weight_2_dequant=self.weight_capture.layer_mapper.moe_args._4_w2_weight_scale_inv[i],
gate_weight=self.weight_capture.layer_mapper.moe_args._5_gate_weight[i],
block_size_13=self.weight_capture.layer_mapper.moe_args._6_w13_block_size,
block_size_2=self.weight_capture.layer_mapper.moe_args._7_w2_block_size,
# dist
world_size=self.weight_capture.layer_mapper.dist_args._0_world_size,
rank=self.weight_capture.layer_mapper.dist_args._1_rank,
group_id=self.weight_capture.layer_mapper.dist_args._2_group_id,
dev_info=self.weight_capture.layer_mapper.dist_args._3_dev_info)
else:
layer0 = self.layers[self.start_layer]
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode)
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual, cos_cache, sin_cache )
if deepstack_input_embeds is not None and i in range(0, len(deepstack_input_embeds)):
if isinstance(deepstack_input_embeds, IntermediateTensors):
hidden_states = hidden_states + deepstack_input_embeds[f"deepstack_input_embeds_{i}"]
elif isinstance(deepstack_input_embeds, torch.Tensor):
hidden_states = hidden_states + deepstack_input_embeds[i]
else:
raise ValueError(f'unsupported type: {type(deepstack_input_embeds)}')
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states, residual)
return hidden_states
class Qwen3MoeForCausalLM(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
deepstack_input_embeds = None,
) -> Union[torch.Tensor, IntermediateTensors]:
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
if attn_metadata.prefill_metadata is not None:
from .memory.memory_recycling import alloc_memory_recycler
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
if hasattr(attn_metadata, 'num_prefill_tokens'):
tokens = attn_metadata.num_prefill_tokens
else:
tokens = attn_metadata.prefill_metadata.num_prefill_tokens
vllm_model_mode = "qwen3_moe"
config_infos = vllm_vacc_config_manager().get_model_infos()
if config_infos != "default":
vllm_model_mode = config_infos
if get_tp_group().rank_in_group == 0:
memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}'
logger.info(memory_infos)
if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size, dtype=self.lm_head.weight.dtype):
logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, deepstack_input_embeds)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
from .memory.memory_recycling import init_huge_memory_allocator
from .vars import LLM_MAX_PREFILL_SEQ_LEN
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
# default is deepseek, config can set to ['deepseek_mtp',]
model_name = "qwen3_moe"
config_infos = vllm_vacc_config_manager().get_model_infos()
if config_infos != "default":
model_name = config_infos
if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name):
logger.warning("init huge memory allocator fail. prefill memory recycling will disable")
from vllm.model_executor.models.utils import AutoWeightsLoader
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@@ -0,0 +1,362 @@
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
maybe_prefix, merge_multimodal_embeddings)
from vllm.logger import init_logger
from .hf_processor.qwenvl_processor import Qwen3VLProcessorWithVacc
from .hf_processor.qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
from .vars import USE_FUSED_QWEN_ATTENTION
# from vacc_tools.trace_logger import get_trace_api
# trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
# get_trace_api("Qwen3vl")
# )
logger = init_logger(__name__)
class Qwen3_VisionPatchEmbed(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self.proj, 'bias') and self.proj.bias is not None:
return torch.nn.functional.linear(x, self.proj.weight.view(self.hidden_size, -1), self.proj.bias)
return torch.matmul(x, self.proj.weight.view(self.hidden_size, -1).T)
class Qwen3_VisionBlock(nn.Module):
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | list[torch.Tensor],
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
if USE_FUSED_QWEN_ATTENTION:
assert isinstance(rotary_pos_emb, list), "qwen3vl vit-attention need rotary_pos_emb is list[torch.Tensor]"
total_bytes = x.numel() * x.element_size() * get_tp_group().world_size
reduce_result = get_tp_group().world_size > 1 and total_bytes < 4194304
# hidden_states = self.norm1(x)
attn_outs = torch.vacc.fuse_atten_vit(
hidden_states=x.view(-1, x.shape[-1]),
hidden_states_norm_weight = self.norm1.weight,
hidden_states_norm_bias = self.norm1.bias,
# hidden_states_norm_weight = torch.Tensor(),
# hidden_states_norm_bias = torch.Tensor(),
qkv_proj_weight=self.attn.qkv.weight,
qkv_proj_bias=self.attn.qkv.bias,
sin_cache=rotary_pos_emb[0],
cos_cache=rotary_pos_emb[1],
o_proj_weight=self.attn.proj.weight,
o_proj_bias=self.attn.proj.bias if self.attn.proj.tp_rank == 0 else torch.Tensor(),
seq_lens=cu_seqlens,
sm_scale=-1,
num_attention_heads=self.attn.num_attention_heads_per_partition * get_tp_group().world_size,
flash_attention=True,
reduce_result=reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos
)
attn_out = attn_outs[0] if reduce_result else tensor_model_parallel_all_reduce(attn_outs[0])
attn_out = attn_out.view(x.shape)
x = x + attn_out
else:
x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens)
x = x + self.mlp(self.norm2(x))
return x
class Qwen3_VisionTransformer(nn.Module):
def rot_pos_emb(self, grid_thw):
if USE_FUSED_QWEN_ATTENTION:
try:
from torch_vacc.vacc.custom_qwen3_ops import rot_pos_emb_qwenvl
return rot_pos_emb_qwenvl(grid_thw, self.hidden_size, self.num_heads, self.spatial_merge_size, self.dtype, self.device)
except Exception as e:
logger.error(f"rot_pos_emb fused ops run fail, e:{e}")
pos_ids = []
# Support both Tensor and list inputs for DP path
if isinstance(grid_thw, list):
grid_list = grid_thw
max_grid_size = max(max(h, w) for _, h, w in grid_list)
else:
grid_list = grid_thw.tolist()
max_grid_size = int(grid_thw[:, 1:].max().item())
for t, h, w in grid_list:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def fast_pos_embed_interpolate(self,
grid_thw: list[list[int]]) -> torch.Tensor:
num_grid_per_side = self.num_grid_per_side
m_size = self.spatial_merge_size
hidden_dim = self.pos_embed.embedding_dim
try:
from torch_vacc.vacc.custom_qwen3_ops import fast_pos_embed_interpolate_qwenvl
return fast_pos_embed_interpolate_qwenvl(self.pos_embed.weight, grid_thw, num_grid_per_side, m_size, hidden_dim)
except Exception as e:
logger.error(f"fast_pos_embed_interpolate fused ops run fail, e:{e}")
outputs = []
for t, h, w in grid_thw:
h_idxs = torch.linspace(0,
num_grid_per_side - 1,
h,
dtype=torch.float32,
device=self.device)
w_idxs = torch.linspace(0,
num_grid_per_side - 1,
w,
dtype=torch.float32,
device=self.device)
h_floor = h_idxs.to(torch.long)
w_floor = w_idxs.to(torch.long)
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
dh = h_idxs - h_floor
dw = w_idxs - w_floor
# Create meshgrid view for all h, w vars
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij')
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor,
w_floor,
indexing='ij')
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil,
w_ceil,
indexing='ij')
h_floor_grid_idx = h_floor_grid * num_grid_per_side
h_ceil_grid_idx = h_ceil_grid * num_grid_per_side
# original computation of weights
# w00 = (1 - dh_grid) * (1 - dw_grid)
# w01 = (1 - dh_grid) * dw_grid
# w10 = dh_grid * (1 - dw_grid)
# w11 = dh_grid * dw_grid
# we reuse w11 here to avoid duplicate
# dh_grid * dw_grid computation
w11 = dh_grid * dw_grid
w10 = dh_grid - w11
w01 = dw_grid - w11
w00 = 1 - dh_grid - dw_grid + w11
idx00 = h_floor_grid_idx + w_floor_grid
idx01 = h_floor_grid_idx + w_ceil_grid
idx10 = h_ceil_grid_idx + w_floor_grid
idx11 = h_ceil_grid_idx + w_ceil_grid
indices = torch.stack([idx00, idx01, idx10, idx11],
dim=0).reshape(4, -1)
weights = torch.stack([w00, w01, w10, w11],
dim=0).reshape(4, -1, 1)
weights = weights.to(dtype=self.dtype, device=self.device)
embeds = self.pos_embed(indices)
weighted_embeds = embeds * weights
p0, p1, p2, p3 = weighted_embeds.unbind(dim=0)
combined = p0 + p1 + p2 + p3
combined = combined.view(h * w, hidden_dim)
repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous()
repeated = repeated.view(t, h // m_size, m_size, w // m_size,
m_size, hidden_dim)
repeated = repeated.permute(0, 1, 3, 2, 4,
5).reshape(-1, hidden_dim)
outputs.append(repeated)
return torch.cat(outputs, dim=0)
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype)
hidden_states = self.patch_embed(hidden_states)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
grid_thw_tensor = torch.tensor(grid_thw,
dtype=torch.int32)
cu_seqlens = torch.repeat_interleave(
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
grid_thw_tensor[:, 0]).cumsum(
dim=0,
dtype=grid_thw_tensor.dtype
if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
hidden_states = hidden_states.unsqueeze(1)
if isinstance(rotary_pos_emb, torch.Tensor):
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
if USE_FUSED_QWEN_ATTENTION:
max_seqlen, seqlens = None, None
cu_seqlens = cu_seqlens.tolist()
else:
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(
layer_num)
deepstack_feature = self.deepstack_merger_list[
deepstack_merger_idx](hidden_states)
deepstack_feature_lists.append(deepstack_feature)
hidden_states = self.merger(hidden_states)
hidden_states = torch.cat(
[hidden_states] + deepstack_feature_lists,
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
return hidden_states
class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
deepstack_input_embeds = None
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if self.use_deepstack:
deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501
input_ids, inputs_embeds, multimodal_embeddings)
self._set_deepstack_input_embeds(deepstack_input_embeds)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
# commit here to remove deepstack_input_embeds copy
# if self.use_deepstack:
# if deepstack_input_embeds is None:
# deepstack_input_embeds = torch.zeros_like(
# inputs_embeds).unsqueeze(0).repeat(
# self.deepstack_num_level, 1, 1).contiguous()
# self._set_deepstack_input_embeds(deepstack_input_embeds)
return inputs_embeds
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
return #patch here to optimize deepstack_input_embeds
# clear deepstack_input_embeds in buffer
if num_tokens > 0:
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_()
class Qwen3VLProcessingInfo():
def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessorWithVacc:
processor = self.ctx.get_hf_processor(
Qwen3VLProcessorWithVacc,
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
return processor
def get_image_processor(self,
**kwargs: object) -> Qwen2VLImageProcessorFastWithVacc:
return self.get_hf_processor(**kwargs).image_processor
# def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
# return self.get_hf_processor(**kwargs).video_processor
class Qwen3_VisionPatchMerger():
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_postshuffle_norm:
x = self.norm(x.view(-1, self.hidden_size))
else:
x = self.norm(x).view(-1, self.hidden_size)
try:
from torch_vacc.vacc import patch_merger_vision
tp_rank_id = get_tp_group().rank_in_group
fc2_bias = None if tp_rank_id > 0 else self.linear_fc2.bias
hidden_states = patch_merger_vision(x,
self.linear_fc1.weight, self.linear_fc2.weight,
self.linear_fc1.bias, fc2_bias,
0) #0 is gelu, 1 is silu
return tensor_model_parallel_all_reduce(hidden_states)
except Exception as e:
logger.error(f"merge patch fused vision mlp run fail, cased by:{e}")
x_parallel, _ = self.linear_fc1(x)
x_parallel = self.act_fn(x_parallel)
out, _ = self.linear_fc2(x_parallel)
return out
class Qwen3_VisionMLP():
def forward(self, x: torch.Tensor):
try:
from torch_vacc.vacc import fuse_mlp_vision
hiddens_shape = x.shape
tp_rank_id = get_tp_group().rank_in_group
fc2_bias = None if tp_rank_id > 0 else self.linear_fc2.bias
hidden_states = fuse_mlp_vision(x.view(-1, hiddens_shape[-1]),
self.linear_fc1.weight, self.linear_fc2.weight,
self.linear_fc1.bias, fc2_bias,
0) #0 is gelu, 1 is silu
return tensor_model_parallel_all_reduce(hidden_states).view(hiddens_shape)
except Exception as e:
logger.error(f"qwen3vl fused vision mlp run fail, cased by:{e}")
return self.linear_fc2(self.act_fn(self.linear_fc1(x)))

View File

@@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import nn
from vllm.model_executor.models.bert import _decode_token_type_ids
class RobertaEmbedding(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids)
inputs_embeds = self.word_embeddings(input_ids)
# position_embeddings = self.position_embeddings(position_ids)
# token_type_embeddings = self.token_type_embeddings(token_type_ids)
# embeddings = inputs_embeds + token_type_embeddings + position_embeddings
# embeddings = self.LayerNorm(embeddings)
embeddings = torch.vacc.fuse_bge_embedding_stage1(inputs_embeds, position_ids, self.position_embeddings.weight, token_type_ids, self.token_type_embeddings.weight, self.LayerNorm.weight, self.LayerNorm.bias, self.LayerNorm.eps)
return embeddings

View File

@@ -0,0 +1,58 @@
import os
# support Q,KV Gen with TP
USE_PARALLEL_Q_KV_GEN = True
# support Merge Q,KV Gen, Q,QR weights Merge
USE_MERGE_Q_KV_GEN_AND_Q_QR = True
# Support FP8 Weights for WQ,QR
W_Q_W_QR_WUV_WUK_USE_FP8 = True
# fused prefill
USE_FUSED_PREFILL = True
# fused prefill stage1
USE_FUSED_PREFILL_STAGE1 = True
# All Request Seq Lens
DO_SEQ_LENS = 0
def update_seqence_length(seq_num):
global DO_SEQ_LENS
DO_SEQ_LENS = seq_num
USE_DS3_SAMPLER = int(os.getenv("USE_DS3_SAMPLER", 1))
USE_DS3_SAMPLER_OP = int(os.getenv("USE_DS3_SAMPLER_OP", 1))
# cut prefill seq len
CUT_PREFILL_SEQ_LEN = int(os.getenv("CUT_PREFILL_SEQ_LEN", -1))
# llm max prefill seq len
LLM_MAX_PREFILL_SEQ_LEN = int(os.getenv("LLM_MAX_PREFILL_SEQ_LEN", 56 * 1024))
# All Fused Decode, default is cpu loop
USE_DECODER_LAYER_FUSE_MODE = int(os.getenv("USE_DECODER_LAYER_FUSE_MODE", 1))
# Fused all layers, use cmcu loop
FUSE_ALL_DECODER_LAYERS = int(os.getenv("FUSE_ALL_DECODER_LAYERS", 1))
# where to use flash attention (default: 1)
USE_FLASH_ATTENTION = int(os.getenv("USE_FLASH_ATTENTION", 1))
# transpose gptq weight KN => NK
TRANSPOSE_GPTQ_WEIGHT = True
# qwen fused attention
USE_FUSED_QWEN_ATTENTION = int(os.getenv("USE_FUSED_QWEN_ATTENTION", 1))
# support MTP eh_proj with TP
USE_PARALLEL_MTP_EH_PROJ = int(os.getenv("USE_PARALLEL_MTP_EH_PROJ", 1))
# kv_cache group size
BLOCK_GROUP_SIZE = int(os.getenv("BLOCK_GROUP_SIZE", 8192))
# bert fused attention
USE_FUSED_BERT_ATTENTION = int(os.getenv("USE_FUSED_BERT_ATTENTION", 1))
# fused mlp vision
USE_FUSED_MLP_VISION = int(os.getenv("USE_FUSED_MLP_VISION", 1))

View File

@@ -0,0 +1,491 @@
import torch
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer,
DeepseekV2MLAAttention,
DeepseekV2MLP,
DeepseekV2MoE)
from ..vars import *
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
OUTPUT_ARGS_LOGS = False
class DistributedArgs():
def __init__(self):
self._0_world_size = 32
self._1_rank = -1
self._2_group_id = 0
self._3_dev_info = []
def logs(self):
print("dist self._0_world_size = " , self._0_world_size)
print("dist self._1_rank = " , self._1_rank)
print("dist self._2_group_id = " , self._2_group_id)
print("dist self._3_dev_info = " , self._3_dev_info)
class AttenArgs():
def __init__(self):
self._a_hidden_states_norm_weight = []
self._0_merge_q_kv_weights = [] # 融合Q,KV Weights
self._1_merge_q_kv_scale_inv = [] # 融合Q,KV Scales
self._2_q_a_layernorm_weight = []
self._3_W_Q = []
self._4_W_Q_scales = []
self._5_W_UK = []
self._6_W_UK_scales = []
self._7_W_QR = []
self._8_W_QR_scales = []
self._9_kv_a_layernorm_weight = []
self._10_sin_cache = None
self._11_cos_cache = None
self._12_slot_mapping = None
self._13_kv_cache = None
self._14_block_tables = None
self._15_env_blk_grp_size = env_blk_grp_size
self._16_W_UV = []
self._17_W_UV_scales = []
self._18_o_proj_weight =[]
self._19_o_proj_weight_scale_inv = []
# mla params
self._20_seq_lens = []
self._21_sm_scale = 0.0
self._22_head_num = 128
def logs(self):
print("mla _20_seq_lens block size is:", self._20_seq_lens)
print("mla _21_sm_scale block size is:", self._21_sm_scale)
print("mla _22_head_num block size is:", self._22_head_num)
class MlpArgs():
def __init__(self):
#mlp params
self._0_mlp_rms_weight = []
self._1_mlp_w13 = []
self._2_mlp_w2 = []
self._3_mlp_w13_scale = []
self._4_mlp_w2_scale = []
self._5_mlp_w13_block_size = []
self._6_mlp_w2_block_size = []
def logs(self):
print("mlp _5_mlp_w13_block_size block size is:", self._5_mlp_w13_block_size)
print("mlp _6_mlp_w2_block_size block size is:", self._6_mlp_w2_block_size)
class MoeArgs():
def __init__(self):
#moe params
self._0_moe_rms_weight = []
self._1_moe_share_mlp_w13 = []
self._2_moe_share_mlp_w2 = []
self._3_moe_share_mlp_w13_scale = []
self._4_moe_share_mlp_w2_scale = []
self._5_moe_w13 = []
self._6_moe_w2 = []
self._7_moe_w13_scale = []
self._8_moe_w2_scale = []
self._9_gate_weight = []
self._10_moe_bias = []
self._11_moe_mlp_w13_block_size = []
self._12_moe_mlp_w2_block_size = []
self._13_moe_w13_block_size = []
self._14_moe_w2_block_size = []
def logs(self):
print("moe _11_moe_mlp_w13_block_size block size is:", self._11_moe_mlp_w13_block_size)
print("moe _12_moe_mlp_w2_block_size block size is:", self._12_moe_mlp_w2_block_size)
print("moe _13_moe_w13_block_size block size is:", self._13_moe_w13_block_size)
print("moe _14_moe_w2_block_size block size is:", self._14_moe_w2_block_size)
class WeightMapper():
def __init__(self):
self.attn_args = AttenArgs()
self.mlp_args = MlpArgs() # 3 mla+mlp
self.moe_args = MoeArgs() # 58 mla+moe
self.dist_args = DistributedArgs()
# 1. weights载入
# 2. dequant blocks 预计算
# 3. 参数缓存&提取
class DeepseekWeightCapture():
def __init__(self, layer: torch.nn.ModuleList,
start: int,
end: int):
self.layer_mlp = WeightMapper()
self.layer_moe = WeightMapper()
self.sin_cache_all = None
self.cos_cache_all = None
self.mlp_nums = 3
self.moe_nums = end - self.mlp_nums
self.start_idx = start
self.end_idx = end
for i in range(start, end):
if i < self.mlp_nums:
self.capture_deepseek_mla_attn_weights(layer[i], self.layer_mlp.attn_args)
self.capture_deepseek_mlp_weights(layer[i])
else:
self.capture_deepseek_mla_attn_weights(layer[i], self.layer_moe.attn_args)
self.capture_deepseek_moe_weights(layer[i])
if OUTPUT_ARGS_LOGS:
self.layer_mlp.attn_args.logs()
self.layer_mlp.mlp_args.logs()
self.layer_moe.attn_args.logs()
self.layer_moe.moe_args.logs()
from vllm.distributed import get_tp_group
tp_group = get_tp_group()
self.layer_mlp.dist_args._0_world_size = tp_group.world_size
self.layer_mlp.dist_args._1_rank = tp_group.rank_in_group
self.layer_mlp.dist_args._2_group_id = tp_group.group_id
self.layer_mlp.dist_args._3_dev_info = tp_group.rank_device_infos
self.layer_moe.dist_args._0_world_size = tp_group.world_size
self.layer_moe.dist_args._1_rank = tp_group.rank_in_group
self.layer_moe.dist_args._2_group_id = tp_group.group_id
self.layer_moe.dist_args._3_dev_info = tp_group.rank_device_infos
def capture_deepseek_mlp_weights(self, module: DeepseekV2DecoderLayer):
assert isinstance(module.mlp, DeepseekV2MLP)
mlp = module.mlp
rms_norm = module.post_attention_layernorm
w13_weight = mlp.gate_up_proj.weight
w2_weight = mlp.down_proj.weight
w13_wscale = mlp.gate_up_proj.weight_scale_inv
w13_iscale = mlp.gate_up_proj.input_scale
w2_wscale = mlp.down_proj.weight_scale_inv
w2_iscale = mlp.down_proj.input_scale
w13_block_size0, w13_block_size1 = mlp.gate_up_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = mlp.gate_up_proj.quant_method.scale_n, mlp.gate_up_proj.quant_method.scale_k
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
w13_block_size0 = w13_block_size0 // scale_n
w13_block_size1 = w13_block_size1 // scale_k
w2_block_size0, w2_block_size1 = mlp.down_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = mlp.down_proj.quant_method.scale_n, mlp.down_proj.quant_method.scale_k
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
w2_block_size0 = w2_block_size0 // scale_n
w2_block_size1 = w2_block_size1 // scale_k
self.layer_mlp.mlp_args._0_mlp_rms_weight.append(rms_norm.weight)
self.layer_mlp.mlp_args._1_mlp_w13.append(w13_weight)
self.layer_mlp.mlp_args._2_mlp_w2.append(w2_weight)
self.layer_mlp.mlp_args._3_mlp_w13_scale.append(w13_wscale)
self.layer_mlp.mlp_args._4_mlp_w2_scale.append(w2_wscale)
self.layer_mlp.mlp_args._5_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
self.layer_mlp.mlp_args._6_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
def capture_deepseek_moe_weights(self, module: DeepseekV2DecoderLayer):
assert isinstance(module.mlp, DeepseekV2MoE)
share_expert_layer = module.mlp.shared_experts
experts_layer = module.mlp.experts
rms_norm = module.post_attention_layernorm
gate = module.mlp.gate
w13_weight = share_expert_layer.gate_up_proj.weight
w2_weight = share_expert_layer.down_proj.weight
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
# w13_iscale = share_expert_layer.gate_up_proj.input_scale
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
# w2_iscale = share_expert_layer.down_proj.input_scale
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
w13_block_size0 = w13_block_size0 // scale_n
w13_block_size1 = w13_block_size1 // scale_k
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
w2_block_size0 = w2_block_size0 // scale_n
w2_block_size1 = w2_block_size1 // scale_k
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
block_size0, block_size1 = (
hidden_dims // hidden_blocks,
inter_dims // inter_blocks,
)
self.layer_moe.moe_args._0_moe_rms_weight.append(rms_norm.weight)
self.layer_moe.moe_args._1_moe_share_mlp_w13.append(w13_weight)
self.layer_moe.moe_args._2_moe_share_mlp_w2.append(w2_weight)
self.layer_moe.moe_args._3_moe_share_mlp_w13_scale.append(w13_wscale)
self.layer_moe.moe_args._4_moe_share_mlp_w2_scale.append(w2_wscale)
self.layer_moe.moe_args._5_moe_w13.append(experts_layer.w13_weight)
self.layer_moe.moe_args._6_moe_w2.append(experts_layer.w2_weight)
self.layer_moe.moe_args._7_moe_w13_scale.append(experts_layer.w13_weight_scale_inv)
self.layer_moe.moe_args._8_moe_w2_scale.append(experts_layer.w2_weight_scale_inv)
self.layer_moe.moe_args._9_gate_weight.append(gate.weight)
self.layer_moe.moe_args._10_moe_bias.append(gate.e_score_correction_bias)
self.layer_moe.moe_args._11_moe_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
self.layer_moe.moe_args._12_moe_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
self.layer_moe.moe_args._13_moe_w13_block_size = [block_size0, block_size1]
self.layer_moe.moe_args._14_moe_w2_block_size = [block_size0, block_size1]
def capture_deepseek_mla_attn_weights(self, module: DeepseekV2DecoderLayer,
weight_mapper: AttenArgs):
if(self.sin_cache_all is None):
self.sin_cache_all = module.self_attn.mla_attn.impl.rotary_emb.sin_cache
self.cos_cache_all = module.self_attn.mla_attn.impl.rotary_emb.cos_cache
weight_mapper._a_hidden_states_norm_weight.append(module.input_layernorm.weight)
fused_params = {}
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
for name, param in module.self_attn.q_a_proj.named_parameters():
fused_params['q_a_proj_' + name] = param
for name, param in module.self_attn.q_a_layernorm.named_parameters():
fused_params['q_a_layernorm_' + name] = param
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
for name, param in module.self_attn.kv_a_proj_with_mqa.named_parameters():
fused_params['kv_a_proj_' + name] = param
for name, param in module.self_attn.kv_a_layernorm.named_parameters():
fused_params['kv_a_layernorm_' + name] = param
for name, param in module.self_attn.o_proj.named_parameters():
fused_params['o_proj_' + name] = param
import os
self._15_env_blk_grp_size = env_blk_grp_size
# init sin,cos cache
mla_params = module.self_attn.mla_attn.impl.extract_weights()
fused_params = {**fused_params, **mla_params}
weight_mapper._0_merge_q_kv_weights.append(module.self_attn.merge_q_kv_weights)
weight_mapper._1_merge_q_kv_scale_inv.append(module.self_attn.merge_q_kv_scale_inv)
weight_mapper._2_q_a_layernorm_weight.append(fused_params['q_a_layernorm_weight'])
weight_mapper._3_W_Q.append(fused_params['W_Q'])
weight_mapper._4_W_Q_scales.append(fused_params['W_Q_scales'])
weight_mapper._5_W_UK.append(fused_params['W_UK'])
weight_mapper._6_W_UK_scales.append(fused_params['W_UK_scales'])
weight_mapper._7_W_QR.append(fused_params['W_QR'])
weight_mapper._8_W_QR_scales.append(fused_params['W_QR_scales'])
weight_mapper._9_kv_a_layernorm_weight.append(fused_params['kv_a_layernorm_weight'])
#weight_mapper._10_sin_cache.append(None)
#weight_mapper._11_cos_cache.append(None)
#weight_mapper._12_slot_mapping.append(None)
#weight_mapper._13_kv_cache.append(None)
#weight_mapper._14_block_tables.append(None)
# weight_mapper._15_env_blk_grp_size.append(None)
weight_mapper._16_W_UV.append(fused_params['W_UV'])
weight_mapper._17_W_UV_scales.append(fused_params['W_UV_scales'])
weight_mapper._18_o_proj_weight.append(fused_params['o_proj_weight'])
weight_mapper._19_o_proj_weight_scale_inv.append(fused_params['o_proj_weight_scale_inv'])
weight_mapper._20_seq_lens = None
weight_mapper._21_sm_scale = module.self_attn.scaling
weight_mapper._22_head_num = module.self_attn.num_heads // module.self_attn.o_proj.tp_size
# 可优化,在c++里面只用Tensor即可
def update_attn_args(self, seq_lens, slot_mapping, kv_caches_dense_layer, kv_caches_moe_layer, block_tables):
positions = [i - 1 for i in seq_lens]
cos_cache = [self.cos_cache_all[i] for i in positions]
sin_cache = [self.sin_cache_all[i] for i in positions]
self.layer_mlp.attn_args._10_sin_cache = sin_cache
self.layer_mlp.attn_args._11_cos_cache = cos_cache
self.layer_moe.attn_args._10_sin_cache = sin_cache
self.layer_moe.attn_args._11_cos_cache = cos_cache
self.layer_mlp.attn_args._20_seq_lens = seq_lens
self.layer_moe.attn_args._20_seq_lens = seq_lens
self.layer_mlp.attn_args._13_kv_cache = kv_caches_dense_layer
self.layer_moe.attn_args._13_kv_cache = kv_caches_moe_layer
self.layer_mlp.attn_args._12_slot_mapping = slot_mapping
self.layer_mlp.attn_args._14_block_tables = block_tables
self.layer_moe.attn_args._12_slot_mapping = slot_mapping
self.layer_moe.attn_args._14_block_tables = block_tables
# for i in range(self.mlp_nums):
# if i < self.end_idx:
# self.layer_mlp.attn_args._12_slot_mapping[i] = slot_mapping
# self.layer_mlp.attn_args._14_block_tables[i] = block_tables
# for i in range(self.moe_nums):
# if i < self.end_idx:
# self.layer_moe.attn_args._12_slot_mapping[i] = slot_mapping
# self.layer_moe.attn_args._14_block_tables[i] = block_tables
def logs(self):
print("current layer mlp attn: \n")
self.layer_mlp.attn_args.logs()
self.layer_mlp.dist_args.logs()
class DeepseekMTPWegitCapture():
# 相比DeepSeek Weight Capture MTP只有1层DeepseekDecoderLayer, 且是MOE的layer
def __init__(self, layer: torch.nn.Module):
self.layer_moe = WeightMapper()
self.sin_cache_all = None
self.cos_cache_all = None
self.capture_deepseek_mla_attn_weights(layer, self.layer_moe.attn_args)
self.capture_deepseek_moe_weights(layer)
if OUTPUT_ARGS_LOGS:
self.layer_moe.attn_args.logs()
self.layer_moe.moe_args.logs()
from vllm.distributed import get_tp_group
tp_group = get_tp_group()
self.layer_moe.dist_args._0_world_size = tp_group.world_size
self.layer_moe.dist_args._1_rank = tp_group.rank_in_group
self.layer_moe.dist_args._2_group_id = tp_group.group_id
self.layer_moe.dist_args._3_dev_info = tp_group.rank_device_infos
def capture_deepseek_moe_weights(self, module: DeepseekV2DecoderLayer):
assert isinstance(module.mlp, DeepseekV2MoE)
share_expert_layer = module.mlp.shared_experts
experts_layer = module.mlp.experts
rms_norm = module.post_attention_layernorm
gate = module.mlp.gate
w13_weight = share_expert_layer.gate_up_proj.weight
w2_weight = share_expert_layer.down_proj.weight
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
w13_block_size0 = w13_block_size0 // scale_n
w13_block_size1 = w13_block_size1 // scale_k
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
w2_block_size0 = w2_block_size0 // scale_n
w2_block_size1 = w2_block_size1 // scale_k
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
block_size0, block_size1 = (
hidden_dims // hidden_blocks,
inter_dims // inter_blocks,
)
self.layer_moe.moe_args._0_moe_rms_weight.append(rms_norm.weight)
self.layer_moe.moe_args._1_moe_share_mlp_w13.append(w13_weight)
self.layer_moe.moe_args._2_moe_share_mlp_w2.append(w2_weight)
self.layer_moe.moe_args._3_moe_share_mlp_w13_scale.append(w13_wscale)
self.layer_moe.moe_args._4_moe_share_mlp_w2_scale.append(w2_wscale)
self.layer_moe.moe_args._5_moe_w13.append(experts_layer.w13_weight)
self.layer_moe.moe_args._6_moe_w2.append(experts_layer.w2_weight)
self.layer_moe.moe_args._7_moe_w13_scale.append(experts_layer.w13_weight_scale_inv)
self.layer_moe.moe_args._8_moe_w2_scale.append(experts_layer.w2_weight_scale_inv)
self.layer_moe.moe_args._9_gate_weight.append(gate.weight)
self.layer_moe.moe_args._10_moe_bias.append(gate.e_score_correction_bias)
self.layer_moe.moe_args._11_moe_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
self.layer_moe.moe_args._12_moe_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
self.layer_moe.moe_args._13_moe_w13_block_size = [block_size0, block_size1]
self.layer_moe.moe_args._14_moe_w2_block_size = [block_size0, block_size1]
def capture_deepseek_mla_attn_weights(self, module: DeepseekV2DecoderLayer,
weight_mapper: AttenArgs):
if(self.sin_cache_all is None):
self.sin_cache_all = module.self_attn.mla_attn.impl.rotary_emb.sin_cache
self.cos_cache_all = module.self_attn.mla_attn.impl.rotary_emb.cos_cache
weight_mapper._a_hidden_states_norm_weight.append(module.input_layernorm.weight)
fused_params = {}
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
for name, param in module.self_attn.q_a_proj.named_parameters():
fused_params['q_a_proj_' + name] = param
for name, param in module.self_attn.q_a_layernorm.named_parameters():
fused_params['q_a_layernorm_' + name] = param
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
for name, param in module.self_attn.kv_a_proj_with_mqa.named_parameters():
fused_params['kv_a_proj_' + name] = param
for name, param in module.self_attn.kv_a_layernorm.named_parameters():
fused_params['kv_a_layernorm_' + name] = param
for name, param in module.self_attn.o_proj.named_parameters():
fused_params['o_proj_' + name] = param
import os
self._15_env_blk_grp_size = env_blk_grp_size
# init sin,cos cache
mla_params = module.self_attn.mla_attn.impl.extract_weights()
fused_params = {**fused_params, **mla_params}
weight_mapper._0_merge_q_kv_weights.append(module.self_attn.merge_q_kv_weights)
weight_mapper._1_merge_q_kv_scale_inv.append(module.self_attn.merge_q_kv_scale_inv)
weight_mapper._2_q_a_layernorm_weight.append(fused_params['q_a_layernorm_weight'])
weight_mapper._3_W_Q.append(fused_params['W_Q'])
weight_mapper._4_W_Q_scales.append(fused_params['W_Q_scales'])
weight_mapper._5_W_UK.append(fused_params['W_UK'])
weight_mapper._6_W_UK_scales.append(fused_params['W_UK_scales'])
weight_mapper._7_W_QR.append(fused_params['W_QR'])
weight_mapper._8_W_QR_scales.append(fused_params['W_QR_scales'])
weight_mapper._9_kv_a_layernorm_weight.append(fused_params['kv_a_layernorm_weight'])
#weight_mapper._10_sin_cache.append(None)
#weight_mapper._11_cos_cache.append(None)
#weight_mapper._12_slot_mapping.append(None)
#weight_mapper._13_kv_cache.append(None)
#weight_mapper._14_block_tables.append(None)
# weight_mapper._15_env_blk_grp_size.append(None)
weight_mapper._16_W_UV.append(fused_params['W_UV'])
weight_mapper._17_W_UV_scales.append(fused_params['W_UV_scales'])
weight_mapper._18_o_proj_weight.append(fused_params['o_proj_weight'])
weight_mapper._19_o_proj_weight_scale_inv.append(fused_params['o_proj_weight_scale_inv'])
weight_mapper._20_seq_lens = None
weight_mapper._21_sm_scale = module.self_attn.scaling
weight_mapper._22_head_num = module.self_attn.num_heads // module.self_attn.o_proj.tp_size
# 可优化,在c++里面只用Tensor即可
def update_attn_args(self, seq_lens, slot_mapping, kv_caches_dense_layer, kv_caches_moe_layer, block_tables):
positions = [i - 1 for i in seq_lens]
cos_cache = [self.cos_cache_all[i] for i in positions]
sin_cache = [self.sin_cache_all[i] for i in positions]
self.layer_moe.attn_args._10_sin_cache = sin_cache
self.layer_moe.attn_args._11_cos_cache = cos_cache
self.layer_moe.attn_args._20_seq_lens = seq_lens
self.layer_moe.attn_args._13_kv_cache = kv_caches_moe_layer
self.layer_moe.attn_args._12_slot_mapping = slot_mapping
self.layer_moe.attn_args._14_block_tables = block_tables
def logs(self):
print("current layer mlp attn: \n")
self.layer_mlp.attn_args.logs()
self.layer_mlp.dist_args.logs()

View File

@@ -0,0 +1,154 @@
import torch
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeDecoderLayer,
Qwen3MoeMLP)
class Qwen3Moe_DistributedArgs():
def __init__(self):
self._0_world_size = 32
self._1_rank = -1
self._2_group_id = 0
self._3_dev_info = []
def __repr__(self):
dist_infos = f"[dist] world_size = {self._0_world_size} \n" \
+ f"[dist] rank = {self._1_rank} \n" \
+ f"[dist] group_id = {self._2_group_id} \n" \
+ f"[dist] dev_info = {self._3_dev_info}"
return dist_infos
class Qwen3Moe_AttenArgs():
def __init__(self):
self._0_input_layernorm_weight = []
self._1_qkv_proj_weight = [] #
self._2_qkv_proj_weight_scale = []
self._3_qkv_proj_bias = []
self._4_qkv_proj_qzeros = []
self._5_q_norm_weight = []
self._6_k_norm_weight = []
self._7_sin_cache = None
self._8_cos_cache = None
self._9_slot_mapping = None
self._10_kv_cache = None
self._11_block_tables = None
self._12_block_group_size = None
self._13_o_proj_weight = []
self._14_o_proj_weight_scale = []
self._15_o_proj_bias = []
self._16_o_proj_qzeros = []
self._17_seq_lens = None
self._18_sm_scale =None
self._19_num_attention_heads = None
self._20_num_key_value_heads = None
def __repr__(self):
attn_infos = "[qwen attn] 21 args" \
+ f"[qwen attn] weight counts: {len(self._0_input_layernorm_weight)}"
return attn_infos
class Qwen3Moe_MoeArgs():
def __init__(self):
#moe params
self._0_rms_norm_weight = []
self._1_w13_weight = []
self._2_w2_weight = []
self._3_w13_weight_scale_inv = []
self._4_w2_weight_scale_inv = []
self._5_gate_weight = []
self._6_w13_block_size = None
self._7_w2_block_size = None
def __repr__(self):
moe_infos = f"[moe] w13_block_size: {self._6_w13_block_size}" \
+ f"[moe] w2_block_size: {self._7_w2_block_size}" \
+ f"[moe] weight counts: {len(self._1_w13_weight)}"
return moe_infos
class Qwen3Moe_WeightMapper:
def __init__(self):
self.attn_args = Qwen3Moe_AttenArgs()
self.moe_args = Qwen3Moe_MoeArgs()
self.dist_args = Qwen3Moe_DistributedArgs()
class Qwen3Moe_WeightCapture():
def __init__(self, layers: torch.nn.ModuleList,
start: int,
end: int):
self.layer_mapper = Qwen3Moe_WeightMapper()
# qwen3 only support fp8 now
self.support_fused_weights = False
for i in range(start, end):
layer = layers[i]
self.capture_attn_weights(layer)
self.capture_moe_weights(layer)
# 注册 多卡环境信息
from vllm.distributed import get_tp_group
tp_group = get_tp_group()
self.layer_mapper.dist_args._0_world_size = tp_group.world_size
self.layer_mapper.dist_args._1_rank = tp_group.rank_in_group
self.layer_mapper.dist_args._2_group_id = tp_group.group_id
self.layer_mapper.dist_args._3_dev_info = tp_group.rank_device_infos
def capture_attn_weights(self, layer):
from vllm_vacc.vllm.model_executor.models.qwen3_moe import set_fused_params
# 注册融合算子
fused_params = {}
fused_params['input_layernorm_weight'] = layer.input_layernorm.weight
fused_params['q_norm_weight'] = layer.self_attn.q_norm.weight
fused_params['k_norm_weight'] = layer.self_attn.k_norm.weight
set_fused_params(fused_params, layer.self_attn.qkv_proj.quant_method, layer.self_attn.qkv_proj, 'qkv_proj')
set_fused_params(fused_params, layer.self_attn.o_proj.quant_method, layer.self_attn.o_proj, 'o_proj')
self.support_fused_weights = hasattr(layer.mlp.experts.quant_method, 'quant_config') and hasattr(layer.mlp.experts.quant_method.quant_config, 'weight_block_size')
if not hasattr(layer.self_attn, "fused_params"):
layer.self_attn.fused_params = fused_params
self.layer_mapper.attn_args._0_input_layernorm_weight.append(fused_params['input_layernorm_weight'])
self.layer_mapper.attn_args._1_qkv_proj_weight.append(fused_params['qkv_proj_weight'])
self.layer_mapper.attn_args._2_qkv_proj_weight_scale.append(fused_params['qkv_proj_weight_scale'])
self.layer_mapper.attn_args._3_qkv_proj_bias.append(fused_params['qkv_proj_bias'])
self.layer_mapper.attn_args._4_qkv_proj_qzeros.append(fused_params['qkv_proj_qzeros'])
self.layer_mapper.attn_args._5_q_norm_weight.append(fused_params['q_norm_weight'])
self.layer_mapper.attn_args._6_k_norm_weight.append(fused_params['k_norm_weight'])
# self.layer_mapper.attn_args._7_sin_cache
# self.layer_mapper.attn_args._8_cos_cache
# self.layer_mapper.attn_args._9_slot_mapping
# self.layer_mapper.attn_args._10_kv_cache
# self.layer_mapper.attn_args._11_block_tables
# self.layer_mapper.attn_args._12_block_group_size
self.layer_mapper.attn_args._13_o_proj_weight.append(fused_params['o_proj_weight'])
self.layer_mapper.attn_args._14_o_proj_weight_scale.append(fused_params['o_proj_weight_scale'])
self.layer_mapper.attn_args._15_o_proj_bias.append(fused_params['o_proj_bias'])
self.layer_mapper.attn_args._16_o_proj_qzeros.append(fused_params['o_proj_qzeros'])
# self.layer_mapper.attn_args._17_seq_lens
self.layer_mapper.attn_args._18_sm_scale = layer.self_attn.scaling
self.layer_mapper.attn_args._19_num_attention_heads = layer.self_attn.total_num_heads
self.layer_mapper.attn_args._20_num_key_value_heads = layer.self_attn.total_num_kv_heads
def capture_moe_weights(self, layer: Qwen3MoeDecoderLayer):
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock
quant_method = layer.mlp.experts.quant_method if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock) \
else layer.mlp.down_proj.quant_method
if not isinstance(quant_method, MoeWNA16Method):
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import recompute_moe_layer_blocksize
recompute_moe_layer_blocksize(layer.mlp.experts)
self.layer_mapper.moe_args._0_rms_norm_weight.append(layer.post_attention_layernorm.weight)
self.layer_mapper.moe_args._1_w13_weight.append(layer.mlp.experts.w13_weight)
self.layer_mapper.moe_args._2_w2_weight.append(layer.mlp.experts.w2_weight)
self.layer_mapper.moe_args._3_w13_weight_scale_inv.append(layer.mlp.experts.w13_weight_scale_inv)
self.layer_mapper.moe_args._4_w2_weight_scale_inv.append(layer.mlp.experts.w2_weight_scale_inv)
self.layer_mapper.moe_args._5_gate_weight.append(layer.mlp.gate.weight)
self.layer_mapper.moe_args._6_w13_block_size = layer.mlp.experts.w13_block_size
self.layer_mapper.moe_args._7_w2_block_size = layer.mlp.experts.w2_block_size
else:
self.layer_mapper.moe_args._0_rms_norm_weight.append(layer.post_attention_layernorm.weight)
self.layer_mapper.moe_args._1_w13_weight.append(layer.mlp.experts.w13_qweight)
self.layer_mapper.moe_args._2_w2_weight.append(layer.mlp.experts.w2_qweight)
self.layer_mapper.moe_args._3_w13_weight_scale_inv.append(layer.mlp.experts.w13_scales)
self.layer_mapper.moe_args._4_w2_weight_scale_inv.append(layer.mlp.experts.w2_scales)
self.layer_mapper.moe_args._5_gate_weight.append(layer.mlp.gate.weight)
self.layer_mapper.moe_args._6_w13_block_size = layer.mlp.experts.w13_block_size
self.layer_mapper.moe_args._7_w2_block_size = layer.mlp.experts.w2_block_size

Some files were not shown because too many files have changed in this diff Show More