This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

@@ -0,0 +1,8 @@
import torch
def SiluAndMul_forward_vacc(self, x: torch.Tensor) -> torch.Tensor:
return torch.vacc.swiglu(x)
def QuickGELU_forward_vacc(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)

View File

@@ -0,0 +1,82 @@
import torch
from typing import Optional, Tuple
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
dtype = hidden_states.dtype
hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, -1)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)
return topk_weights, selected_experts
def grouped_topk_with_itype(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
try:
from torch_vacc.vacc.custom_ops import fused_moe_preprocess
return fused_moe_preprocess(gating_output, e_score_correction_bias)
except Exception as e:
print(f"fused group topk run fail, now use unfused group topk: {e}")
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (scores.view(num_token, num_expert_group,
-1).topk(2, dim=-1)[0].sum(dim=-1))
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(hidden_states.dtype), topk_ids.to(torch.int32)

View File

@@ -0,0 +1,715 @@
# SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod
from enum import Enum
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op
if current_platform.is_cuda_alike():
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore
if current_platform.is_tpu():
# the iterative moe implementation is used until the moe_pallas is fixed
from .moe_torch_iterative import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
def FusedMoE_init_(
self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
is_sequence_parallel=False,
):
from vllm.model_executor.layers.fused_moe import FusedMoE
super(FusedMoE, self).__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
vllm_config = get_current_vllm_config()
# FIXME (varun): We should have a better way of inferring the activation
# datatype. This works for now as the tensor datatype entering the MoE
# operation is typically unquantized (i.e. float16/bfloat16).
if vllm_config.model_config is not None:
moe_in_dtype = vllm_config.model_config.dtype
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
moe_in_dtype = params_dtype
tp_size_ = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
dp_size_ = (dp_size
if dp_size is not None else get_dp_group().world_size)
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
tp_size_=tp_size_,
dp_size_=dp_size_,
vllm_parallel_config=vllm_config.parallel_config))
self.global_num_experts = num_experts + num_redundant_experts
# For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.enable_eplb = enable_eplb
self.expert_load_view: Optional[torch.Tensor] = None
self.logical_to_physical_map: Optional[torch.Tensor] = None
self.logical_replica_count: Optional[torch.Tensor] = None
# Determine expert maps
if self.use_ep:
if self.enable_eplb:
assert self.global_num_experts % self.ep_size == 0, \
"EPLB currently only supports even distribution of " \
"experts across ranks."
else:
assert num_redundant_experts == 0, \
"Redundant experts are only supported with EPLB."
self.local_num_experts, self.expert_map = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts)
else:
self.local_num_experts, self.expert_map = (self.global_num_experts,
None)
self.top_k = top_k
assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
)
self.moe_config = moe
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError("EPLB is only supported for FP8 "
"quantization for now.")
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
# self.scale_n = self.quant_method.scale_n
# self.scale_k = self.quant_method.scale_k
self.scale_n = 1
self.scale_k = 1
self.scale_n_prefill = 1
if hasattr(self.quant_method, "scale_n") and hasattr(self.quant_method, "scale_k"):
self.scale_n = self.quant_method.scale_n
self.scale_k = self.quant_method.scale_k
if hasattr(self.quant_method, "scale_n_prefill"):
self.scale_n_prefill = self.quant_method.scale_n_prefill
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
# Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros(
(moe.max_num_tokens, num_experts),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
class FusedMoE(torch.nn.Module):
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int,
expert_id=0):
#print("w13 shape is:", expert_data.shape, loaded_weight.shape)
if self.scale_n > 1 and len(loaded_weight.shape) == 2 and torch.finfo(loaded_weight.dtype).bits > 8:
n_v, k_v = loaded_weight.shape
loaded_weight = loaded_weight.reshape(1, n_v, 1, k_v) #[1, n, 1, k]
if self.scale_n_prefill != self.scale_n and hasattr(self, 'w13_weight_scale_inv_prefill'):
loaded_weight0 = loaded_weight.repeat(self.scale_n_prefill,1,self.scale_k,1).permute(1,0,3,2).reshape([n_v * self.scale_n_prefill, k_v * self.scale_k])
shard_size = self.w13_weight_scale_inv_prefill.data[expert_id].shape[shard_dim] // 2
loaded_weight0 = loaded_weight0.narrow(shard_dim, shard_size * tp_rank, shard_size)
if shard_id == "w1":
self.w13_weight_scale_inv_prefill.data[expert_id, :loaded_weight0.shape[0], :loaded_weight0.shape[1]] = loaded_weight0
elif shard_id == "w3":
self.w13_weight_scale_inv_prefill.data[expert_id, -loaded_weight0.shape[0]:, -loaded_weight0.shape[1]:] = loaded_weight0
else:
raise ValueError('error shard_id: ',shard_id)
loaded_weight = loaded_weight.repeat(self.scale_n,1,self.scale_k,1).permute(1,0,3,2).reshape([n_v * self.scale_n, k_v * self.scale_k])
#print("w13 repeat shape is:", expert_data.shape, loaded_weight.shape)
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False,
expert_id=0):
#print("w2 shape is:", expert_data.shape, loaded_weight.shape)
if self.scale_n > 1 and len(loaded_weight.shape) == 2 and torch.finfo(loaded_weight.dtype).bits > 8:
n_v, k_v = loaded_weight.shape
loaded_weight = loaded_weight.reshape(1, n_v, 1, k_v) #[1, n, 1, k]
if self.scale_n_prefill != self.scale_n and hasattr(self, 'w2_weight_scale_inv_prefill'):
loaded_weight0 = loaded_weight.repeat(self.scale_k,1,self.scale_n_prefill,1).permute(1,0,3,2).reshape([n_v * self.scale_k, k_v * self.scale_n_prefill])
shard_size = self.w2_weight_scale_inv_prefill.data[expert_id].shape[shard_dim]
if not load_full:
loaded_weight0 = loaded_weight0.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
self.w2_weight_scale_inv_prefill.data[expert_id] = loaded_weight0
#print("loaded_weight:", loaded_weight.shape)
loaded_weight = loaded_weight.repeat(self.scale_k,1,self.scale_n,1).permute(1,0,3,2).reshape([n_v * self.scale_k, k_v * self.scale_n])
#print("w2 repeat shape is:", expert_data.shape, loaded_weight.shape)
#if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
def weight_loader(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int, return_success=True) -> None:
if self.quant_config and self.quant_config.get_name() == "mxfp4":
# (FIXME) for gpt-oss all experts are combined
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
param.data[:, :dim1].copy_(loaded_weight)
else:
dim1 = loaded_weight.shape[1]
dim2 = loaded_weight.shape[2]
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return True if return_success else None
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return False if return_success else None
quant_method_name = self.quant_method.__class__.__name__
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
if self.quant_method.__class__.__name__ in (
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod"):
loaded_weight = loaded_weight.t().contiguous()
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
# WEIGHT_SCALE_SUPPORTED = [
# e.value for e in FusedMoeWeightScaleSupported
# ]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
param.data.copy_(loaded_weight)
return True if return_success else None
# Case for BitsAndBytes
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
if use_bitsandbytes_4bit:
shard_dim = 0
expert_data = param.data[expert_id]
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
# BNB inflight quantization has already sharded the weights
full_load = True
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank,
load_full=full_load,
)
return True if return_success else None
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = int(not shard_dim)
full_load = len(loaded_weight.shape) == 3
if full_load:
shard_dim += 1
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
final_shape = list(loaded_weight.shape)
if shard_id in ["w1", "w3"]:
final_shape[1] *= 2
final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size
param.materialize(final_shape, dtype=loaded_weight.dtype)
expert_data = param.data if full_load else param.data[expert_id]
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}")
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return True if return_success else None
# Case g_idx
if "g_idx" in weight_name:
self._load_g_idx(shard_dim=0,
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
return True if return_success else None
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
if "ModelOpt" in quant_method_name:
# Determine per-tensor weight scale patterns based on variant
# Use the dedicated method instead of brittle string matching
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
)
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
# weights scales.
# Input scales are always per-tensor.
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
# "weight_scale" for per-tensor scales.
is_per_tensor = ("weight_scale_2" in weight_name
if uses_weight_scale_2 else "weight_scale"
in weight_name) or "input_scale" in weight_name
if is_per_tensor:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
return True if return_success else None
# If the weight is w13_weight_scale and w13_weight_scales are
# combined into single loaded_weight, call
# _load_combined_w13_weight_scale() to load it.
# This is checked by comparing the hidden_out dims of the
# loaded_weight and the param.
if "w13_weight_scale" in weight_name:
loaded_weight_hidden_out = loaded_weight.shape[-2]
param_hidden_out = param.data.shape[-2] * self.tp_size
if loaded_weight_hidden_out == param_hidden_out:
self._load_combined_w13_weight_scale(
shard_dim=shard_dim,
loaded_weight=loaded_weight,
param=param,
tp_rank=self.tp_rank,
)
return True if return_success else None
# For other weights, call _load_model_weight_or_group_weight_scale()
# to load it.
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
return True if return_success else None
# Case weight scales, zero_points and offset
if ("scale" in weight_name or "zero" in weight_name
or "offset" in weight_name):
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
]:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank,
load_full_w2=getattr(param, "load_full_w2", False),
expert_id=expert_id)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
else:
WEIGHT_SCALE_SUPPORTED = [
e.value for e in FusedMoeWeightScaleSupported
]
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
return True if return_success else None
# Case weight_shape
if "weight_shape" in weight_name:
# only required by compressed-tensors
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return True if return_success else None
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
return True if return_success else None
return False if return_success else None
def _load_model_weight_or_group_weight_scale(self,
shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full_w2: bool = False,
expert_id: int = 0):
"""
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2":
# In the case where we have actorder/g_idx, we do not partition the
# w2 scales, as indicated by `load_full` argument, for all tp cases
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
load_full=load_full_w2,
expert_id=expert_id)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
expert_id=expert_id)
class UnquantizedFusedMoEMethod():
def forward_vacc(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
hidden_states = x
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
dtype = hidden_states.dtype
intermediate_size = layer.w2_weight.shape[-1]
gating_output=router_logits
hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, global_num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
final_hidden_states = torch.zeros_like(hidden_states)
sel_experts = topk_ids.shape[1]
if hidden_states.shape[0] == 1:
for id in range(sel_experts):
expert_idx = topk_ids[0][id]
expert_w1 = layer.w13_weight[expert_idx].contiguous()
expert_w2 = layer.w2_weight[expert_idx].contiguous()
expert_weights = topk_weights[0][id].to(hidden_states.dtype)
x = hidden_states
x = F.linear(x, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
current_hidden_states = current_hidden_states.to(x.dtype)
final_hidden_states += current_hidden_states
else:
for expert_idx in range(global_num_experts):
# topk_ids [tokens, experts] => sample:[10, 8]
# expert_mask [tokens, experts] => sample:[10, 8]
expert_mask = topk_ids == expert_idx
idx = torch.where(expert_mask)[0]
if idx.numel() == 0:
continue
expert_w1 = layer.w13_weight[expert_idx].contiguous()
expert_w2 = layer.w2_weight[expert_idx].contiguous()
# [seq, experts]
expert_weights = (
topk_weights.masked_select(expert_mask)
.unsqueeze(1)
.to(hidden_states.dtype)
)
x = hidden_states[idx]
x = F.linear(x, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
current_hidden_states = current_hidden_states.to(x.dtype)
# final_hidden_states[idx] += current_hidden_states
final_hidden_states.index_add_(0, idx, current_hidden_states)
return final_hidden_states.view(orig_shape) # type: ignore
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
custom_forward = self.forward
if x.device.type == "vacc":
custom_forward = UnquantizedFusedMoEMethod.forward_vacc
return custom_forward(
x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
global_num_experts=global_num_experts,
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)

View File

@@ -0,0 +1,32 @@
from typing import Optional, Tuple, Union
import torch
def RMSNorm_forward_vacc(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if residual is not None:
# x = x + residual
# residual = x
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
# x_var=x_var.unsqueeze(0)
# out = torch.vacc.rms_norm(x_var,self.weight,self.variance_epsilon)
# if residual is None:
# return out.squeeze(0)
# else:
# return out.squeeze(0), residual
out = torch.vacc.fused_residual_rmsnorm(x_var, self.weight, residual, self.variance_epsilon, x_var, residual)
return out

View File

@@ -0,0 +1,465 @@
import itertools
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
BlockQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter)
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
WEIGHT_LOADER_V2_SUPPORTED,
LinearBase,
RowParallelLinear)
def ReplicatedLinear__init__(self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super(ReplicatedLinear,self).__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.scale_k = 1 # quant_block_k 128 需要除以 scale_k 如设置为2 即 quant_block_k 是 64
self.scale_k_slice = 1
self.scale_n = 1
self.scale_n_slice = 1
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
gcd_value = quant_config.weight_block_size[1]
import math
if input_size % quant_config.weight_block_size[1]:
gcd_value = math.gcd(input_size % quant_config.weight_block_size[1], quant_config.weight_block_size[1])
self.scale_k =self.scale_k * quant_config.weight_block_size[1] // gcd_value
self.scale_k_slice = input_size // gcd_value
if output_size % quant_config.weight_block_size[0]:
gcd_value = math.gcd(output_size % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
self.scale_n = self.scale_n * quant_config.weight_block_size[0] // gcd_value
self.scale_n_slice = output_size // gcd_value
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
scale_k = self.scale_k,
scale_n = self.scale_n,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def ReplicatedLinear_weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
if self.scale_k > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_k, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,2,0).reshape([loaded_weight.shape[1], -1])[:, :self.scale_k_slice]
#[1,n,k] -> [scale_k,n,k] -> [n,k,scale_k] -> [n, k*scale_k]
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[2]])[:self.scale_n_slice]
assert param.size() == loaded_weight.size(), f'{param.size()}, {loaded_weight.size()}'
param.data.copy_(loaded_weight)
def refine_block(block_size:list[int],
weight_size:list[int],
dim:int=0,
pingpong_size:int = 2.5*1024*1024, #bytes
core_number:int = 4,
data_type:int = 2, #bfloat16
max_iter_number:int = 2):
'''
对于不均匀分core 需要每个core <= 2.5M 才能保证可以pingpong,
core间相差数量为 block_size[dim] * weight_size[1-dim]
缩小block_size可以减小core间差距使得更平均一些直到大core数据量不超
如果均匀分core已经超了或者没有超就没必要调整
'''
if dim < 0:
dim = 2 + dim
pingpong_size = pingpong_size / data_type # number of data
block_size_refine = block_size[dim]
all_block_number = weight_size[dim] // block_size_refine
if all_block_number % core_number == 0:
#均分,这种情况不管有没有超,都无需调整
return block_size_refine
block_number_tiny = all_block_number // core_number
block_number_big = all_block_number // core_number + 1
if block_number_tiny * block_size_refine * weight_size[1-dim] >= pingpong_size or \
block_number_big * block_size_refine * weight_size[1-dim] <= pingpong_size :
# 小的已经超了,无法再调整了
# 大的没有超,无需调整
return block_size_refine
all_block_number_tmp = all_block_number
block_size_refine_tmp = block_size_refine
for iter_index in range(max_iter_number):
all_block_number_tmp = all_block_number_tmp * 2
block_size_refine_tmp = block_size_refine_tmp // 2
if all_block_number_tmp % core_number == 0:
block_number_tiny = all_block_number // core_number
if block_number_tiny * block_size_refine_tmp * weight_size[1-dim] <= pingpong_size:
return block_size_refine_tmp
else:
#均分还是超了,无需调整
return block_size_refine
else:
block_number_big = all_block_number_tmp // core_number + 1
if block_number_big * block_size_refine_tmp * weight_size[1-dim] <= pingpong_size:
return block_size_refine_tmp
return block_size_refine
def ColumnParallelLinear__init__(self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,):
# Divide the weight matrix along the last dimension.
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
super(ColumnParallelLinear,self).__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
self.scale_n = 1
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
gcd_value = quant_config.weight_block_size[0]
import math
if hasattr(self, "output_sizes"):
# 对于Merge类型的ColumnParallelLinear来说需要根据每个Part Linear的shape去计算最小公约数
output_size_no_merge = self.output_partition_sizes
block_values = [o % quant_config.weight_block_size[0] for o in output_size_no_merge]
is_gcd_recompute = sum(block_values)
if is_gcd_recompute:
import math
block_values.append(quant_config.weight_block_size[0])
gcd_value = math.gcd(*block_values)
# Notice:
# 这儿对于非对齐的Part-Weight 可能需要验证一下流程
# 对于DeepSeek来说仅存在于MLP&MOE中的MergeColumnLinear都是Shape一致的PartWeight
# 对于QWen3来说会存在QKVColumnLinear是Shape不一致的PartWeight但是由于QWen3当下的切分方案对于gcd_value无感无需重计算所以暂时不会进来
if hasattr(self, "output_sizes") and len(output_size_no_merge) == 2 and output_size_no_merge[0] == output_size_no_merge[1]:
#only refine mlp w13
gcd_value = refine_block([gcd_value, quant_config.weight_block_size[1]], [output_size_no_merge[0], input_size])
self.scale_n =self.scale_n * quant_config.weight_block_size[0] // gcd_value
else:
# 对于非Merge的ColumnParallelLinear来说, 仅仅根据当下shape去计算最小公约数
output_size_no_merge = self.output_size_per_partition
is_gcd_recompute = output_size_no_merge % quant_config.weight_block_size[0]
if is_gcd_recompute:
gcd_value = math.gcd(output_size_no_merge % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
self.scale_n =self.scale_n * quant_config.weight_block_size[0] // gcd_value
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
scale_n = self.scale_n,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
def ColumnParallelLinear_weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[-1]])
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,n,k] -> [n*scale_n, k]
param.load_column_parallel_weight(loaded_weight=loaded_weight)
class MergedColumnParallelLinear(ColumnParallelLinear):
def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[-1]])
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,n,k] -> [n*scale_n, k]
if self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
if self.quant_method.scale_k > 1 and len(loaded_weight.shape) == 2 and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]:
loaded_weight = loaded_weight.unsqueeze(1) #[k,1,n]
loaded_weight = loaded_weight.expand(loaded_weight.shape[0], self.quant_method.scale_k, loaded_weight.shape[2]).reshape([-1, loaded_weight.shape[2]])
#[k,1,n] -> [k,scale_k,n]] -> [k*scale_k, n]
if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=0)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id < len(self.output_sizes)
tp_size = get_tensor_model_parallel_world_size()
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0] // self.scale_n, weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n) // tp_size
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n // tp_size)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
def RowParallelLinear__init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
super(RowParallelLinear, self).__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None
self.scale_k = 1 # quant_block_k 128 需要除以 scale_k 如设置为2 即 quant_block_k 是 64
self.scale_n = 1
self.scale_n_slice = 1
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
gcd_value = quant_config.weight_block_size[1]
import math
if self.input_size_per_partition % quant_config.weight_block_size[1]:
gcd_value = math.gcd(self.input_size_per_partition % quant_config.weight_block_size[1], quant_config.weight_block_size[1])
self.scale_k =self.scale_k * quant_config.weight_block_size[1] // gcd_value
if output_size % quant_config.weight_block_size[0]:
gcd_value = math.gcd(output_size % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
self.scale_n = self.scale_n * quant_config.weight_block_size[0] // gcd_value
self.scale_n_slice = output_size // gcd_value
# N = 576, block = 128, n方向scale 扩充需要知道两个信息: 1.拷贝多少份 scale_n; 2. slice 有效的 scale_n_slice
# scale = [s0,s1,s2,s3,s4] 拷贝scale_n=2份
# scale = [s0,s0,s1,s1,s2,s2,s3,s3,s4,s4]slice scale_n_slice=9份 =>[s0,s0,s1,s1,s2,s2,s3,s3,s4]
if self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
gcd_value = quant_config.group_size
import math
if self.input_size_per_partition % quant_config.group_size:
gcd_value = math.gcd(self.input_size_per_partition % quant_config.group_size, quant_config.group_size)
self.quant_method.scale_k = self.quant_method.scale_k * quant_config.group_size // gcd_value
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size],
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
scale_k = self.scale_k,
scale_n = self.scale_n,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def RowParallelLinear_weight_loader_v2_vacc(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
if self.scale_k > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_k, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,2,0).reshape([loaded_weight.shape[1], -1])
#[1,n,k] -> [scale_k,n,k] -> [n,k,scale_k] -> [n, k*scale_k]
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[2]])[:self.scale_n_slice]
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,k] -> [n*scale_n,k]
elif self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
# broadcast scale TODO: broadcast zero
if self.quant_method.scale_k > 1 and len(loaded_weight.shape) == 2 and loaded_weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
loaded_weight = loaded_weight.unsqueeze(1) #[k,1,n]
loaded_weight = loaded_weight.expand(loaded_weight.shape[0], self.quant_method.scale_k, loaded_weight.shape[2]).reshape([-1, loaded_weight.shape[2]])
#[k,1,n] -> [k,scale_k,n]] -> [k*scale_k, n]
param.load_row_parallel_weight(loaded_weight=loaded_weight)
class UnquantizedLinearMethod():
"""Linear method without quantization."""
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if bias is not None:
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
parallel_embedding_output = None
if memory_recycler is not None:
if memory_recycler.EMBEDDING_OUT_BUFFER.size(0) == x.size(0):
parallel_embedding_output = memory_recycler.EMBEDDING_OUT_BUFFER
return torch.mm(x.view(-1, x.shape[-1]), layer.weight.transpose(1,0), out=parallel_embedding_output).view(*(x.shape[:-1]), layer.weight.shape[0])

View File

@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that compute logits from hidden_stats."""
from typing import Optional
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.distributed import tensor_model_parallel_gather
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.platforms import current_platform
def DumpLogits(logits):
import os
import time
VLLM_VACC_DUMP_LOGITS = os.getenv('VLLM_VACC_DUMP_LOGITS')
if VLLM_VACC_DUMP_LOGITS:
logit_arr = logits.cpu().to(torch.float32).numpy()
timestamp = time.time()
# print("timestamp:", timestamp)
if not os.path.exists(VLLM_VACC_DUMP_LOGITS):
os.makedirs(VLLM_VACC_DUMP_LOGITS)
logit_path = os.path.join(VLLM_VACC_DUMP_LOGITS, f'logit_{timestamp}.bin')
summary_path = os.path.join(VLLM_VACC_DUMP_LOGITS, 'summary.txt')
with open(summary_path, 'a') as f:
f.write(f'{logit_path}\n')
# print("save file:", logit_path)
logit_arr.tofile(logit_path)
class LogitsProcessor(nn.Module):
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
from vllm.distributed.parallel_state import get_tp_group
total_size_all_gather_input = hidden_states.size(0) * lm_head.weight.shape[0] * hidden_states.element_size() * get_tp_group().world_size
# fuse matmul and all_gather if hidden size less equal 7168 and divisible by 32 according to dsp developer
# if hidden_states.size(1) <= 7168 and hidden_states.size(1) % 32 == 0:
# fuse matmul and all_gather if total size of all_gather inputs doesn't exceeds 4MB
if total_size_all_gather_input <= 4194304:
try:
from torch_vacc.vacc.custom_ops import fused_matmul_allgather
# from vllm.distributed.parallel_state import get_tp_group
logits = fused_matmul_allgather(hidden_states, lm_head.weight.T,
get_tp_group().world_size,
get_tp_group().rank_in_group,
get_tp_group().group_id,
get_tp_group().rank_device_infos)
# ensure there is PP last stage at front
if (hasattr(current_platform, 'supports_v1') and current_platform.supports_v1(current_platform)) or get_tp_group().rank_in_group == 0:
seq, hidden_dims = hidden_states.shape
logits = logits.movedim(0, 1)
logits = logits.reshape(seq, -1)
logits = logits[..., :self.org_vocab_size]
if get_tp_group().rank_in_group == 0:
DumpLogits(logits)
else:
logits = None
return logits
except Exception as e:
print("Fused Matmul with AllGather run Fail, now use unfused. " ,e)
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
#print("quant method:", lm_head, lm_head.quant_method, embedding_bias, logits.shape)
# Gather logits for TP
logits = self._gather_logits(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits

View File

@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union
import torch
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolerActivation, get_pooling_params
class ClassifierPooler(Pooler):
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
if pooled_data.dtype != self.head_dtype:
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None:
pooled_data = self.classifier(pooled_data)
# pooled_data shape: [batchsize, num_labels]
if self.logit_bias is not None:
pooled_data -= self.logit_bias
pooling_params = get_pooling_params(pooling_metadata)
flags = [p.activation for p in pooling_params]
if len(set(flags)) == 1:
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
else:
scores = [
self.act_fn(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
# scores shape: [batchsize, num_labels]
return scores
class PoolerNormalize(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return torch.vacc.l2_norm(pooled_data, epsilon=1e-12)

View File

@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Literal, Type, get_args
QuantizationMethods = Literal[
# "aqlm",
"awq",
"deepspeedfp",
"tpu_int8",
"fp8",
"ptpc_fp8",
"fbgemm_fp8",
"modelopt",
"modelopt_fp4",
"bitblas",
"gguf",
"gptq_marlin_24",
"gptq_marlin",
"gptq_bitblas",
"awq_marlin",
"gptq",
"compressed-tensors",
"bitsandbytes",
"hqq",
"experts_int8",
"ipex",
"quark",
"moe_wna16",
"torchao",
"auto-round",
"rtn",
"inc",
"mxfp4",
"petit_nvfp4",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))

View File

@@ -0,0 +1,615 @@
import functools
import importlib.util
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_block_linear, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, expert_weight_is_col_major,
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
validate_fp8_block_shape)
# from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
# all_close_1d, apply_fp8_linear, convert_to_channelwise,
# cutlass_block_fp8_supported, cutlass_fp8_supported,
# normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
# requantize_with_max_scale)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, all_close_1d, convert_to_channelwise,
cutlass_block_fp8_supported, cutlass_fp8_supported,
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize, requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, is_layer_skipped)
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.utils import has_deep_gemm
logger = init_logger(__name__)
# has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def Fp8LinearMethod__init(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.quant_config.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static"
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
if self.block_quant:
self.block_size = self.quant_config.weight_block_size
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.scale_k = 1
self.scale_n = 1
self.scale_n_prefill = 1 # only for fp8 moe
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape)
class Fp8LinearMethod(LinearMethodBase):
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
if self.block_quant:
scale_n = extra_weight_attrs.get("scale_n")
scale_k = extra_weight_attrs.get("scale_k")
if scale_n is not None:
self.scale_n = scale_n
if scale_k is not None:
self.scale_k = scale_k
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_config.weight_block_size is not None
block_n, block_k = (
self.quant_config.weight_block_size[0] // self.scale_n ,
self.quant_config.weight_block_size[1] // self.scale_k ,
)
# Required by row parallel
if (tp_size > 1
and input_size // input_size_per_partition == tp_size
and input_size_per_partition % block_k != 0):
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# Required by column parallel or enabling merged weights
if (tp_size > 1 and output_size // output_size_per_partition
== tp_size) or len(output_partition_sizes) > 1:
for output_partition_size in output_partition_sizes:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}.")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
if not self.block_quant:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)
else:
assert self.quant_config.activation_scheme == "dynamic"
scale = BlockQuantScaleParameter(
data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n,
(input_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None:
# TODO(rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
if current_platform.is_fp8_fnuz():
weight, weight_scale_inv, _ = \
normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv)
else:
weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data
if isinstance(layer, QKVParallelLinear):
# NOTE: for QKVParallelLinear
# weight_scale should be divisible by 8 Dsps
shape = weight_scale_inv.shape[0]
repeat = 1
while shape % 8 != 0:
repeat *= 2
shape = shape * repeat
weight_scale_inv = torch.repeat_interleave(weight_scale_inv, repeats=repeat, dim=0)
# weight = self._maybe_pad_weight(weight)
# if self.block_quant:
# maybe_post_process_fp8_weight_block(
# layer, self.cutlass_block_fp8_supported)
# Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv,
requires_grad=False)
return
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
# Note: lazy import to avoid triton import error.
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear)
if self.block_quant:
assert self.quant_config.weight_block_size is not None
return apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=[layer.weight.shape[0] // layer.weight_scale_inv.shape[0], layer.weight.shape[1] // layer.weight_scale_inv.shape[1]],
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
)
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)
# return apply_fp8_linear(
# input=x,
# weight=layer.weight,
# weight_scale=layer.weight_scale,
# input_scale=layer.input_scale,
# bias=bias,
# cutlass_fp8_supported=self.cutlass_fp8_supported,
# # Default to using per_token quantization if cutlass is supported
# use_per_token_if_dynamic=self.cutlass_fp8_supported)
def Fp8MoEMethod_init_(self, quant_config: Fp8Config, layer: torch.nn.Module):
self.layer = layer
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
self.flashinfer_moe_backend = None
self.scale_k = 1
self.scale_n = 1
self.scale_n_prefill = 1
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm() or current_platform.is_vacc:
self.use_marlin = False
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm():
logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
" DeepGemm kernels")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True
else:
logger.warning_once(
"DeepGemm not supported on the current platform.")
# Check for CutlassBlockScaledGroupedGemm support.
self.allow_cutlass_block_scaled_grouped_gemm = False
if not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
"CutlassBlockScaledGroupedGemm kernels")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(100)):
logger.info_once(
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
)
self.allow_cutlass_block_scaled_grouped_gemm = True
else:
logger.warning_once(
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform.")
self.topk_indices_dtype = None
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
class Fp8MoEMethod(FusedMoEMethodBase):
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
if self.block_quant:
assert self.quant_config.weight_block_size is not None
scale_n = extra_weight_attrs.get("scale_n")
scale_n_prefill = extra_weight_attrs.get("scale_n_prefill")
scale_k = extra_weight_attrs.get("scale_k")
if scale_n is not None:
self.scale_n = scale_n
if scale_k is not None:
self.scale_k = scale_k
if scale_n_prefill is not None:
self.scale_n_prefill = scale_n_prefill
if self.quant_config is not None and self.quant_config.weight_block_size is not None:
self.gcd_value = self.quant_config.weight_block_size[0]
output_size_no_merge = intermediate_size_per_partition
#assert isinstance(output_size_no_merge, int), f"merge output size should divded int, valuue is: {output_size_no_merge}"
if output_size_no_merge % self.quant_config.weight_block_size[0]:
import math
gcd_value = math.gcd(output_size_no_merge % self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[0])
self.scale_n =self.scale_n * self.quant_config.weight_block_size[0] // gcd_value
self.scale_n_prefill =self.scale_n_prefill * self.quant_config.weight_block_size[0] // gcd_value
if hidden_size % self.quant_config.weight_block_size[1]:
import math
gcd_value = math.gcd(hidden_size % self.quant_config.weight_block_size[1], self.quant_config.weight_block_size[1])
self.scale_k =self.scale_k * self.quant_config.weight_block_size[1] // gcd_value
# self.scale_k = self.scale_n
# print('output_size_no_merge', output_size_no_merge)
# 按 block_size 分core
# output_size_no_merge = 384
# block_size = 128: 384 = 3x128 只能分3core x 128
# block_size = 16: 384 = 24x16 8core x (3x16) 可以分到 8core
# output_size_no_merge = 512
# block_size = 128: 512 = 4x128 只能分 4core x 128
# block_size = 64: 512 = 8x64 可以分到 8core x 64
# output_size_no_merge = 768
# block_size = 128: 768 = 6x128 只能分 6core x 128
# block_size = 32: 768 = 8x(3x32) 可以分到 8core x (3x32)
core_num = 8
min_block_size = 4
block_size_tmp = self.quant_config.weight_block_size[0] // self.scale_n
if output_size_no_merge > block_size_tmp and \
output_size_no_merge % block_size_tmp == 0 and \
output_size_no_merge // block_size_tmp < core_num and \
output_size_no_merge % core_num == 0:
core_num_old = output_size_no_merge // block_size_tmp
import math
gcd_value = math.gcd(core_num, core_num_old)
new_scale = core_num // gcd_value
if block_size_tmp // new_scale >= min_block_size:
self.scale_n = new_scale * self.scale_n
#print("moe scale n is:", self.scale_n, self.scale_k, intermediate_size_per_partition)
tp_size = get_tensor_model_parallel_world_size()
if self.scale_n != self.scale_n_prefill:
block_n_prefill = self.quant_config.weight_block_size[0] // self.scale_n_prefill
block_n, block_k = (
self.quant_config.weight_block_size[0] // self.scale_n,
self.quant_config.weight_block_size[1] // self.scale_k,
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}.")
if (tp_size > 1
and hidden_size % block_k != 0):
# Required by row parallel
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if not self.block_quant:
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, 2, dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
else:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n - 1) //
block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_k - 1) // block_k,
(intermediate_size_per_partition + block_n - 1) // block_n,
dtype=torch.float32,
),
requires_grad=False,
)
if self.scale_n != self.scale_n_prefill:
w13_weight_scale_prefill = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n_prefill - 1) //
block_n_prefill),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale_prefill = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_k - 1) // block_k,
(intermediate_size_per_partition + block_n_prefill - 1) // block_n_prefill,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv_prefill", w13_weight_scale_prefill)
layer.register_parameter("w2_weight_scale_inv_prefill", w2_weight_scale_prefill)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.
value} if self.block_quant else
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if self.scale_n != self.scale_n_prefill:
set_weight_attrs(w13_weight_scale_prefill, extra_weight_attrs)
set_weight_attrs(w2_weight_scale_prefill, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
w13_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def moe_fp8_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import FusedMoE
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
try:
from torch_vacc.vacc.custom_ops import fused_experts
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
experts_output = None
if memory_recycler is not None:
# remove MOE_EXPERT_OUT_BUFFER
# experts_output = memory_recycler.MOE_EXPERT_OUT_BUFFER
experts_output = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w13_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
decode_with_batch=layer.is_decode and x.shape[0] > 1,
output_opt=experts_output
)
except Exception as e:
print(f"vacc fused_expert run fail, now using unfused ops: {e}")
from torch_vacc.vacc.custom_ops_cpu import fused_experts
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w13_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)

View File

@@ -0,0 +1,282 @@
# SPDX-License-Identifier: Apache-2.0
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig as GPTQConfigOrig
from vllm.model_executor.layers.quantization.gptq import ExllamaState
from vllm_vacc.vllm.model_executor.models.vars import TRANSPOSE_GPTQ_WEIGHT
import math
def GPTQLinearMethod__init(self, quant_config: GPTQConfigOrig):
self.quant_config = quant_config
self.scale_k = 1
self.split_num = 4
def int32_to_int4(s0, axis = -2):
# 要先拉平 shape[1, n]
# 每个int32 拆成8个int4, 8个int32表示 得到[8, n]
# x32(int32) => 32bit => 4bit x 8 x4[8] 4bit
# x32 31-28 => x4[7]
# x32 27-24 => x4[6]
# ...
# x32 3-0 => x4[0]
# x32[index=0] => x4[7,6,5,4,3,2,1,0]
# 4bit转真实数字
# 不是按补码方式
# 1111 => 15 => 7
# 15-8 = 7
# 0101 => 6 =>-2
# 6-8 = -2
# 0x 6A CB 37 2B (内存中排列 2B 37 CB 6A => B273BCA6 => (-8) => int4: 3, -6, -1, -5, 3, 4, 2, -2
# 内存中实际排布为小端模式:
# int32: 2B 37 CB 6A => 2,11,3,7,12,11,6,10 => (-8) => -6,3, -5,-1, 4,3, -2,2 => 同一字节所在的两个交换得到 3, -6, -1, -5, 3, 4, 2, -2
# int4: 3, -6, -1, -5, 3, 4, 2, -2
s = s0.view(torch.uint32)
all = []
for i in range(8):
x = 15 << (i*4)
# s2 = torch.bitwise_and(x,s)
s2 = torch.from_numpy(np.bitwise_and(x, s.numpy()))
s3 = s2 / (2 ** (i*4))
s4 = s3.to(torch.int32)
# 补码, 结果不对
# s4[s4 > 7] = s4[s4 > 7]-16
# 直接 - 8 结果正确, 范围: -8-7
s4 = s4 - 8
all.append(s4.reshape(1,*s4.shape))
all = torch.concatenate(all, 0)
if axis == -2 or axis == 0:
# 8,K//8,N => K//8,8,N => K,N
all = all.transpose(-2,0).reshape(-1,all.shape[-1]).contiguous()
else:
# 8,N,K//8 => N,K//8,8 => N,K
all = all.permute(1,2,0).reshape(all.shape[-2],-1).contiguous()
return all
def dequant_weight(qw, scales, group_size = 128):
N = qw.shape[1]
int4_to_int32_axis = -2
if TRANSPOSE_GPTQ_WEIGHT:
N = qw.shape[0]
int4_to_int32_axis = -1
qweight = int32_to_int4(qw,int4_to_int32_axis).to(torch.float16) #int32 => 8 int4 +> fp16
if TRANSPOSE_GPTQ_WEIGHT:
scales = scales.T.contiguous()
qweight = qweight.T.contiguous()
scales = torch.concatenate([scales] * group_size, 1).reshape(-1, N) # scale 按 group_size 扩展, 每 group_size 个数共用一个scale
# print('qweight', qweight.shape, qweight.dtype)
# print('scale', scales.shape, scales.dtype)
dequant_weight = qweight * scales #dequant
return dequant_weight
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
Reference: https://arxiv.org/abs/2210.17323
"""
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
class GPTQLinearMethod(LinearMethodBase):
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
# if input_size_per_partition % self.quant_config.group_size != 0:
# raise ValueError(
# "The input size is not aligned with the quantized "
# "weight shape. This can be caused by too large "
# "tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
exllama_state = ExllamaState.UNINITIALIZED
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
# For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED
else:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
g_idx = RowvLLMParameter(data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
qzeros_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
layer.exllama_state = exllama_state
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# for torch.compile
# self.quant_config.weight_bits == 4
if TRANSPOSE_GPTQ_WEIGHT:
layer.qzeros = Parameter(layer.qzeros.data.T.contiguous(), requires_grad=False)
layer.qweight = Parameter(layer.qweight.data.T.contiguous(), requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data.T.contiguous(), requires_grad=False)
else:
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N
reshaped_x = x.reshape(-1, x.shape[-1])
# print(f"~~~~ start dequant")
# import time
# start_quant_time = time.time()
# weight = dequant_weight(layer.qweight.cpu(), layer.scales.cpu(), self.quant_config.group_size // self.scale_k).to(layer.qweight.device)
# end_quant_time = time.time()
# print(f"~~~~ dequant time: {end_quant_time - start_quant_time}")
# if torch.distributed.get_rank() == 0:
# print(f"~~~~ weight shape: {weight.shape}, dtype: {weight.dtype}")
# output = torch.matmul(reshaped_x, weight)
# print("entering GPTQLinearMethod apply, reshaped_x shape:", reshaped_x.shape, "reshaped_x stride", reshaped_x.stride(), "input_tensor", x.shape, "qweight shape:", layer.qweight.shape, "scales shape:", layer.scales.shape)
output = torch.vacc.w4a8_block_int4_matmul(
reshaped_x,
layer.qweight.transpose(-1, -2),
layer.scales.transpose(-1, -2),
[1, self.quant_config.group_size // self.scale_k],
)
# print("exiting GPTQLinearMethod apply, output shape:", output.shape)
# end_gemm_time = time.time()
# if torch.distributed.get_rank() == 0:
# print(f"~~~~ gemm time: {end_gemm_time - end_quant_time}")
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)

View File

@@ -0,0 +1,372 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional, Union
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
# [num_experts, N, K//8], int32 ==> [num_experts, N, K], int4 ==> [num_experts, N//8, K], int32
def repack_quant_moe_weight_old(original_packed_tensor):
num_experts = original_packed_tensor.shape[0]
N = original_packed_tensor.shape[1]
K = original_packed_tensor.shape[2] * 8
if original_packed_tensor.dtype != torch.int32:
raise ValueError("data type of input tensor should be int32")
if N % 8 != 0:
raise ValueError("N of input tensor should be divisible by 8")
# --- 1. 解包:将 int32 张量展开为逻辑上的 int4 张量 ---
# 创建一个临时张量来存储解包后的所有 int4 值
# 用 torch.uint8 作为 int4 的临时存储,因为 PyTorch 没有原生的 int4 dtype
unpacked_int4_tensor = torch.zeros(
(num_experts, N, K),
dtype=torch.uint8,
device=original_packed_tensor.device
)
mask = 0b1111
for i in range(8):
# 提取当前 int4 所需的 int32 块中的值
# 通过右移 (i * 4) 位,我们将第 i 个 4 位整数移动到最低有效位
# 然后通过按位与操作与掩码结合,提取出这 4 位的值
extracted_int4s = (original_packed_tensor >> (i * 4)) & mask
# 将提取出的 int4 值放置到 unpacked_int4_tensor 的正确位置
# 使用切片 `i::8`,意思是:从索引 `i` 开始,每隔 8 个位置填充一次
unpacked_int4_tensor[:, :, i::8] = extracted_int4s
# --- 2. 重新打包:将 int4 逻辑张量重新打包为新的 int32 张量 ---
new_packed_tensor = torch.zeros(
(num_experts, N//8, K),
dtype=torch.int32,
device=original_packed_tensor.device
)
for i in range(8):
# 从解包后的 int4 张量中提取当前需要打包的 int4 序列,使用切片 `i::8` 沿着N方向来提取
current_int4_segment = unpacked_int4_tensor[:, i::8, :]
# 将这个 int4 序列转换为 int32 类型(因为打包到 int32并左移到其在新 int32 块中的正确位置
# 然后使用按位或操作符将其合并到 new_packed_tensor 中
new_packed_tensor |= (current_int4_segment.to(torch.int32) << (i * 4))
return new_packed_tensor
def repack_quant_moe_weight(original_packed_tensor):
if original_packed_tensor.dtype != torch.int32:
raise ValueError("data type of input tensor should be int32")
num_experts, N, K_packed = original_packed_tensor.shape
K = K_packed * 8
if N % 8 != 0:
raise ValueError("N of input tensor should be divisible by 8")
new_packed_tensor = torch.zeros((num_experts, N // 8, K),
dtype=torch.int32,
device=original_packed_tensor.device)
for i in range(8):
source_slice = original_packed_tensor[:, i::8, :]
for j in range(8):
unpacked_strip = (source_slice >> (j * 4)) & 0b1111
new_packed_tensor[:, :, j::8] |= (unpacked_strip.to(torch.int32) << (i * 4))
return new_packed_tensor
class MoeWNA16Method(FusedMoEMethodBase):
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
self.moe = layer
layer.quant_config = self.quant_config
bit8_pack_factor = self.quant_config.bit8_pack_factor
bit32_pack_factor = 32 // self.quant_config.weight_bits
group_size = self.quant_config.group_size
group_size_div_factor = 1
group_size_w13 = self.quant_config.group_size
group_size_div_factor_w13 = 1
group_size_w2 = self.quant_config.group_size
group_size_div_factor_w2 = 1
# make intermediate_size and hidden_size divisible by group_size
# we reduce the group size to ensure that
# and we would repeat the loaded_weight later
while intermediate_size_per_partition % group_size or \
hidden_size % group_size:
group_size = group_size // 2
group_size_div_factor *= 2
assert group_size >= 32
layer.group_size = group_size
layer.group_size_div_factor = group_size_div_factor
while intermediate_size_per_partition % group_size_w2:
group_size_w2 = group_size_w2 // 2
group_size_div_factor_w2 *= 2
assert group_size_w2 >= 32
layer.w2_block_size = group_size_w2
layer.group_size_div_factor_w2 = group_size_div_factor_w2
while hidden_size % group_size_w13:
group_size_w13 = group_size_w13 // 2
group_size_div_factor_w13 *= 2
assert group_size_w13 >= 32
layer.w13_block_size = group_size_w13
layer.group_size_div_factor_w13 = group_size_div_factor_w13
strategy = FusedMoeWeightScaleSupported.GROUP.value
extra_weight_attrs.update({
"quant_method": strategy,
"is_transposed": False
})
assert 'weight_loader' in extra_weight_attrs
weight_loader = extra_weight_attrs['weight_loader']
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
layer, weight_loader)
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
# Fused gate_up_proj (column parallel)
# w13_qweight = torch.nn.Parameter(torch.empty(
# num_experts,
# 2 * intermediate_size_per_partition,
# hidden_size // bit8_pack_factor,
# dtype=torch.uint8),
# requires_grad=False)
w13_qweight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // bit32_pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
# down_proj (row parallel)
# w2_qweight = torch.nn.Parameter(torch.empty(
# num_experts,
# hidden_size,
# intermediate_size_per_partition // bit8_pack_factor,
# dtype=torch.uint8),
# requires_grad=False)
w2_qweight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // bit32_pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
w13_scales = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // group_size_w13,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
w2_scales = torch.nn.Parameter(torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // group_size_w2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
if self.quant_config.has_zp:
w13_qzeros = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition // bit8_pack_factor,
hidden_size // group_size,
dtype=torch.uint8),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
w2_qzeros = torch.nn.Parameter(torch.zeros(
num_experts,
hidden_size // bit8_pack_factor,
intermediate_size_per_partition // group_size,
dtype=torch.uint8),
requires_grad=False)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
if self.quant_config.linear_quant_method == "gptq":
# some param are unused, but we need to init them in order to
# load weights
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
if not self.quant_config.has_zp:
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
for key in invalid_param_keys:
param = torch.nn.Parameter(torch.empty((0, ),
dtype=torch.int32),
requires_grad=False)
layer.register_parameter(key, param)
set_weight_attrs(param, extra_weight_attrs)
@staticmethod
def get_weight_loader(layer, weight_loader):
def convert_awq_tensor(tensor, tensor_type):
# convert awq qweight/qzeros to a standard format (assume int4)
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
# qzeros: (k // group_size, n // pack_factor_bit32) ->
# (n // pack_factor_bit8, k // group_size)
# pack_factor_bit32 = 32 // weight_bits
# pack_factor_bit8 = 8 // weight_bits
# 0. suppose origin shape (a, b), dtype int32
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
size0 = tensor.size(0)
tensor = tensor.view(torch.uint8)
# 2. unpack to uint4 (only when weight_bits == 4)
# shape (a, 4 * b) -> (a, 4 * b, 2)
shifter = torch.tensor([0, 4],
dtype=torch.uint8,
device=tensor.device)
tensor = (tensor[:, :, None] >> shifter) & 0xF
# 3. change order, see
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
# shape -> (a, 4 * b * pack_factor_bit8)
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
tensor = tensor.view(size0, -1)
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
tensor = tensor.T.contiguous()
# 5. repack (only when weight_bits == 4)
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
# qzeros shape -> (4 * b, a)
if tensor_type == "qweight":
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
elif tensor_type == "qzeros":
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
return tensor
def convert_gptq_int4_qzeros(tensor):
tensor = tensor.view(torch.uint8)
shifter = torch.tensor([0, 4],
dtype=torch.uint8,
device=tensor.device)
tensor = (tensor[:, :, None] >> shifter) & 0xF
tensor = tensor + 1
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
return tensor
def moe_wna16_weight_loader(param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
return_success: bool = False):
if "g_idx" in weight_name:
return False if return_success else None
if not layer.quant_config.has_zp and "qzeros" in weight_name:
return False if return_success else None
device = get_tp_group().device
tp_rank = get_tensor_model_parallel_rank()
loaded_weight = loaded_weight.to(device)
shard_size = layer.intermediate_size_per_partition
# convert gptq and awq weight to a standard format
if layer.quant_config.linear_quant_method == "awq":
assert layer.quant_config.weight_bits == 4
if "weight" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight,
"qweight")
elif "zeros" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
else:
loaded_weight = loaded_weight.T
elif layer.quant_config.linear_quant_method == "gptq":
assert layer.quant_config.weight_bits in [4, 8]
if "weight" in weight_name:
# loaded_weight = loaded_weight.T.contiguous().view(
# torch.uint8)
loaded_weight = loaded_weight.T.contiguous()
elif "zeros" in weight_name:
# add 1 to gptq qzeros to align with awq
loaded_weight = loaded_weight.view(torch.uint8)
if layer.quant_config.weight_bits == 4:
loaded_weight = convert_gptq_int4_qzeros(
loaded_weight).T
else:
loaded_weight = loaded_weight.T + 1
else:
# loaded_weight = loaded_weight.T
loaded_weight = loaded_weight.T.contiguous()
# repeat the qzeros/scales to fit new group size
if layer.group_size_div_factor_w13 > 1 and \
"qzeros" in weight_name or "scales" in weight_name and \
shard_id == "w1" or shard_id == "w3":
loaded_weight = loaded_weight.repeat_interleave(
layer.group_size_div_factor_w13, 1)
elif layer.group_size_div_factor_w2 > 1 and \
"qzeros" in weight_name or "scales" in weight_name and \
shard_id == "w2":
loaded_weight = loaded_weight.repeat_interleave(
layer.group_size_div_factor_w2, 1)
elif layer.group_size_div_factor > 1 and \
"qzeros" in weight_name or "scales" in weight_name:
loaded_weight = loaded_weight.repeat_interleave(
layer.group_size_div_factor, 1)
if "w13_qzeros" in weight_name:
tensor = loaded_weight.view(layer.tp_size, -1,
loaded_weight.size(1))[tp_rank]
if shard_id == "w1":
param.data[expert_id, :shard_size // 2] = tensor
else:
param.data[expert_id, shard_size // 2:] = tensor
return True if return_success else None
elif "w2_qzeros" in weight_name:
param.data[expert_id] = loaded_weight.view(
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
return True if return_success else None
else:
# Delegate to the original loader, passing return_success
return weight_loader(param,
loaded_weight,
weight_name,
shard_id,
expert_id,
return_success=return_success)
return moe_wna16_weight_loader
def process_weights_after_loading(self, layer):
dev_w2 = layer.w2_qweight.device
# torch.Size([128, 2048, 24]), torch.int32, strides: (49152, 24, 1)
# ======>
# torch.Size([128, 256, 192]), torch.int32, strides: (49152, 1, 256)
layer.w2_qweight = torch.nn.Parameter(repack_quant_moe_weight(layer.w2_qweight.cpu()).transpose(-1, -2).contiguous().transpose(-1, -2).to(device=dev_w2), requires_grad=False)
# torch.Size([128, 2048, 3]), torch.float16, strides: (6144, 3, 1)
# ======>
# torch.Size([128, 2048, 3]), torch.float16, strides: (6144, 1, 2048)
layer.w2_scales = torch.nn.Parameter(layer.w2_scales.transpose(-1, -2).contiguous().transpose(-1, -2), requires_grad=False)

View File

@@ -0,0 +1,39 @@
import torch
from typing import List, Optional
def _apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = True,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
assert len(block_size) == 2, "only support dim2 block now"
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
try:
from torch_vacc.vacc.custom_ops import w8a8_block_fp8_linear
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
mla_oproj_output = None
if memory_recycler is not None:
os1, os2 = memory_recycler.MLA_OPROJ_OUT_BUFFER.shape
if os1 == input_2d.size(0) and os2 == weight.size(0):
mla_oproj_output = memory_recycler.MLA_OPROJ_OUT_BUFFER
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size, output = mla_oproj_output)
except Exception as e:
print("vacc fuse fp8 matmul run fail:", e, " , now use unfused ops")
from torch_vacc.vacc.custom_ops_cpu import w8a8_block_fp8_linear
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)

View File

@@ -0,0 +1,9 @@
from .rotary_embedding import (
RotaryEmbedding_init_vacc,
RotaryEmbedding_forward_vacc,
ScalingRotaryEmbedding_forward_vacc,
_compute_inv_freq_vacc,
_deepseek_compute_cos_sin_cache_vacc,
_yarn_compute_cos_sin_cache_vacc,
_compute_cos_sin_cache_vacc
)

View File

@@ -0,0 +1,101 @@
import itertools
from typing import Optional, Union
import numpy as np
import torch
from transformers import PretrainedConfig
class MRotaryEmbedding:
@classmethod
def _qwen3vl_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw
for _ in range(t)]
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta

View File

@@ -0,0 +1,203 @@
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
# from vllm.model_executor.layers.rotary_embedding import _apply_rotary_emb
# from vllm.model_executor.layers.rotary_embedding import _yarn_find_correction_range, _yarn_linear_ramp_mask
from vllm.model_executor.layers.rotary_embedding.common import yarn_find_correction_range as _yarn_find_correction_range
from vllm.model_executor.layers.rotary_embedding.common import yarn_linear_ramp_mask as _yarn_linear_ramp_mask
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from ...ops.mrope_op import get_sin_cos_mrope
def RotaryEmbedding_init_vacc(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super(CustomOp, self).__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
# cache = self._compute_cos_sin_cache()
cos, sin = self._compute_cos_sin_cache()
cos = cos.to(dtype)
sin = sin.to(dtype)
self.register_buffer("cos_cache", cos, persistent=False)
self.register_buffer("sin_cache", sin, persistent=False)
# cache = cache.to(dtype)
# self.cos_sin_cache: torch.Tensor
# self.register_buffer("cos_sin_cache", cache, persistent=False)
def RotaryEmbedding_forward_vacc(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-vacc implementation of forward()."""
if offsets is not None:
positions = positions + offsets
num_tokens = positions.numel()
# positions = positions.flatten()
# num_tokens = positions.shape[0]
# cos_sin = self.cos_sin_cache.index_select(0, positions)
# cos_sin = self.cos_sin_cache[positions]
# cos, sin = cos_sin.chunk(2, dim=-1)
if isinstance(self, MRotaryEmbedding):
# get mrope sin/cos
cos, sin = get_sin_cos_mrope(self, positions)
num_tokens = num_tokens//3
else:
positions = positions.flatten()
cos = self.cos_cache[positions]
sin = self.sin_cache[positions]
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
mode = "neox"
if not self.is_neox_style:
mode = "gptj"
query_rot, key_rot=torch.vacc.RotaryPosEmbedding(query_rot, key_rot, cos, sin, 0, mode)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
return query, key
def ScalingRotaryEmbedding_forward_vacc(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
# self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
# positions.device)
# cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
# if offsets is not None else positions]
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
# cos_sin = self.cos_sin_cache[positions]
# cos, sin = cos_sin.chunk(2, dim=-1)
cos = self.cos_cache[positions]
sin = self.sin_cache[positions]
# TODO: to be removed (require odsp support)
# if self.is_neox_style:
# # NOTE(woosuk): Here we assume that the positions tensor has the
# # shape [batch_size, seq_len].
# cos = cos.repeat(1, 1, 2).unsqueeze(-2)
# sin = sin.repeat(1, 1, 2).unsqueeze(-2)
# else:
# cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
# sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
# rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
mode = "neox" if self.is_neox_style else "gptj"
# query_rot = query_rot * cos + rotate_fn(query_rot) * sin
# key_rot = key_rot * cos + rotate_fn(key_rot) * sin
query_rot, key_rot=torch.vacc.RotaryPosEmbedding(query_rot, key_rot, cos, sin, 0, mode)
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key
def _compute_inv_freq_vacc(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device=current_platform.device_type) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _deepseek_compute_cos_sin_cache_vacc(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device=current_platform.device_type,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
# NOTE: for odsp friendly
# seperate cos/sin cache can gurantee cos/sin
# always has contigous layout for dim[-1]
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
return cos, sin
# cache = torch.cat((cos, sin), dim=-1)
# return cache
def _yarn_compute_cos_sin_cache_vacc(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device=current_platform.device_type,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
return cos, sin
# cache = torch.cat((cos, sin), dim=-1)
# return cache
def _compute_cos_sin_cache_vacc(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings,
device=current_platform.device_type,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
# NOTE: for odsp friendly
# seperate cos/sin cache can gurantee cos/sin
# always has contigous layout for dim[-1]
cos = freqs.cos()
sin = freqs.sin()
return cos, sin
# cache = torch.cat((cos, sin), dim=-1)
# return cache
# import vllm.model_executor.layers.rotary_embedding as rotary_embedding
# rotary_embedding.RotaryEmbedding.forward_vacc=RotaryEmbedding_forward_vacc
# rotary_embedding.DeepseekScalingRotaryEmbedding._compute_inv_freq=_compute_inv_freq_vacc
# rotary_embedding.DeepseekScalingRotaryEmbedding._compute_cos_sin_cache=_compute_cos_sin_cache_vacc

View File

@@ -0,0 +1,542 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""
import itertools
import warnings
from dataclasses import dataclass
from importlib.util import find_spec
from math import inf
from typing import Dict, Iterator, List, Optional, Tuple, Union
import msgspec
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.model_executor.layers.utils import apply_penalties
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors,
SequenceGroupToSample)
from vllm.sampling_params import SamplingType
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.model_executor.layers.sampler import (SamplerOutput,
_apply_min_tokens_penalty,
_apply_top_k_top_p,
_apply_min_p,
_sample,
SampleResultArgsType,
get_logprobs,
_build_sampler_output,
SampleReturnType,
SampleResultsDictType,
SampleMetadataType,
MultinomialSamplesType,
_modify_greedy_probs_inplace,
_top_k_top_p_multinomial_with_flashinfer,
_multinomial,
get_pythonized_sample_results,
)
from vllm_vacc.vllm.model_executor.models.vars import USE_DS3_SAMPLER as use_ds3_sampler
from vllm_vacc.vllm.model_executor.models.vars import USE_DS3_SAMPLER_OP as use_ds3_sampler_op
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
# yapf: disable
from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
# yapf: enable
else:
flashinfer_top_k_top_p_sampling = None
class SamplerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This data structure implements methods, so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs: List[CompletionSequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional[torch.Tensor] = None
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
# Holds either (1) the pythonized sampler result (single-step scheduling)
# or (2) what will be arguments for later deferred pythonization of the
# sampler result (muliti-step scheduling)
deferred_sample_results_args: Optional[SampleResultArgsType] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
# CPU tensor containing the sampled token ids. Used during multi-step to
# return the sampled token ids from last rank to AsyncLLMEngine to be
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu: Optional[torch.Tensor] = None
# On-device tensor containing the sampled token embeddings (embeddings
# corresponding to the sampled token ids). Used when prompt embeddings are
# specified in lieu of prompt token ids or text.
sampled_token_embeds: Optional[torch.Tensor] = None
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None
# Optional prefill hidden states from the model
# (used for models like EAGLE).
prefill_hidden_states: Optional[torch.Tensor] = None
# Time taken in the forward pass for this across all workers
model_forward_time: Optional[float] = None
# Time taken in the model execute function. This will include model forward,
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None
def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
return self.outputs[idx]
def __setitem__(self, idx: int, value):
self.outputs[idx] = value
def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
return iter(self.outputs)
def __len__(self):
return len(self.outputs)
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs
def __repr__(self) -> str:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
else self.sampled_token_probs.shape)
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
self.sampled_token_ids.shape)
return (
f"SamplerOutput(outputs={self.outputs}, "
f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr},")
def Sampler_forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert logits is not None
# print(f'Sampler_forward all_greedy={all_greedy}')
# Prepare sampling tensors with pinned memory to avoid blocking.
if not sampling_metadata.reuse_sampling_tensors:
self._init_sampling_tensors(logits, sampling_metadata)
elif self._do_penalties:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self._init_sampling_tensors(logits, sampling_metadata)
assert self._sampling_tensors is not None
sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
do_min_p = self._do_min_p
is_greedy = (len(sampling_metadata.categorized_sample_indices[SamplingType.GREEDY]) == logits.shape[0])
is_random = (len(sampling_metadata.categorized_sample_indices[SamplingType.RANDOM]) == logits.shape[0])
is_random_seed = (len(sampling_metadata.categorized_sample_indices[SamplingType.RANDOM_SEED]) == logits.shape[0])
max_n_in_batch = sampling_metadata.seq_groups[0].sampling_params.n
generator = sampling_metadata.seq_groups[0].generator
min_tokens = sampling_metadata.seq_groups[0].sampling_params.min_tokens
# print("use_ds3_sampler ", use_ds3_sampler)
if use_ds3_sampler == True and (is_greedy == True or ((is_random == True or is_random_seed == True) \
and do_penalties == False \
and flashinfer_top_k_top_p_sampling is None \
and min_tokens <= 0 \
and do_min_p == False \
and max_n_in_batch == 1 \
# and self._should_modify_greedy_probs_inplace == False
# and self.include_gpu_probs_tensor == False
)):
sampling_type = SamplingType.GREEDY
sample_metadata: SampleMetadataType = {}
multinomial_samples: MultinomialSamplesType = {}
greedy_samples: Optional[torch.Tensor] = None
multinomial_out: Optional[torch.Tensor] = None
vacc_device = logits.device
# Create output tensor for sampled token ids.
if self.include_gpu_probs_tensor:
sampled_token_ids_tensor = torch.full((logits.shape[0], 1),
VLLM_INVALID_TOKEN_ID,
dtype=torch.long,
device=vacc_device)
probs_out = torch.empty_like(logits)
logprobs_out = torch.empty_like(logits)
else:
probs_out = None
logprobs_out = None
sampled_token_ids_tensor = None
if is_greedy == True:
greedy_samples, _ = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 0)
sampling_type = SamplingType.GREEDY
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor = greedy_samples.unsqueeze(-1).to(torch.long)
if probs_out is not None:
# probs_out = torch.softmax(logits.to(torch.float), dim=-1, dtype=torch.float).to(logits)
probs_out = torch.softmax(logits, dim=-1)
if self._should_modify_greedy_probs_inplace == True:
sample_indices = (sampling_metadata.categorized_sample_indices[SamplingType.GREEDY]).long()
probs_out[sample_indices, :] = 0
probs_out[sample_indices, greedy_samples] = 1.0
elif is_random == True and do_top_p_top_k == True:
if use_ds3_sampler_op:
logits = logits.to(torch.float)
multinomial_out, probs_out = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 2)
multinomial_out = multinomial_out.view(-1, max_n_in_batch)
else:
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
logits = torch.vacc.topk_topp(logits, sampling_tensors.top_ps, sampling_tensors.top_ks)
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
probs_out = probs
# multinomial_out = torch.multinomial(probs, 1)
q = torch.empty_like(probs)
q.exponential_()
multinomial_out = probs.div_(q).argmax(dim=1).view(-1, max_n_in_batch)
sampling_type = SamplingType.RANDOM
elif is_random_seed == True and generator is not None and do_top_p_top_k == True:
if use_ds3_sampler_op:
# print("is_random_seed ", is_random_seed)
logits = logits.to(torch.float)
multinomial_out, probs_out = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 1, generator)
multinomial_out = multinomial_out.view(-1, max_n_in_batch)
else:
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
logits = torch.vacc.topk_topp(logits, sampling_tensors.top_ps, sampling_tensors.top_ks).to(torch.float)
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
probs_out = probs
# torch.manual_seed(sampling_metadata.seq_groups[0].sampling_params.seed)
# multinomial_out = torch.multinomial(probs, 1)
q = torch.empty_like(probs)
q.exponential_(generator=generator)
multinomial_out = probs.div_(q).argmax(dim=1).view(-1, max_n_in_batch)
sampling_type = SamplingType.RANDOM_SEED
multinomial_samples[sampling_type] = multinomial_out
if sampled_token_ids_tensor is not None:
if(sampling_type != SamplingType.GREEDY):
# Store sampled tokens in output tensor.
sampled_token_ids_tensor = multinomial_samples[sampling_type].to(torch.long)
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
t: []
for t in SamplingType
}
for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
sample_results_dict: SampleResultsDictType = {}
maybe_deferred_args = SampleResultArgsType(
sampling_metadata=sampling_metadata,
sample_metadata=sample_metadata,
multinomial_samples=multinomial_samples,
greedy_samples=greedy_samples,
# beam_search_logprobs=None,
sample_results_dict=sample_results_dict)
if not sampling_metadata.skip_sampler_cpu_output:
# GPU<->CPU sync happens here.
# This also converts the sampler output to a Python object.
# Return Pythonized sampler result & sampled token ids
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = get_pythonized_sample_results(
maybe_deferred_args), sampled_token_ids_tensor
else:
# Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = (
maybe_deferred_args,
sampled_token_ids_tensor,
)
if self.include_gpu_probs_tensor:
on_device_tensors = (probs_out, logprobs_out, maybe_sampled_tokens_tensor)
else:
on_device_tensors = None
# Get the logprobs query results.
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
logprobs = logits
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Apply presence and frequency penalties.
# if do_penalties:
# logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
# sampling_tensors.output_tokens,
# sampling_tensors.presence_penalties.to(logits.device),
# sampling_tensors.frequency_penalties.to(logits.device),
# sampling_tensors.repetition_penalties.to(logits.device))
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps.to(logits.device),
sampling_tensors.top_ks.to(logits.device))
if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
)
if self.include_gpu_probs_tensor:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors = None
# Get the logprobs query results.
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
def rejection_forward(
self,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
if seeded_seqs is None:
out, index = torch.vacc.rejection_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, 1)
else:
out, index = torch.vacc.rejection_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, 0, seeded_seqs[0])
return out
class Sampler(nn.Module):
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert logits is not None
_, vocab_size = logits.shape
# Prepare sampling tensors with pinned memory to avoid blocking.
if not sampling_metadata.reuse_sampling_tensors:
self._init_sampling_tensors(logits, sampling_metadata)
elif self._do_penalties:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self._init_sampling_tensors(logits, sampling_metadata)
assert self._sampling_tensors is not None
sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
do_min_p = self._do_min_p
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Apply presence and frequency penalties.
if do_penalties:
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
sampling_tensors.output_tokens,
sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
# print("tempratures is:", temperatures)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1).to(logits.device))
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
)
if self.include_gpu_probs_tensor:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors = None
# Get the logprobs query results.
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
def _apply_top_k_top_p_vacc(
logits: torch.Tensor,
p: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1).to(probs_sum.device)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = torch.empty_like(logits_sort).scatter_(dim=-1,
index=logits_idx,
src=logits_sort)
return logits

View File

@@ -0,0 +1,69 @@
import torch
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.platforms import current_platform
from typing import Tuple
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
def VocabParallelEmbedding_forward(self, input_):
try:
if self.tp_size > 1:
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
parallel_embedding_output = None
if memory_recycler is not None:
if memory_recycler.EMBEDDING_OUT_BUFFER.size(0) == input_.size(0):
parallel_embedding_output = memory_recycler.EMBEDDING_OUT_BUFFER.to(self.weight.dtype)
output_parallel = torch.vacc.parallel_embedding(
input_,
self.weight,
self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index,
output = parallel_embedding_output
)
else:
raise ValueError("not support non-tp")
except:
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self,
masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
#TODO: fuse all_reduce
return tensor_model_parallel_all_reduce(output_parallel)