init
This commit is contained in:
0
vllm_vacc/vllm/model_executor/layers/__init__.py
Normal file
0
vllm_vacc/vllm/model_executor/layers/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
8
vllm_vacc/vllm/model_executor/layers/activation.py
Normal file
8
vllm_vacc/vllm/model_executor/layers/activation.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import torch
|
||||
|
||||
def SiluAndMul_forward_vacc(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.vacc.swiglu(x)
|
||||
|
||||
def QuickGELU_forward_vacc(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
82
vllm_vacc/vllm/model_executor/layers/fused_moe/fused_moe.py
Normal file
82
vllm_vacc/vllm/model_executor/layers/fused_moe/fused_moe.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||
gating_output = gating_output.view(num_tokens, -1)
|
||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(dtype)
|
||||
return topk_weights, selected_experts
|
||||
|
||||
def grouped_topk_with_itype(hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import fused_moe_preprocess
|
||||
return fused_moe_preprocess(gating_output, e_score_correction_bias)
|
||||
except Exception as e:
|
||||
print(f"fused group topk run fail, now use unfused group topk: {e}")
|
||||
|
||||
if scoring_func == "softmax":
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.shape[0]
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_scores = scores
|
||||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (scores.view(num_token, num_expert_group,
|
||||
-1).topk(2, dim=-1)[0].sum(dim=-1))
|
||||
else:
|
||||
group_scores = scores.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
|
||||
sorted=False)[1] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(),
|
||||
float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores,
|
||||
k=topk,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights.to(hidden_states.dtype), topk_ids.to(torch.int32)
|
||||
715
vllm_vacc/vllm/model_executor/layers/fused_moe/layer.py
Normal file
715
vllm_vacc/vllm/model_executor/layers/fused_moe/layer.py
Normal file
@@ -0,0 +1,715 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_moe import fused_experts
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
if current_platform.is_tpu():
|
||||
# the iterative moe implementation is used until the moe_pallas is fixed
|
||||
from .moe_torch_iterative import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
fused_moe_pallas = None # type: ignore
|
||||
logger = init_logger(__name__)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
GROUP = "group"
|
||||
BLOCK = "block"
|
||||
|
||||
def FusedMoE_init_(
|
||||
self,
|
||||
num_experts: int, # Global number of experts
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
dp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
has_bias: bool = False,
|
||||
is_sequence_parallel=False,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
super(FusedMoE, self).__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# FIXME (varun): We should have a better way of inferring the activation
|
||||
# datatype. This works for now as the tensor datatype entering the MoE
|
||||
# operation is typically unquantized (i.e. float16/bfloat16).
|
||||
if vllm_config.model_config is not None:
|
||||
moe_in_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||
# since model_config is not set in the pytest test.
|
||||
moe_in_dtype = params_dtype
|
||||
|
||||
tp_size_ = (tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size())
|
||||
dp_size_ = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
|
||||
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size_,
|
||||
dp_size_=dp_size_,
|
||||
vllm_parallel_config=vllm_config.parallel_config))
|
||||
|
||||
self.global_num_experts = num_experts + num_redundant_experts
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError("Duplicate layer name: {}".format(prefix))
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
|
||||
self.enable_eplb = enable_eplb
|
||||
self.expert_load_view: Optional[torch.Tensor] = None
|
||||
self.logical_to_physical_map: Optional[torch.Tensor] = None
|
||||
self.logical_replica_count: Optional[torch.Tensor] = None
|
||||
|
||||
# Determine expert maps
|
||||
if self.use_ep:
|
||||
if self.enable_eplb:
|
||||
assert self.global_num_experts % self.ep_size == 0, \
|
||||
"EPLB currently only supports even distribution of " \
|
||||
"experts across ranks."
|
||||
else:
|
||||
assert num_redundant_experts == 0, \
|
||||
"Redundant experts are only supported with EPLB."
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts)
|
||||
else:
|
||||
self.local_num_experts, self.expert_map = (self.global_num_experts,
|
||||
None)
|
||||
|
||||
self.top_k = top_k
|
||||
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self.activation = activation
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=moe_in_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
has_bias=has_bias,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
quant_method: Optional[QuantizeMethodBase] = None
|
||||
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
|
||||
else quant_config.get_quant_method(self, prefix))
|
||||
|
||||
assert quant_method is not None
|
||||
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
self.quant_method = quant_method
|
||||
|
||||
if self.enable_eplb:
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8MoEMethod)
|
||||
if not isinstance(quant_method, Fp8MoEMethod):
|
||||
# TODO: Add support for additional quantization methods.
|
||||
# The implementation for other quantization methods does not
|
||||
# contain essential differences, but the current quant API
|
||||
# design causes duplicated work when extending to new
|
||||
# quantization methods, so I'm leaving it for now.
|
||||
# If you plan to add support for more quantization methods,
|
||||
# please refer to the implementation in `Fp8MoEMethod`.
|
||||
raise NotImplementedError("EPLB is only supported for FP8 "
|
||||
"quantization for now.")
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
# self.scale_n = self.quant_method.scale_n
|
||||
# self.scale_k = self.quant_method.scale_k
|
||||
self.scale_n = 1
|
||||
self.scale_k = 1
|
||||
self.scale_n_prefill = 1
|
||||
if hasattr(self.quant_method, "scale_n") and hasattr(self.quant_method, "scale_k"):
|
||||
self.scale_n = self.quant_method.scale_n
|
||||
self.scale_k = self.quant_method.scale_k
|
||||
if hasattr(self.quant_method, "scale_n_prefill"):
|
||||
self.scale_n_prefill = self.quant_method.scale_n_prefill
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(moe.max_num_tokens, self.hidden_size),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(moe.max_num_tokens, num_experts),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
|
||||
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int,
|
||||
expert_id=0):
|
||||
#print("w13 shape is:", expert_data.shape, loaded_weight.shape)
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2 and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
n_v, k_v = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.reshape(1, n_v, 1, k_v) #[1, n, 1, k]
|
||||
if self.scale_n_prefill != self.scale_n and hasattr(self, 'w13_weight_scale_inv_prefill'):
|
||||
loaded_weight0 = loaded_weight.repeat(self.scale_n_prefill,1,self.scale_k,1).permute(1,0,3,2).reshape([n_v * self.scale_n_prefill, k_v * self.scale_k])
|
||||
shard_size = self.w13_weight_scale_inv_prefill.data[expert_id].shape[shard_dim] // 2
|
||||
loaded_weight0 = loaded_weight0.narrow(shard_dim, shard_size * tp_rank, shard_size)
|
||||
|
||||
if shard_id == "w1":
|
||||
self.w13_weight_scale_inv_prefill.data[expert_id, :loaded_weight0.shape[0], :loaded_weight0.shape[1]] = loaded_weight0
|
||||
elif shard_id == "w3":
|
||||
self.w13_weight_scale_inv_prefill.data[expert_id, -loaded_weight0.shape[0]:, -loaded_weight0.shape[1]:] = loaded_weight0
|
||||
else:
|
||||
raise ValueError('error shard_id: ',shard_id)
|
||||
|
||||
loaded_weight = loaded_weight.repeat(self.scale_n,1,self.scale_k,1).permute(1,0,3,2).reshape([n_v * self.scale_n, k_v * self.scale_k])
|
||||
#print("w13 repeat shape is:", expert_data.shape, loaded_weight.shape)
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False,
|
||||
expert_id=0):
|
||||
#print("w2 shape is:", expert_data.shape, loaded_weight.shape)
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2 and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
n_v, k_v = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.reshape(1, n_v, 1, k_v) #[1, n, 1, k]
|
||||
if self.scale_n_prefill != self.scale_n and hasattr(self, 'w2_weight_scale_inv_prefill'):
|
||||
loaded_weight0 = loaded_weight.repeat(self.scale_k,1,self.scale_n_prefill,1).permute(1,0,3,2).reshape([n_v * self.scale_k, k_v * self.scale_n_prefill])
|
||||
shard_size = self.w2_weight_scale_inv_prefill.data[expert_id].shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight0 = loaded_weight0.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
self.w2_weight_scale_inv_prefill.data[expert_id] = loaded_weight0
|
||||
|
||||
#print("loaded_weight:", loaded_weight.shape)
|
||||
loaded_weight = loaded_weight.repeat(self.scale_k,1,self.scale_n,1).permute(1,0,3,2).reshape([n_v * self.scale_k, k_v * self.scale_n])
|
||||
#print("w2 repeat shape is:", expert_data.shape, loaded_weight.shape)
|
||||
#if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def weight_loader(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int, return_success=True) -> None:
|
||||
|
||||
if self.quant_config and self.quant_config.get_name() == "mxfp4":
|
||||
# (FIXME) for gpt-oss all experts are combined
|
||||
if "bias" in weight_name:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
param.data[:, :dim1].copy_(loaded_weight)
|
||||
else:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
dim2 = loaded_weight.shape[2]
|
||||
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
|
||||
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||
if expert_id == -1:
|
||||
return False if return_success else None
|
||||
|
||||
quant_method_name = self.quant_method.__class__.__name__
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
if self.quant_method.__class__.__name__ in (
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod"):
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
|
||||
# WEIGHT_SCALE_SUPPORTED = [
|
||||
# e.value for e in FusedMoeWeightScaleSupported
|
||||
# ]
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
# dimension intermediate_size_per_partition is used.
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
param.data.copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
|
||||
# Case for BitsAndBytes
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
if use_bitsandbytes_4bit:
|
||||
shard_dim = 0
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
if shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
# BNB inflight quantization has already sharded the weights
|
||||
full_load = True
|
||||
self._load_w13(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank,
|
||||
load_full=full_load,
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size_per_partition is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
shard_dim = int(not shard_dim)
|
||||
|
||||
full_load = len(loaded_weight.shape) == 3
|
||||
if full_load:
|
||||
shard_dim += 1
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
final_shape = list(loaded_weight.shape)
|
||||
if shard_id in ["w1", "w3"]:
|
||||
final_shape[1] *= 2
|
||||
final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size
|
||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
expert_data = param.data if full_load else param.data[expert_id]
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
# this is needed for compressed-tensors only
|
||||
loaded_weight = loaded_weight.to(param.data.device)
|
||||
|
||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param.data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return True if return_success else None
|
||||
|
||||
# Case g_idx
|
||||
if "g_idx" in weight_name:
|
||||
self._load_g_idx(shard_dim=0,
|
||||
shard_id=shard_id,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
|
||||
if "ModelOpt" in quant_method_name:
|
||||
# Determine per-tensor weight scale patterns based on variant
|
||||
# Use the dedicated method instead of brittle string matching
|
||||
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
|
||||
)
|
||||
|
||||
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
|
||||
# weights scales.
|
||||
# Input scales are always per-tensor.
|
||||
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
|
||||
# "weight_scale" for per-tensor scales.
|
||||
is_per_tensor = ("weight_scale_2" in weight_name
|
||||
if uses_weight_scale_2 else "weight_scale"
|
||||
in weight_name) or "input_scale" in weight_name
|
||||
if is_per_tensor:
|
||||
self._load_per_tensor_weight_scale(
|
||||
shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# If the weight is w13_weight_scale and w13_weight_scales are
|
||||
# combined into single loaded_weight, call
|
||||
# _load_combined_w13_weight_scale() to load it.
|
||||
# This is checked by comparing the hidden_out dims of the
|
||||
# loaded_weight and the param.
|
||||
if "w13_weight_scale" in weight_name:
|
||||
loaded_weight_hidden_out = loaded_weight.shape[-2]
|
||||
param_hidden_out = param.data.shape[-2] * self.tp_size
|
||||
if loaded_weight_hidden_out == param_hidden_out:
|
||||
self._load_combined_w13_weight_scale(
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
param=param,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# For other weights, call _load_model_weight_or_group_weight_scale()
|
||||
# to load it.
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
|
||||
# Case weight scales, zero_points and offset
|
||||
if ("scale" in weight_name or "zero" in weight_name
|
||||
or "offset" in weight_name):
|
||||
# load the weight scales and zp based on the quantization scheme
|
||||
# supported weight scales/zp can be found in
|
||||
# FusedMoeWeightScaleSupported
|
||||
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
||||
# specific to each case
|
||||
quant_method = getattr(param, "quant_method", None)
|
||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||
self._load_per_channel_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
elif quant_method in [
|
||||
FusedMoeWeightScaleSupported.GROUP.value,
|
||||
FusedMoeWeightScaleSupported.BLOCK.value,
|
||||
]:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank,
|
||||
load_full_w2=getattr(param, "load_full_w2", False),
|
||||
expert_id=expert_id)
|
||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
else:
|
||||
WEIGHT_SCALE_SUPPORTED = [
|
||||
e.value for e in FusedMoeWeightScaleSupported
|
||||
]
|
||||
raise ValueError(
|
||||
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
||||
return True if return_success else None
|
||||
|
||||
# Case weight_shape
|
||||
if "weight_shape" in weight_name:
|
||||
# only required by compressed-tensors
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return True if return_success else None
|
||||
|
||||
# Case model weights
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
return False if return_success else None
|
||||
|
||||
def _load_model_weight_or_group_weight_scale(self,
|
||||
shard_dim: int,
|
||||
expert_data: torch.Tensor,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full_w2: bool = False,
|
||||
expert_id: int = 0):
|
||||
"""
|
||||
Load grouped weight scales for group quantization or model weights
|
||||
:param shard_dim: dimension to shard
|
||||
:param expert_data: parameter for a particular expert
|
||||
:param shard_id: either w1, w2, or w3
|
||||
:param loaded_weight: checkpoint weight to load into the param
|
||||
:param tp_rank: tensor parallel rank
|
||||
:param load_full_w2: whether or not the w2 loaded should be sharded.
|
||||
"""
|
||||
if shard_id == "w2":
|
||||
# In the case where we have actorder/g_idx, we do not partition the
|
||||
# w2 scales, as indicated by `load_full` argument, for all tp cases
|
||||
self._load_w2(shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
load_full=load_full_w2,
|
||||
expert_id=expert_id)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
expert_id=expert_id)
|
||||
|
||||
class UnquantizedFusedMoEMethod():
|
||||
def forward_vacc(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = x
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
dtype = hidden_states.dtype
|
||||
intermediate_size = layer.w2_weight.shape[-1]
|
||||
gating_output=router_logits
|
||||
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||
gating_output = gating_output.view(num_tokens, global_num_experts)
|
||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(dtype)
|
||||
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
sel_experts = topk_ids.shape[1]
|
||||
if hidden_states.shape[0] == 1:
|
||||
for id in range(sel_experts):
|
||||
expert_idx = topk_ids[0][id]
|
||||
expert_w1 = layer.w13_weight[expert_idx].contiguous()
|
||||
expert_w2 = layer.w2_weight[expert_idx].contiguous()
|
||||
|
||||
expert_weights = topk_weights[0][id].to(hidden_states.dtype)
|
||||
|
||||
x = hidden_states
|
||||
x = F.linear(x, expert_w1)
|
||||
gate = F.silu(x[:, :intermediate_size])
|
||||
x = x[:, intermediate_size:] * gate
|
||||
x = F.linear(x, expert_w2)
|
||||
|
||||
current_hidden_states = x * expert_weights
|
||||
current_hidden_states = current_hidden_states.to(x.dtype)
|
||||
final_hidden_states += current_hidden_states
|
||||
else:
|
||||
for expert_idx in range(global_num_experts):
|
||||
# topk_ids [tokens, experts] => sample:[10, 8]
|
||||
# expert_mask [tokens, experts] => sample:[10, 8]
|
||||
expert_mask = topk_ids == expert_idx
|
||||
|
||||
idx = torch.where(expert_mask)[0]
|
||||
if idx.numel() == 0:
|
||||
continue
|
||||
|
||||
expert_w1 = layer.w13_weight[expert_idx].contiguous()
|
||||
expert_w2 = layer.w2_weight[expert_idx].contiguous()
|
||||
|
||||
# [seq, experts]
|
||||
expert_weights = (
|
||||
topk_weights.masked_select(expert_mask)
|
||||
.unsqueeze(1)
|
||||
.to(hidden_states.dtype)
|
||||
)
|
||||
|
||||
x = hidden_states[idx]
|
||||
x = F.linear(x, expert_w1)
|
||||
gate = F.silu(x[:, :intermediate_size])
|
||||
x = x[:, intermediate_size:] * gate
|
||||
x = F.linear(x, expert_w2)
|
||||
|
||||
current_hidden_states = x * expert_weights
|
||||
current_hidden_states = current_hidden_states.to(x.dtype)
|
||||
# final_hidden_states[idx] += current_hidden_states
|
||||
final_hidden_states.index_add_(0, idx, current_hidden_states)
|
||||
return final_hidden_states.view(orig_shape) # type: ignore
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
custom_forward = self.forward
|
||||
if x.device.type == "vacc":
|
||||
custom_forward = UnquantizedFusedMoEMethod.forward_vacc
|
||||
|
||||
return custom_forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
|
||||
32
vllm_vacc/vllm/model_executor/layers/layernorm.py
Normal file
32
vllm_vacc/vllm/model_executor/layers/layernorm.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
|
||||
def RMSNorm_forward_vacc(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# if residual is not None:
|
||||
# x = x + residual
|
||||
# residual = x
|
||||
hidden_size = x.shape[-1]
|
||||
if hidden_size != self.hidden_size:
|
||||
raise ValueError("Expected hidden_size to be "
|
||||
f"{self.hidden_size}, but found: {hidden_size}")
|
||||
if self.variance_size_override is None:
|
||||
x_var = x
|
||||
else:
|
||||
if hidden_size < self.variance_size_override:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be at least "
|
||||
f"{self.variance_size_override}, but found: {hidden_size}")
|
||||
x_var = x[:, :, :self.variance_size_override]
|
||||
# x_var=x_var.unsqueeze(0)
|
||||
# out = torch.vacc.rms_norm(x_var,self.weight,self.variance_epsilon)
|
||||
# if residual is None:
|
||||
# return out.squeeze(0)
|
||||
# else:
|
||||
# return out.squeeze(0), residual
|
||||
out = torch.vacc.fused_residual_rmsnorm(x_var, self.weight, residual, self.variance_epsilon, x_var, residual)
|
||||
return out
|
||||
465
vllm_vacc/vllm/model_executor/layers/linear.py
Normal file
465
vllm_vacc/vllm/model_executor/layers/linear.py
Normal file
@@ -0,0 +1,465 @@
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
BlockQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
PerTensorScaleParameter,
|
||||
RowvLLMParameter)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
WEIGHT_LOADER_V2_SUPPORTED,
|
||||
LinearBase,
|
||||
RowParallelLinear)
|
||||
|
||||
def ReplicatedLinear__init__(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super(ReplicatedLinear,self).__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.scale_k = 1 # quant_block_k 128 需要除以 scale_k, 如设置为2 即 quant_block_k 是 64
|
||||
self.scale_k_slice = 1
|
||||
self.scale_n = 1
|
||||
self.scale_n_slice = 1
|
||||
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
|
||||
gcd_value = quant_config.weight_block_size[1]
|
||||
import math
|
||||
if input_size % quant_config.weight_block_size[1]:
|
||||
gcd_value = math.gcd(input_size % quant_config.weight_block_size[1], quant_config.weight_block_size[1])
|
||||
self.scale_k =self.scale_k * quant_config.weight_block_size[1] // gcd_value
|
||||
self.scale_k_slice = input_size // gcd_value
|
||||
if output_size % quant_config.weight_block_size[0]:
|
||||
gcd_value = math.gcd(output_size % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
|
||||
self.scale_n = self.scale_n * quant_config.weight_block_size[0] // gcd_value
|
||||
self.scale_n_slice = output_size // gcd_value
|
||||
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size, [self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
scale_k = self.scale_k,
|
||||
scale_n = self.scale_n,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def ReplicatedLinear_weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
# If the weight on disk does not have a shape, give it one
|
||||
# (such scales for AutoFp8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
if self.scale_k > 1 and len(loaded_weight.shape) == 2:
|
||||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||||
loaded_weight = loaded_weight.expand(self.scale_k, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,2,0).reshape([loaded_weight.shape[1], -1])[:, :self.scale_k_slice]
|
||||
#[1,n,k] -> [scale_k,n,k] -> [n,k,scale_k] -> [n, k*scale_k]
|
||||
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
|
||||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||||
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[2]])[:self.scale_n_slice]
|
||||
|
||||
assert param.size() == loaded_weight.size(), f'{param.size()}, {loaded_weight.size()}'
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def refine_block(block_size:list[int],
|
||||
weight_size:list[int],
|
||||
dim:int=0,
|
||||
pingpong_size:int = 2.5*1024*1024, #bytes
|
||||
core_number:int = 4,
|
||||
data_type:int = 2, #bfloat16
|
||||
max_iter_number:int = 2):
|
||||
'''
|
||||
对于不均匀分core, 需要每个core <= 2.5M 才能保证可以pingpong,
|
||||
core间相差数量为 block_size[dim] * weight_size[1-dim]
|
||||
缩小block_size可以减小core间差距,使得更平均一些,直到大core数据量不超
|
||||
如果均匀分core已经超了或者没有超,就没必要调整
|
||||
'''
|
||||
if dim < 0:
|
||||
dim = 2 + dim
|
||||
|
||||
pingpong_size = pingpong_size / data_type # number of data
|
||||
|
||||
block_size_refine = block_size[dim]
|
||||
all_block_number = weight_size[dim] // block_size_refine
|
||||
|
||||
if all_block_number % core_number == 0:
|
||||
#均分,这种情况不管有没有超,都无需调整
|
||||
return block_size_refine
|
||||
|
||||
block_number_tiny = all_block_number // core_number
|
||||
block_number_big = all_block_number // core_number + 1
|
||||
if block_number_tiny * block_size_refine * weight_size[1-dim] >= pingpong_size or \
|
||||
block_number_big * block_size_refine * weight_size[1-dim] <= pingpong_size :
|
||||
# 小的已经超了,无法再调整了
|
||||
# 大的没有超,无需调整
|
||||
return block_size_refine
|
||||
|
||||
all_block_number_tmp = all_block_number
|
||||
block_size_refine_tmp = block_size_refine
|
||||
for iter_index in range(max_iter_number):
|
||||
all_block_number_tmp = all_block_number_tmp * 2
|
||||
block_size_refine_tmp = block_size_refine_tmp // 2
|
||||
if all_block_number_tmp % core_number == 0:
|
||||
block_number_tiny = all_block_number // core_number
|
||||
if block_number_tiny * block_size_refine_tmp * weight_size[1-dim] <= pingpong_size:
|
||||
return block_size_refine_tmp
|
||||
else:
|
||||
#均分还是超了,无需调整
|
||||
return block_size_refine
|
||||
else:
|
||||
block_number_big = all_block_number_tmp // core_number + 1
|
||||
if block_number_big * block_size_refine_tmp * weight_size[1-dim] <= pingpong_size:
|
||||
return block_size_refine_tmp
|
||||
|
||||
return block_size_refine
|
||||
|
||||
def ColumnParallelLinear__init__(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[List[int]] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||
if not disable_tp else 0)
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
super(ColumnParallelLinear,self).__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
self.scale_n = 1
|
||||
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
|
||||
gcd_value = quant_config.weight_block_size[0]
|
||||
|
||||
import math
|
||||
if hasattr(self, "output_sizes"):
|
||||
# 对于Merge类型的ColumnParallelLinear来说,需要根据每个Part Linear的shape,去计算最小公约数
|
||||
output_size_no_merge = self.output_partition_sizes
|
||||
block_values = [o % quant_config.weight_block_size[0] for o in output_size_no_merge]
|
||||
is_gcd_recompute = sum(block_values)
|
||||
|
||||
if is_gcd_recompute:
|
||||
import math
|
||||
block_values.append(quant_config.weight_block_size[0])
|
||||
gcd_value = math.gcd(*block_values)
|
||||
# Notice:
|
||||
# 这儿对于非对齐的Part-Weight, 可能需要验证一下流程
|
||||
# 对于DeepSeek来说,仅存在于MLP&MOE中的MergeColumnLinear,都是Shape一致的PartWeight
|
||||
# 对于QWen3来说,会存在QKVColumnLinear,是Shape不一致的PartWeight,但是由于QWen3当下的切分方案,对于gcd_value无感,无需重计算所以暂时不会进来
|
||||
if hasattr(self, "output_sizes") and len(output_size_no_merge) == 2 and output_size_no_merge[0] == output_size_no_merge[1]:
|
||||
#only refine mlp w13
|
||||
gcd_value = refine_block([gcd_value, quant_config.weight_block_size[1]], [output_size_no_merge[0], input_size])
|
||||
self.scale_n =self.scale_n * quant_config.weight_block_size[0] // gcd_value
|
||||
else:
|
||||
# 对于非Merge的ColumnParallelLinear来说, 仅仅根据当下shape去计算最小公约数
|
||||
output_size_no_merge = self.output_size_per_partition
|
||||
is_gcd_recompute = output_size_no_merge % quant_config.weight_block_size[0]
|
||||
if is_gcd_recompute:
|
||||
gcd_value = math.gcd(output_size_no_merge % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
|
||||
self.scale_n =self.scale_n * quant_config.weight_block_size[0] // gcd_value
|
||||
|
||||
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
scale_n = self.scale_n,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.update_param_tp_status()
|
||||
|
||||
def ColumnParallelLinear_weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
|
||||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||||
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[-1]])
|
||||
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,n,k] -> [n*scale_n, k]
|
||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
def weight_loader_v2(self,
|
||||
param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None):
|
||||
|
||||
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
|
||||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||||
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[-1]])
|
||||
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,n,k] -> [n*scale_n, k]
|
||||
|
||||
if self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
|
||||
if self.quant_method.scale_k > 1 and len(loaded_weight.shape) == 2 and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||||
loaded_weight = loaded_weight.unsqueeze(1) #[k,1,n]
|
||||
loaded_weight = loaded_weight.expand(loaded_weight.shape[0], self.quant_method.scale_k, loaded_weight.shape[2]).reshape([-1, loaded_weight.shape[2]])
|
||||
#[k,1,n] -> [k,scale_k,n]] -> [k*scale_k, n]
|
||||
|
||||
if loaded_shard_id is None:
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||||
shard_id=0)
|
||||
return
|
||||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
||||
return
|
||||
# TODO: @dsikka - move to parameter.py
|
||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||
return
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8LinearMethod, Fp8MoEMethod)
|
||||
assert self.quant_method is not None
|
||||
assert isinstance(self.quant_method,
|
||||
(Fp8LinearMethod, Fp8MoEMethod))
|
||||
weight_block_size = self.quant_method.quant_config.weight_block_size
|
||||
assert weight_block_size is not None
|
||||
block_n, _ = weight_block_size[0] // self.scale_n, weight_block_size[1]
|
||||
shard_offset = (
|
||||
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
|
||||
block_n) // tp_size
|
||||
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
|
||||
block_n // tp_size)
|
||||
else:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size)
|
||||
|
||||
def RowParallelLinear__init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||
if not disable_tp else 0)
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
super(RowParallelLinear, self).__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.scale_k = 1 # quant_block_k 128 需要除以 scale_k, 如设置为2 即 quant_block_k 是 64
|
||||
self.scale_n = 1
|
||||
self.scale_n_slice = 1
|
||||
|
||||
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
|
||||
gcd_value = quant_config.weight_block_size[1]
|
||||
import math
|
||||
if self.input_size_per_partition % quant_config.weight_block_size[1]:
|
||||
gcd_value = math.gcd(self.input_size_per_partition % quant_config.weight_block_size[1], quant_config.weight_block_size[1])
|
||||
self.scale_k =self.scale_k * quant_config.weight_block_size[1] // gcd_value
|
||||
if output_size % quant_config.weight_block_size[0]:
|
||||
gcd_value = math.gcd(output_size % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
|
||||
self.scale_n = self.scale_n * quant_config.weight_block_size[0] // gcd_value
|
||||
self.scale_n_slice = output_size // gcd_value
|
||||
# N = 576, block = 128, n方向scale 扩充需要知道两个信息: 1.拷贝多少份 scale_n; 2. slice 有效的 scale_n_slice
|
||||
# scale = [s0,s1,s2,s3,s4] 拷贝scale_n=2份
|
||||
# scale = [s0,s0,s1,s1,s2,s2,s3,s3,s4,s4],slice scale_n_slice=9份 =>[s0,s0,s1,s1,s2,s2,s3,s3,s4]
|
||||
|
||||
if self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
|
||||
gcd_value = quant_config.group_size
|
||||
import math
|
||||
if self.input_size_per_partition % quant_config.group_size:
|
||||
gcd_value = math.gcd(self.input_size_per_partition % quant_config.group_size, quant_config.group_size)
|
||||
self.quant_method.scale_k = self.quant_method.scale_k * quant_config.group_size // gcd_value
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=[self.output_size],
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
scale_k = self.scale_k,
|
||||
scale_n = self.scale_n,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def RowParallelLinear_weight_loader_v2_vacc(self, param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
if self.scale_k > 1 and len(loaded_weight.shape) == 2:
|
||||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||||
loaded_weight = loaded_weight.expand(self.scale_k, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,2,0).reshape([loaded_weight.shape[1], -1])
|
||||
#[1,n,k] -> [scale_k,n,k] -> [n,k,scale_k] -> [n, k*scale_k]
|
||||
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
|
||||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||||
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[2]])[:self.scale_n_slice]
|
||||
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,k] -> [n*scale_n,k]
|
||||
|
||||
elif self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
|
||||
# broadcast scale TODO: broadcast zero
|
||||
if self.quant_method.scale_k > 1 and len(loaded_weight.shape) == 2 and loaded_weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
|
||||
loaded_weight = loaded_weight.unsqueeze(1) #[k,1,n]
|
||||
loaded_weight = loaded_weight.expand(loaded_weight.shape[0], self.quant_method.scale_k, loaded_weight.shape[2]).reshape([-1, loaded_weight.shape[2]])
|
||||
#[k,1,n] -> [k,scale_k,n]] -> [k*scale_k, n]
|
||||
|
||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
class UnquantizedLinearMethod():
|
||||
"""Linear method without quantization."""
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if bias is not None:
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
parallel_embedding_output = None
|
||||
if memory_recycler is not None:
|
||||
if memory_recycler.EMBEDDING_OUT_BUFFER.size(0) == x.size(0):
|
||||
parallel_embedding_output = memory_recycler.EMBEDDING_OUT_BUFFER
|
||||
return torch.mm(x.view(-1, x.shape[-1]), layer.weight.transpose(1,0), out=parallel_embedding_output).view(*(x.shape[:-1]), layer.weight.shape[0])
|
||||
81
vllm_vacc/vllm/model_executor/layers/logits_processor.py
Normal file
81
vllm_vacc/vllm/model_executor/layers/logits_processor.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A layer that compute logits from hidden_stats."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import tensor_model_parallel_gather
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
def DumpLogits(logits):
|
||||
import os
|
||||
import time
|
||||
VLLM_VACC_DUMP_LOGITS = os.getenv('VLLM_VACC_DUMP_LOGITS')
|
||||
if VLLM_VACC_DUMP_LOGITS:
|
||||
logit_arr = logits.cpu().to(torch.float32).numpy()
|
||||
timestamp = time.time()
|
||||
# print("timestamp:", timestamp)
|
||||
if not os.path.exists(VLLM_VACC_DUMP_LOGITS):
|
||||
os.makedirs(VLLM_VACC_DUMP_LOGITS)
|
||||
logit_path = os.path.join(VLLM_VACC_DUMP_LOGITS, f'logit_{timestamp}.bin')
|
||||
summary_path = os.path.join(VLLM_VACC_DUMP_LOGITS, 'summary.txt')
|
||||
with open(summary_path, 'a') as f:
|
||||
f.write(f'{logit_path}\n')
|
||||
# print("save file:", logit_path)
|
||||
logit_arr.tofile(logit_path)
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
total_size_all_gather_input = hidden_states.size(0) * lm_head.weight.shape[0] * hidden_states.element_size() * get_tp_group().world_size
|
||||
# fuse matmul and all_gather if hidden size less equal 7168 and divisible by 32 according to dsp developer
|
||||
# if hidden_states.size(1) <= 7168 and hidden_states.size(1) % 32 == 0:
|
||||
# fuse matmul and all_gather if total size of all_gather inputs doesn't exceeds 4MB
|
||||
if total_size_all_gather_input <= 4194304:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import fused_matmul_allgather
|
||||
# from vllm.distributed.parallel_state import get_tp_group
|
||||
|
||||
logits = fused_matmul_allgather(hidden_states, lm_head.weight.T,
|
||||
get_tp_group().world_size,
|
||||
get_tp_group().rank_in_group,
|
||||
get_tp_group().group_id,
|
||||
get_tp_group().rank_device_infos)
|
||||
# ensure there is PP last stage at front
|
||||
if (hasattr(current_platform, 'supports_v1') and current_platform.supports_v1(current_platform)) or get_tp_group().rank_in_group == 0:
|
||||
seq, hidden_dims = hidden_states.shape
|
||||
logits = logits.movedim(0, 1)
|
||||
logits = logits.reshape(seq, -1)
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
if get_tp_group().rank_in_group == 0:
|
||||
DumpLogits(logits)
|
||||
else:
|
||||
logits = None
|
||||
return logits
|
||||
except Exception as e:
|
||||
print("Fused Matmul with AllGather run Fail, now use unfused. " ,e)
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.quant_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
|
||||
#print("quant method:", lm_head, lm_head.quant_method, embedding_bias, logits.shape)
|
||||
# Gather logits for TP
|
||||
logits = self._gather_logits(logits)
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
return logits
|
||||
51
vllm_vacc/vllm/model_executor/layers/pooler.py
Normal file
51
vllm_vacc/vllm/model_executor/layers/pooler.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import PoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolerActivation, get_pooling_params
|
||||
|
||||
|
||||
class ClassifierPooler(Pooler):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_size]
|
||||
|
||||
if pooled_data.dtype != self.head_dtype:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
if self.classifier is not None:
|
||||
pooled_data = self.classifier(pooled_data)
|
||||
# pooled_data shape: [batchsize, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
pooled_data -= self.logit_bias
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
flags = [p.activation for p in pooling_params]
|
||||
|
||||
if len(set(flags)) == 1:
|
||||
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
|
||||
else:
|
||||
scores = [
|
||||
self.act_fn(vecs) if f else vecs
|
||||
for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
# scores shape: [batchsize, num_labels]
|
||||
return scores
|
||||
|
||||
|
||||
class PoolerNormalize(PoolerActivation):
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return torch.vacc.l2_norm(pooled_data, epsilon=1e-12)
|
||||
@@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Literal, Type, get_args
|
||||
|
||||
QuantizationMethods = Literal[
|
||||
# "aqlm",
|
||||
"awq",
|
||||
"deepspeedfp",
|
||||
"tpu_int8",
|
||||
"fp8",
|
||||
"ptpc_fp8",
|
||||
"fbgemm_fp8",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"bitblas",
|
||||
"gguf",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"gptq_bitblas",
|
||||
"awq_marlin",
|
||||
"gptq",
|
||||
"compressed-tensors",
|
||||
"bitsandbytes",
|
||||
"hqq",
|
||||
"experts_int8",
|
||||
"ipex",
|
||||
"quark",
|
||||
"moe_wna16",
|
||||
"torchao",
|
||||
"auto-round",
|
||||
"rtn",
|
||||
"inc",
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
615
vllm_vacc/vllm/model_executor/layers/quantization/fp8.py
Normal file
615
vllm_vacc/vllm/model_executor/layers/quantization/fp8.py
Normal file
@@ -0,0 +1,615 @@
|
||||
|
||||
import functools
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_fp8_block_linear, check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale, create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter, expert_weight_is_col_major,
|
||||
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
|
||||
validate_fp8_block_shape)
|
||||
# from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
# all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||
# cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
# normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
# requantize_with_max_scale)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, all_close_1d, convert_to_channelwise,
|
||||
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, is_layer_skipped)
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
from vllm.utils import has_deep_gemm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
def Fp8LinearMethod__init(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
|
||||
if self.block_quant:
|
||||
self.block_size = self.quant_config.weight_block_size
|
||||
if self.block_quant:
|
||||
# Marlin doesn't support block-wise fp8
|
||||
self.use_marlin = False
|
||||
self.scale_k = 1
|
||||
self.scale_n = 1
|
||||
self.scale_n_prefill = 1 # only for fp8 moe
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
class Fp8LinearMethod(LinearMethodBase):
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
if self.block_quant:
|
||||
|
||||
scale_n = extra_weight_attrs.get("scale_n")
|
||||
scale_k = extra_weight_attrs.get("scale_k")
|
||||
if scale_n is not None:
|
||||
self.scale_n = scale_n
|
||||
if scale_k is not None:
|
||||
self.scale_k = scale_k
|
||||
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0] // self.scale_n ,
|
||||
self.quant_config.weight_block_size[1] // self.scale_k ,
|
||||
)
|
||||
# Required by row parallel
|
||||
if (tp_size > 1
|
||||
and input_size // input_size_per_partition == tp_size
|
||||
and input_size_per_partition % block_k != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
# Required by column parallel or enabling merged weights
|
||||
if (tp_size > 1 and output_size // output_size_per_partition
|
||||
== tp_size) or len(output_partition_sizes) > 1:
|
||||
for output_partition_size in output_partition_sizes:
|
||||
if output_partition_size % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_partition_size = "
|
||||
f"{output_partition_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
# WEIGHT
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
params_dtype)
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
if not self.block_quant:
|
||||
scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
else:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
scale = BlockQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
(output_size_per_partition + block_n - 1) // block_n,
|
||||
(input_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale_inv, _ = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv)
|
||||
else:
|
||||
weight = layer.weight.data
|
||||
weight_scale_inv = layer.weight_scale_inv.data
|
||||
|
||||
if isinstance(layer, QKVParallelLinear):
|
||||
# NOTE: for QKVParallelLinear
|
||||
# weight_scale should be divisible by 8 Dsps
|
||||
shape = weight_scale_inv.shape[0]
|
||||
repeat = 1
|
||||
while shape % 8 != 0:
|
||||
repeat *= 2
|
||||
shape = shape * repeat
|
||||
weight_scale_inv = torch.repeat_interleave(weight_scale_inv, repeats=repeat, dim=0)
|
||||
|
||||
# weight = self._maybe_pad_weight(weight)
|
||||
# if self.block_quant:
|
||||
# maybe_post_process_fp8_weight_block(
|
||||
# layer, self.cutlass_block_fp8_supported)
|
||||
|
||||
# Torch.compile cannot use Parameter subclasses.
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale_inv,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
# Note: lazy import to avoid triton import error.
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_w8a8_block_fp8_linear)
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
block_size=[layer.weight.shape[0] // layer.weight_scale_inv.shape[0], layer.weight.shape[1] // layer.weight_scale_inv.shape[1]],
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
# return apply_fp8_linear(
|
||||
# input=x,
|
||||
# weight=layer.weight,
|
||||
# weight_scale=layer.weight_scale,
|
||||
# input_scale=layer.input_scale,
|
||||
# bias=bias,
|
||||
# cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||
# # Default to using per_token quantization if cutlass is supported
|
||||
# use_per_token_if_dynamic=self.cutlass_fp8_supported)
|
||||
|
||||
def Fp8MoEMethod_init_(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
self.layer = layer
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
self.flashinfer_moe_backend = None
|
||||
|
||||
self.scale_k = 1
|
||||
self.scale_n = 1
|
||||
self.scale_n_prefill = 1
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm() or current_platform.is_vacc:
|
||||
self.use_marlin = False
|
||||
|
||||
# Check for DeepGemm support.
|
||||
self.allow_deep_gemm = False
|
||||
if envs.VLLM_USE_DEEP_GEMM:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||
elif not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using "
|
||||
" DeepGemm kernels")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(90)):
|
||||
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
||||
self.allow_deep_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
|
||||
# Check for CutlassBlockScaledGroupedGemm support.
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = False
|
||||
if not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using "
|
||||
"CutlassBlockScaledGroupedGemm kernels")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(100)):
|
||||
logger.info_once(
|
||||
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
|
||||
)
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"CutlassBlockScaledGroupedGemm not supported on the current "
|
||||
"platform.")
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
self.fused_experts = functools.partial( # type: ignore
|
||||
fused_experts,
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
allow_cutlass_block_scaled_grouped_gemm=(
|
||||
self.allow_cutlass_block_scaled_grouped_gemm))
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
|
||||
scale_n = extra_weight_attrs.get("scale_n")
|
||||
scale_n_prefill = extra_weight_attrs.get("scale_n_prefill")
|
||||
scale_k = extra_weight_attrs.get("scale_k")
|
||||
if scale_n is not None:
|
||||
self.scale_n = scale_n
|
||||
if scale_k is not None:
|
||||
self.scale_k = scale_k
|
||||
if scale_n_prefill is not None:
|
||||
self.scale_n_prefill = scale_n_prefill
|
||||
|
||||
if self.quant_config is not None and self.quant_config.weight_block_size is not None:
|
||||
self.gcd_value = self.quant_config.weight_block_size[0]
|
||||
|
||||
output_size_no_merge = intermediate_size_per_partition
|
||||
#assert isinstance(output_size_no_merge, int), f"merge output size should divded int, valuue is: {output_size_no_merge}"
|
||||
|
||||
if output_size_no_merge % self.quant_config.weight_block_size[0]:
|
||||
import math
|
||||
gcd_value = math.gcd(output_size_no_merge % self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[0])
|
||||
self.scale_n =self.scale_n * self.quant_config.weight_block_size[0] // gcd_value
|
||||
self.scale_n_prefill =self.scale_n_prefill * self.quant_config.weight_block_size[0] // gcd_value
|
||||
if hidden_size % self.quant_config.weight_block_size[1]:
|
||||
import math
|
||||
gcd_value = math.gcd(hidden_size % self.quant_config.weight_block_size[1], self.quant_config.weight_block_size[1])
|
||||
self.scale_k =self.scale_k * self.quant_config.weight_block_size[1] // gcd_value
|
||||
# self.scale_k = self.scale_n
|
||||
|
||||
# print('output_size_no_merge', output_size_no_merge)
|
||||
# 按 block_size 分core
|
||||
# output_size_no_merge = 384
|
||||
# block_size = 128: 384 = 3x128 只能分3core x 128
|
||||
# block_size = 16: 384 = 24x16 8core x (3x16) 可以分到 8core
|
||||
|
||||
# output_size_no_merge = 512
|
||||
# block_size = 128: 512 = 4x128 只能分 4core x 128
|
||||
# block_size = 64: 512 = 8x64 可以分到 8core x 64
|
||||
|
||||
# output_size_no_merge = 768
|
||||
# block_size = 128: 768 = 6x128 只能分 6core x 128
|
||||
# block_size = 32: 768 = 8x(3x32) 可以分到 8core x (3x32)
|
||||
|
||||
core_num = 8
|
||||
min_block_size = 4
|
||||
block_size_tmp = self.quant_config.weight_block_size[0] // self.scale_n
|
||||
if output_size_no_merge > block_size_tmp and \
|
||||
output_size_no_merge % block_size_tmp == 0 and \
|
||||
output_size_no_merge // block_size_tmp < core_num and \
|
||||
output_size_no_merge % core_num == 0:
|
||||
core_num_old = output_size_no_merge // block_size_tmp
|
||||
import math
|
||||
gcd_value = math.gcd(core_num, core_num_old)
|
||||
new_scale = core_num // gcd_value
|
||||
if block_size_tmp // new_scale >= min_block_size:
|
||||
self.scale_n = new_scale * self.scale_n
|
||||
|
||||
|
||||
#print("moe scale n is:", self.scale_n, self.scale_k, intermediate_size_per_partition)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.scale_n != self.scale_n_prefill:
|
||||
block_n_prefill = self.quant_config.weight_block_size[0] // self.scale_n_prefill
|
||||
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0] // self.scale_n,
|
||||
self.quant_config.weight_block_size[1] // self.scale_k,
|
||||
)
|
||||
# NOTE: To ensure proper alignment of the block-wise quantization
|
||||
# scales, the output_size of the weights for both the gate and up
|
||||
# layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size_per_partition % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
if (tp_size > 1
|
||||
and hidden_size % block_k != 0):
|
||||
# Required by row parallel
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
if not self.block_quant:
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, 2, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
else:
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size_per_partition + block_n - 1) //
|
||||
block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
(intermediate_size_per_partition + block_n - 1) // block_n,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
if self.scale_n != self.scale_n_prefill:
|
||||
w13_weight_scale_prefill = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size_per_partition + block_n_prefill - 1) //
|
||||
block_n_prefill),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale_prefill = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
(intermediate_size_per_partition + block_n_prefill - 1) // block_n_prefill,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_inv_prefill", w13_weight_scale_prefill)
|
||||
layer.register_parameter("w2_weight_scale_inv_prefill", w2_weight_scale_prefill)
|
||||
|
||||
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.
|
||||
value} if self.block_quant else
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
if self.scale_n != self.scale_n_prefill:
|
||||
set_weight_attrs(w13_weight_scale_prefill, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale_prefill, extra_weight_attrs)
|
||||
# INPUT_SCALES
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8.")
|
||||
|
||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def moe_fp8_apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import fused_experts
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
experts_output = None
|
||||
if memory_recycler is not None:
|
||||
# remove MOE_EXPERT_OUT_BUFFER
|
||||
# experts_output = memory_recycler.MOE_EXPERT_OUT_BUFFER
|
||||
experts_output = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w13_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
decode_with_batch=layer.is_decode and x.shape[0] > 1,
|
||||
output_opt=experts_output
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"vacc fused_expert run fail, now using unfused ops: {e}")
|
||||
from torch_vacc.vacc.custom_ops_cpu import fused_experts
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w13_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
282
vllm_vacc/vllm/model_executor/layers/quantization/gptq.py
Normal file
282
vllm_vacc/vllm/model_executor/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig as GPTQConfigOrig
|
||||
from vllm.model_executor.layers.quantization.gptq import ExllamaState
|
||||
from vllm_vacc.vllm.model_executor.models.vars import TRANSPOSE_GPTQ_WEIGHT
|
||||
import math
|
||||
|
||||
def GPTQLinearMethod__init(self, quant_config: GPTQConfigOrig):
|
||||
self.quant_config = quant_config
|
||||
self.scale_k = 1
|
||||
self.split_num = 4
|
||||
|
||||
def int32_to_int4(s0, axis = -2):
|
||||
# 要先拉平 shape[1, n]
|
||||
# 每个int32 拆成8个int4, 8个int32表示, 得到[8, n]
|
||||
|
||||
# x32(int32) => 32bit => 4bit x 8 x4[8] 4bit
|
||||
|
||||
# x32 31-28 => x4[7]
|
||||
# x32 27-24 => x4[6]
|
||||
# ...
|
||||
# x32 3-0 => x4[0]
|
||||
|
||||
# x32[index=0] => x4[7,6,5,4,3,2,1,0]
|
||||
|
||||
# 4bit转真实数字:
|
||||
# 不是按补码方式
|
||||
|
||||
# 1111 => 15 => 7
|
||||
# 15-8 = 7
|
||||
|
||||
# 0101 => 6 =>-2
|
||||
# 6-8 = -2
|
||||
|
||||
# 0x 6A CB 37 2B (内存中排列 2B 37 CB 6A) => B273BCA6 => (-8) => int4: 3, -6, -1, -5, 3, 4, 2, -2
|
||||
|
||||
# 内存中实际排布为小端模式:
|
||||
# int32: 2B 37 CB 6A => 2,11,3,7,12,11,6,10 => (-8) => -6,3, -5,-1, 4,3, -2,2 => 同一字节所在的两个交换得到 3, -6, -1, -5, 3, 4, 2, -2
|
||||
# int4: 3, -6, -1, -5, 3, 4, 2, -2
|
||||
|
||||
s = s0.view(torch.uint32)
|
||||
all = []
|
||||
for i in range(8):
|
||||
x = 15 << (i*4)
|
||||
# s2 = torch.bitwise_and(x,s)
|
||||
s2 = torch.from_numpy(np.bitwise_and(x, s.numpy()))
|
||||
s3 = s2 / (2 ** (i*4))
|
||||
s4 = s3.to(torch.int32)
|
||||
# 补码, 结果不对
|
||||
# s4[s4 > 7] = s4[s4 > 7]-16
|
||||
# 直接 - 8 结果正确, 范围: -8-7
|
||||
s4 = s4 - 8
|
||||
all.append(s4.reshape(1,*s4.shape))
|
||||
all = torch.concatenate(all, 0)
|
||||
if axis == -2 or axis == 0:
|
||||
# 8,K//8,N => K//8,8,N => K,N
|
||||
all = all.transpose(-2,0).reshape(-1,all.shape[-1]).contiguous()
|
||||
else:
|
||||
# 8,N,K//8 => N,K//8,8 => N,K
|
||||
all = all.permute(1,2,0).reshape(all.shape[-2],-1).contiguous()
|
||||
return all
|
||||
|
||||
|
||||
def dequant_weight(qw, scales, group_size = 128):
|
||||
N = qw.shape[1]
|
||||
int4_to_int32_axis = -2
|
||||
if TRANSPOSE_GPTQ_WEIGHT:
|
||||
N = qw.shape[0]
|
||||
int4_to_int32_axis = -1
|
||||
qweight = int32_to_int4(qw,int4_to_int32_axis).to(torch.float16) #int32 => 8 int4 +> fp16
|
||||
|
||||
if TRANSPOSE_GPTQ_WEIGHT:
|
||||
scales = scales.T.contiguous()
|
||||
qweight = qweight.T.contiguous()
|
||||
|
||||
scales = torch.concatenate([scales] * group_size, 1).reshape(-1, N) # scale 按 group_size 扩展, 每 group_size 个数共用一个scale
|
||||
|
||||
# print('qweight', qweight.shape, qweight.dtype)
|
||||
# print('scale', scales.shape, scales.dtype)
|
||||
|
||||
dequant_weight = qweight * scales #dequant
|
||||
return dequant_weight
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
class GPTQLinearMethod(LinearMethodBase):
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
# if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
# raise ValueError(
|
||||
# "The input size is not aligned with the quantized "
|
||||
# "weight shape. This can be caused by too large "
|
||||
# "tensor parallel size.")
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor.numerator
|
||||
!= 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
exllama_state = ExllamaState.UNINITIALIZED
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (input_size != input_size_per_partition
|
||||
and self.quant_config.group_size != -1):
|
||||
# For act-order models, we cannot use Exllama for row parallel layer
|
||||
if self.quant_config.desc_act:
|
||||
exllama_state = ExllamaState.UNUSED
|
||||
else:
|
||||
# we need to partition qzeros and scales for exllama kernel
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
g_idx = RowvLLMParameter(data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
layer.exllama_state = exllama_state
|
||||
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# for torch.compile
|
||||
# self.quant_config.weight_bits == 4
|
||||
if TRANSPOSE_GPTQ_WEIGHT:
|
||||
layer.qzeros = Parameter(layer.qzeros.data.T.contiguous(), requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data.T.contiguous(), requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data.T.contiguous(), requires_grad=False)
|
||||
else:
|
||||
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
self.quant_config.weight_bits)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=layer.g_idx.device)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
# print(f"~~~~ start dequant")
|
||||
# import time
|
||||
# start_quant_time = time.time()
|
||||
# weight = dequant_weight(layer.qweight.cpu(), layer.scales.cpu(), self.quant_config.group_size // self.scale_k).to(layer.qweight.device)
|
||||
# end_quant_time = time.time()
|
||||
# print(f"~~~~ dequant time: {end_quant_time - start_quant_time}")
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(f"~~~~ weight shape: {weight.shape}, dtype: {weight.dtype}")
|
||||
# output = torch.matmul(reshaped_x, weight)
|
||||
# print("entering GPTQLinearMethod apply, reshaped_x shape:", reshaped_x.shape, "reshaped_x stride", reshaped_x.stride(), "input_tensor", x.shape, "qweight shape:", layer.qweight.shape, "scales shape:", layer.scales.shape)
|
||||
output = torch.vacc.w4a8_block_int4_matmul(
|
||||
reshaped_x,
|
||||
layer.qweight.transpose(-1, -2),
|
||||
layer.scales.transpose(-1, -2),
|
||||
[1, self.quant_config.group_size // self.scale_k],
|
||||
)
|
||||
# print("exiting GPTQLinearMethod apply, output shape:", output.shape)
|
||||
# end_gemm_time = time.time()
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(f"~~~~ gemm time: {end_gemm_time - end_quant_time}")
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
372
vllm_vacc/vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
372
vllm_vacc/vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supports_layer)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
# [num_experts, N, K//8], int32 ==> [num_experts, N, K], int4 ==> [num_experts, N//8, K], int32
|
||||
def repack_quant_moe_weight_old(original_packed_tensor):
|
||||
num_experts = original_packed_tensor.shape[0]
|
||||
N = original_packed_tensor.shape[1]
|
||||
K = original_packed_tensor.shape[2] * 8
|
||||
if original_packed_tensor.dtype != torch.int32:
|
||||
raise ValueError("data type of input tensor should be int32")
|
||||
if N % 8 != 0:
|
||||
raise ValueError("N of input tensor should be divisible by 8")
|
||||
|
||||
# --- 1. 解包:将 int32 张量展开为逻辑上的 int4 张量 ---
|
||||
# 创建一个临时张量来存储解包后的所有 int4 值
|
||||
# 用 torch.uint8 作为 int4 的临时存储,因为 PyTorch 没有原生的 int4 dtype
|
||||
unpacked_int4_tensor = torch.zeros(
|
||||
(num_experts, N, K),
|
||||
dtype=torch.uint8,
|
||||
device=original_packed_tensor.device
|
||||
)
|
||||
mask = 0b1111
|
||||
for i in range(8):
|
||||
# 提取当前 int4 所需的 int32 块中的值
|
||||
# 通过右移 (i * 4) 位,我们将第 i 个 4 位整数移动到最低有效位
|
||||
# 然后通过按位与操作与掩码结合,提取出这 4 位的值
|
||||
extracted_int4s = (original_packed_tensor >> (i * 4)) & mask
|
||||
# 将提取出的 int4 值放置到 unpacked_int4_tensor 的正确位置
|
||||
# 使用切片 `i::8`,意思是:从索引 `i` 开始,每隔 8 个位置填充一次
|
||||
unpacked_int4_tensor[:, :, i::8] = extracted_int4s
|
||||
|
||||
# --- 2. 重新打包:将 int4 逻辑张量重新打包为新的 int32 张量 ---
|
||||
new_packed_tensor = torch.zeros(
|
||||
(num_experts, N//8, K),
|
||||
dtype=torch.int32,
|
||||
device=original_packed_tensor.device
|
||||
)
|
||||
for i in range(8):
|
||||
# 从解包后的 int4 张量中提取当前需要打包的 int4 序列,使用切片 `i::8` 沿着N方向来提取
|
||||
current_int4_segment = unpacked_int4_tensor[:, i::8, :]
|
||||
# 将这个 int4 序列转换为 int32 类型(因为打包到 int32)并左移到其在新 int32 块中的正确位置
|
||||
# 然后使用按位或操作符将其合并到 new_packed_tensor 中
|
||||
new_packed_tensor |= (current_int4_segment.to(torch.int32) << (i * 4))
|
||||
|
||||
return new_packed_tensor
|
||||
|
||||
|
||||
def repack_quant_moe_weight(original_packed_tensor):
|
||||
if original_packed_tensor.dtype != torch.int32:
|
||||
raise ValueError("data type of input tensor should be int32")
|
||||
|
||||
num_experts, N, K_packed = original_packed_tensor.shape
|
||||
K = K_packed * 8
|
||||
|
||||
if N % 8 != 0:
|
||||
raise ValueError("N of input tensor should be divisible by 8")
|
||||
|
||||
new_packed_tensor = torch.zeros((num_experts, N // 8, K),
|
||||
dtype=torch.int32,
|
||||
device=original_packed_tensor.device)
|
||||
for i in range(8):
|
||||
source_slice = original_packed_tensor[:, i::8, :]
|
||||
for j in range(8):
|
||||
unpacked_strip = (source_slice >> (j * 4)) & 0b1111
|
||||
new_packed_tensor[:, :, j::8] |= (unpacked_strip.to(torch.int32) << (i * 4))
|
||||
|
||||
return new_packed_tensor
|
||||
|
||||
|
||||
class MoeWNA16Method(FusedMoEMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
self.moe = layer
|
||||
layer.quant_config = self.quant_config
|
||||
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
||||
bit32_pack_factor = 32 // self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
group_size_div_factor = 1
|
||||
group_size_w13 = self.quant_config.group_size
|
||||
group_size_div_factor_w13 = 1
|
||||
group_size_w2 = self.quant_config.group_size
|
||||
group_size_div_factor_w2 = 1
|
||||
|
||||
# make intermediate_size and hidden_size divisible by group_size
|
||||
# we reduce the group size to ensure that
|
||||
# and we would repeat the loaded_weight later
|
||||
while intermediate_size_per_partition % group_size or \
|
||||
hidden_size % group_size:
|
||||
group_size = group_size // 2
|
||||
group_size_div_factor *= 2
|
||||
assert group_size >= 32
|
||||
layer.group_size = group_size
|
||||
layer.group_size_div_factor = group_size_div_factor
|
||||
|
||||
while intermediate_size_per_partition % group_size_w2:
|
||||
group_size_w2 = group_size_w2 // 2
|
||||
group_size_div_factor_w2 *= 2
|
||||
assert group_size_w2 >= 32
|
||||
layer.w2_block_size = group_size_w2
|
||||
layer.group_size_div_factor_w2 = group_size_div_factor_w2
|
||||
|
||||
while hidden_size % group_size_w13:
|
||||
group_size_w13 = group_size_w13 // 2
|
||||
group_size_div_factor_w13 *= 2
|
||||
assert group_size_w13 >= 32
|
||||
layer.w13_block_size = group_size_w13
|
||||
layer.group_size_div_factor_w13 = group_size_div_factor_w13
|
||||
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
extra_weight_attrs.update({
|
||||
"quant_method": strategy,
|
||||
"is_transposed": False
|
||||
})
|
||||
|
||||
assert 'weight_loader' in extra_weight_attrs
|
||||
weight_loader = extra_weight_attrs['weight_loader']
|
||||
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
|
||||
layer, weight_loader)
|
||||
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
# w13_qweight = torch.nn.Parameter(torch.empty(
|
||||
# num_experts,
|
||||
# 2 * intermediate_size_per_partition,
|
||||
# hidden_size // bit8_pack_factor,
|
||||
# dtype=torch.uint8),
|
||||
# requires_grad=False)
|
||||
w13_qweight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // bit32_pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
# w2_qweight = torch.nn.Parameter(torch.empty(
|
||||
# num_experts,
|
||||
# hidden_size,
|
||||
# intermediate_size_per_partition // bit8_pack_factor,
|
||||
# dtype=torch.uint8),
|
||||
# requires_grad=False)
|
||||
w2_qweight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // bit32_pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
w13_scales = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // group_size_w13,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
|
||||
w2_scales = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // group_size_w2,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.has_zp:
|
||||
w13_qzeros = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition // bit8_pack_factor,
|
||||
hidden_size // group_size,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
w2_qzeros = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size // bit8_pack_factor,
|
||||
intermediate_size_per_partition // group_size,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.linear_quant_method == "gptq":
|
||||
# some param are unused, but we need to init them in order to
|
||||
# load weights
|
||||
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
|
||||
if not self.quant_config.has_zp:
|
||||
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
|
||||
for key in invalid_param_keys:
|
||||
param = torch.nn.Parameter(torch.empty((0, ),
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter(key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
@staticmethod
|
||||
def get_weight_loader(layer, weight_loader):
|
||||
|
||||
def convert_awq_tensor(tensor, tensor_type):
|
||||
# convert awq qweight/qzeros to a standard format (assume int4)
|
||||
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
|
||||
# qzeros: (k // group_size, n // pack_factor_bit32) ->
|
||||
# (n // pack_factor_bit8, k // group_size)
|
||||
# pack_factor_bit32 = 32 // weight_bits
|
||||
# pack_factor_bit8 = 8 // weight_bits
|
||||
|
||||
# 0. suppose origin shape (a, b), dtype int32
|
||||
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
|
||||
size0 = tensor.size(0)
|
||||
tensor = tensor.view(torch.uint8)
|
||||
|
||||
# 2. unpack to uint4 (only when weight_bits == 4)
|
||||
# shape (a, 4 * b) -> (a, 4 * b, 2)
|
||||
shifter = torch.tensor([0, 4],
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
|
||||
# 3. change order, see
|
||||
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
|
||||
# shape -> (a, 4 * b * pack_factor_bit8)
|
||||
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
|
||||
tensor = tensor.view(size0, -1)
|
||||
|
||||
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
|
||||
tensor = tensor.T.contiguous()
|
||||
|
||||
# 5. repack (only when weight_bits == 4)
|
||||
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
|
||||
# qzeros shape -> (4 * b, a)
|
||||
|
||||
if tensor_type == "qweight":
|
||||
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
|
||||
elif tensor_type == "qzeros":
|
||||
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
|
||||
return tensor
|
||||
|
||||
def convert_gptq_int4_qzeros(tensor):
|
||||
tensor = tensor.view(torch.uint8)
|
||||
shifter = torch.tensor([0, 4],
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
tensor = tensor + 1
|
||||
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
|
||||
return tensor
|
||||
|
||||
def moe_wna16_weight_loader(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
return_success: bool = False):
|
||||
if "g_idx" in weight_name:
|
||||
return False if return_success else None
|
||||
if not layer.quant_config.has_zp and "qzeros" in weight_name:
|
||||
return False if return_success else None
|
||||
|
||||
device = get_tp_group().device
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
loaded_weight = loaded_weight.to(device)
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
|
||||
# convert gptq and awq weight to a standard format
|
||||
if layer.quant_config.linear_quant_method == "awq":
|
||||
assert layer.quant_config.weight_bits == 4
|
||||
if "weight" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight,
|
||||
"qweight")
|
||||
elif "zeros" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
|
||||
else:
|
||||
loaded_weight = loaded_weight.T
|
||||
elif layer.quant_config.linear_quant_method == "gptq":
|
||||
assert layer.quant_config.weight_bits in [4, 8]
|
||||
if "weight" in weight_name:
|
||||
# loaded_weight = loaded_weight.T.contiguous().view(
|
||||
# torch.uint8)
|
||||
loaded_weight = loaded_weight.T.contiguous()
|
||||
elif "zeros" in weight_name:
|
||||
# add 1 to gptq qzeros to align with awq
|
||||
loaded_weight = loaded_weight.view(torch.uint8)
|
||||
if layer.quant_config.weight_bits == 4:
|
||||
loaded_weight = convert_gptq_int4_qzeros(
|
||||
loaded_weight).T
|
||||
else:
|
||||
loaded_weight = loaded_weight.T + 1
|
||||
else:
|
||||
# loaded_weight = loaded_weight.T
|
||||
loaded_weight = loaded_weight.T.contiguous()
|
||||
|
||||
# repeat the qzeros/scales to fit new group size
|
||||
if layer.group_size_div_factor_w13 > 1 and \
|
||||
"qzeros" in weight_name or "scales" in weight_name and \
|
||||
shard_id == "w1" or shard_id == "w3":
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor_w13, 1)
|
||||
elif layer.group_size_div_factor_w2 > 1 and \
|
||||
"qzeros" in weight_name or "scales" in weight_name and \
|
||||
shard_id == "w2":
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor_w2, 1)
|
||||
elif layer.group_size_div_factor > 1 and \
|
||||
"qzeros" in weight_name or "scales" in weight_name:
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor, 1)
|
||||
|
||||
if "w13_qzeros" in weight_name:
|
||||
tensor = loaded_weight.view(layer.tp_size, -1,
|
||||
loaded_weight.size(1))[tp_rank]
|
||||
if shard_id == "w1":
|
||||
param.data[expert_id, :shard_size // 2] = tensor
|
||||
else:
|
||||
param.data[expert_id, shard_size // 2:] = tensor
|
||||
return True if return_success else None
|
||||
elif "w2_qzeros" in weight_name:
|
||||
param.data[expert_id] = loaded_weight.view(
|
||||
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
|
||||
return True if return_success else None
|
||||
else:
|
||||
# Delegate to the original loader, passing return_success
|
||||
return weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
shard_id,
|
||||
expert_id,
|
||||
return_success=return_success)
|
||||
|
||||
return moe_wna16_weight_loader
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
dev_w2 = layer.w2_qweight.device
|
||||
# torch.Size([128, 2048, 24]), torch.int32, strides: (49152, 24, 1)
|
||||
# ======>
|
||||
# torch.Size([128, 256, 192]), torch.int32, strides: (49152, 1, 256)
|
||||
layer.w2_qweight = torch.nn.Parameter(repack_quant_moe_weight(layer.w2_qweight.cpu()).transpose(-1, -2).contiguous().transpose(-1, -2).to(device=dev_w2), requires_grad=False)
|
||||
# torch.Size([128, 2048, 3]), torch.float16, strides: (6144, 3, 1)
|
||||
# ======>
|
||||
# torch.Size([128, 2048, 3]), torch.float16, strides: (6144, 1, 2048)
|
||||
layer.w2_scales = torch.nn.Parameter(layer.w2_scales.transpose(-1, -2).contiguous().transpose(-1, -2), requires_grad=False)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
from typing import List, Optional
|
||||
|
||||
def _apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = True,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
assert len(block_size) == 2, "only support dim2 block now"
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import w8a8_block_fp8_linear
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
|
||||
mla_oproj_output = None
|
||||
if memory_recycler is not None:
|
||||
os1, os2 = memory_recycler.MLA_OPROJ_OUT_BUFFER.shape
|
||||
if os1 == input_2d.size(0) and os2 == weight.size(0):
|
||||
mla_oproj_output = memory_recycler.MLA_OPROJ_OUT_BUFFER
|
||||
|
||||
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size, output = mla_oproj_output)
|
||||
except Exception as e:
|
||||
print("vacc fuse fp8 matmul run fail:", e, " , now use unfused ops")
|
||||
from torch_vacc.vacc.custom_ops_cpu import w8a8_block_fp8_linear
|
||||
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size)
|
||||
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
@@ -0,0 +1,9 @@
|
||||
from .rotary_embedding import (
|
||||
RotaryEmbedding_init_vacc,
|
||||
RotaryEmbedding_forward_vacc,
|
||||
ScalingRotaryEmbedding_forward_vacc,
|
||||
_compute_inv_freq_vacc,
|
||||
_deepseek_compute_cos_sin_cache_vacc,
|
||||
_yarn_compute_cos_sin_cache_vacc,
|
||||
_compute_cos_sin_cache_vacc
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
101
vllm_vacc/vllm/model_executor/layers/rotary_embedding/mrope.py
Normal file
101
vllm_vacc/vllm/model_executor/layers/rotary_embedding/mrope.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import itertools
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MRotaryEmbedding:
|
||||
@classmethod
|
||||
def _qwen3vl_get_input_positions_tensor(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw
|
||||
for _ in range(t)]
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
vision_start_token_id = hf_config.vision_start_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
if image_token_id in input_tokens and remain_images > 0:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if video_token_id in input_tokens and remain_videos > 0:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_merge_size, w // spatial_merge_size
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).flatten()
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
llm_grid_t, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
llm_grid_t, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 -
|
||||
len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
return llm_positions, mrope_position_delta
|
||||
@@ -0,0 +1,203 @@
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# from vllm.model_executor.layers.rotary_embedding import _apply_rotary_emb
|
||||
# from vllm.model_executor.layers.rotary_embedding import _yarn_find_correction_range, _yarn_linear_ramp_mask
|
||||
from vllm.model_executor.layers.rotary_embedding.common import yarn_find_correction_range as _yarn_find_correction_range
|
||||
from vllm.model_executor.layers.rotary_embedding.common import yarn_linear_ramp_mask as _yarn_linear_ramp_mask
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from ...ops.mrope_op import get_sin_cos_mrope
|
||||
|
||||
def RotaryEmbedding_init_vacc(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super(CustomOp, self).__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
# cache = self._compute_cos_sin_cache()
|
||||
cos, sin = self._compute_cos_sin_cache()
|
||||
cos = cos.to(dtype)
|
||||
sin = sin.to(dtype)
|
||||
|
||||
self.register_buffer("cos_cache", cos, persistent=False)
|
||||
self.register_buffer("sin_cache", sin, persistent=False)
|
||||
|
||||
# cache = cache.to(dtype)
|
||||
# self.cos_sin_cache: torch.Tensor
|
||||
# self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def RotaryEmbedding_forward_vacc(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-vacc implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
num_tokens = positions.numel()
|
||||
# positions = positions.flatten()
|
||||
# num_tokens = positions.shape[0]
|
||||
# cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
# cos_sin = self.cos_sin_cache[positions]
|
||||
# cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
|
||||
if isinstance(self, MRotaryEmbedding):
|
||||
# get mrope sin/cos
|
||||
cos, sin = get_sin_cos_mrope(self, positions)
|
||||
num_tokens = num_tokens//3
|
||||
else:
|
||||
positions = positions.flatten()
|
||||
cos = self.cos_cache[positions]
|
||||
sin = self.sin_cache[positions]
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
mode = "neox"
|
||||
if not self.is_neox_style:
|
||||
mode = "gptj"
|
||||
query_rot, key_rot=torch.vacc.RotaryPosEmbedding(query_rot, key_rot, cos, sin, 0, mode)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
return query, key
|
||||
|
||||
def ScalingRotaryEmbedding_forward_vacc(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
# self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
||||
# positions.device)
|
||||
# cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||
# if offsets is not None else positions]
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
positions = positions.flatten()
|
||||
|
||||
# cos_sin = self.cos_sin_cache[positions]
|
||||
# cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = self.cos_cache[positions]
|
||||
sin = self.sin_cache[positions]
|
||||
|
||||
# TODO: to be removed (require odsp support)
|
||||
# if self.is_neox_style:
|
||||
# # NOTE(woosuk): Here we assume that the positions tensor has the
|
||||
# # shape [batch_size, seq_len].
|
||||
# cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
# sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
# else:
|
||||
# cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
# sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
# rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
||||
mode = "neox" if self.is_neox_style else "gptj"
|
||||
# query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||
# key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
query_rot, key_rot=torch.vacc.RotaryPosEmbedding(query_rot, key_rot, cos, sin, 0, mode)
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
else:
|
||||
query = query_rot
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
def _compute_inv_freq_vacc(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device=current_platform.device_type) /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2,
|
||||
dtype=torch.float)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
def _deepseek_compute_cos_sin_cache_vacc(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
device=current_platform.device_type,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
# NOTE: for odsp friendly
|
||||
# seperate cos/sin cache can gurantee cos/sin
|
||||
# always has contigous layout for dim[-1]
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
return cos, sin
|
||||
# cache = torch.cat((cos, sin), dim=-1)
|
||||
# return cache
|
||||
|
||||
def _yarn_compute_cos_sin_cache_vacc(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
device=current_platform.device_type,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
return cos, sin
|
||||
# cache = torch.cat((cos, sin), dim=-1)
|
||||
# return cache
|
||||
|
||||
def _compute_cos_sin_cache_vacc(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings,
|
||||
device=current_platform.device_type,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
# NOTE: for odsp friendly
|
||||
# seperate cos/sin cache can gurantee cos/sin
|
||||
# always has contigous layout for dim[-1]
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
return cos, sin
|
||||
# cache = torch.cat((cos, sin), dim=-1)
|
||||
# return cache
|
||||
|
||||
|
||||
# import vllm.model_executor.layers.rotary_embedding as rotary_embedding
|
||||
# rotary_embedding.RotaryEmbedding.forward_vacc=RotaryEmbedding_forward_vacc
|
||||
# rotary_embedding.DeepseekScalingRotaryEmbedding._compute_inv_freq=_compute_inv_freq_vacc
|
||||
# rotary_embedding.DeepseekScalingRotaryEmbedding._compute_cos_sin_cache=_compute_cos_sin_cache_vacc
|
||||
|
||||
542
vllm_vacc/vllm/model_executor/layers/sampler.py
Normal file
542
vllm_vacc/vllm/model_executor/layers/sampler.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
import itertools
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from math import inf
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingTensors,
|
||||
SequenceGroupToSample)
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
CompletionSequenceGroupOutput, Logprob,
|
||||
PromptLogprobs, SampleLogprobs, SequenceOutput)
|
||||
from vllm.model_executor.layers.sampler import (SamplerOutput,
|
||||
_apply_min_tokens_penalty,
|
||||
_apply_top_k_top_p,
|
||||
_apply_min_p,
|
||||
_sample,
|
||||
SampleResultArgsType,
|
||||
get_logprobs,
|
||||
_build_sampler_output,
|
||||
SampleReturnType,
|
||||
SampleResultsDictType,
|
||||
SampleMetadataType,
|
||||
MultinomialSamplesType,
|
||||
_modify_greedy_probs_inplace,
|
||||
_top_k_top_p_multinomial_with_flashinfer,
|
||||
_multinomial,
|
||||
get_pythonized_sample_results,
|
||||
)
|
||||
from vllm_vacc.vllm.model_executor.models.vars import USE_DS3_SAMPLER as use_ds3_sampler
|
||||
from vllm_vacc.vllm.model_executor.models.vars import USE_DS3_SAMPLER_OP as use_ds3_sampler_op
|
||||
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
||||
import flashinfer.sampling
|
||||
# yapf: disable
|
||||
from flashinfer.sampling import (
|
||||
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
|
||||
|
||||
# yapf: enable
|
||||
else:
|
||||
flashinfer_top_k_top_p_sampling = None
|
||||
|
||||
class SamplerOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""For each sequence group, we generate a list of SequenceOutput object,
|
||||
each of which contains one possible candidate for the next token.
|
||||
|
||||
This data structure implements methods, so it can be used like a list, but
|
||||
also has optional fields for device tensors.
|
||||
"""
|
||||
|
||||
outputs: List[CompletionSequenceGroupOutput]
|
||||
|
||||
# On-device tensor containing probabilities of each token.
|
||||
sampled_token_probs: Optional[torch.Tensor] = None
|
||||
|
||||
# On-device tensor containing the logprobs of each token.
|
||||
logprobs: Optional["torch.Tensor"] = None
|
||||
|
||||
# Holds either (1) the pythonized sampler result (single-step scheduling)
|
||||
# or (2) what will be arguments for later deferred pythonization of the
|
||||
# sampler result (muliti-step scheduling)
|
||||
deferred_sample_results_args: Optional[SampleResultArgsType] = None
|
||||
|
||||
# On-device tensor containing the sampled token ids.
|
||||
sampled_token_ids: Optional[torch.Tensor] = None
|
||||
# CPU tensor containing the sampled token ids. Used during multi-step to
|
||||
# return the sampled token ids from last rank to AsyncLLMEngine to be
|
||||
# 'broadcasted' to all other PP ranks for next step.
|
||||
sampled_token_ids_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
# On-device tensor containing the sampled token embeddings (embeddings
|
||||
# corresponding to the sampled token ids). Used when prompt embeddings are
|
||||
# specified in lieu of prompt token ids or text.
|
||||
sampled_token_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
# Optional last hidden states from the model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
# Optional prefill hidden states from the model
|
||||
# (used for models like EAGLE).
|
||||
prefill_hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
# Time taken in the forward pass for this across all workers
|
||||
model_forward_time: Optional[float] = None
|
||||
|
||||
# Time taken in the model execute function. This will include model forward,
|
||||
# block/sync across workers, cpu-gpu sync time and sampling time.
|
||||
model_execute_time: Optional[float] = None
|
||||
|
||||
def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
|
||||
return self.outputs[idx]
|
||||
|
||||
def __setitem__(self, idx: int, value):
|
||||
self.outputs[idx] = value
|
||||
|
||||
def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
|
||||
return iter(self.outputs)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.outputs)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
return isinstance(other,
|
||||
self.__class__) and self.outputs == other.outputs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Show the shape of a tensor instead of its values to reduce noise.
|
||||
"""
|
||||
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
|
||||
else self.sampled_token_probs.shape)
|
||||
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
|
||||
self.sampled_token_ids.shape)
|
||||
return (
|
||||
f"SamplerOutput(outputs={self.outputs}, "
|
||||
f"sampled_token_probs={sampled_token_probs_repr}, "
|
||||
f"sampled_token_ids={sampled_token_ids_repr},")
|
||||
|
||||
|
||||
def Sampler_forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
"""
|
||||
Single-step scheduling:
|
||||
* Perform GPU-side sampling computation & compute
|
||||
GPU-side logprobs tensor
|
||||
* Pythonize sampling result & logprobs tensor
|
||||
|
||||
Multi-step scheduling:
|
||||
* Perform GPU-side sampling computation & compute
|
||||
GPU-side logprobs tensor
|
||||
* Defer Pythonization of sampling result & logprobs
|
||||
tensor
|
||||
* Encapsulate arguments required for deferred Pythonization
|
||||
in the :class:`SamplerOutput` structure
|
||||
|
||||
Args:
|
||||
logits: (num_tokens, vocab_size).
|
||||
sampling_metadata: Metadata for sampling.
|
||||
"""
|
||||
|
||||
assert logits is not None
|
||||
# print(f'Sampler_forward all_greedy={all_greedy}')
|
||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||
if not sampling_metadata.reuse_sampling_tensors:
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
elif self._do_penalties:
|
||||
# In this case, the sampling tensors logic depends on
|
||||
# "output_tokens" of a sequence. As a result, we cannot
|
||||
# reuse sampling tensors, since "output_tokens" changes
|
||||
# between decode runs.
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
|
||||
assert self._sampling_tensors is not None
|
||||
sampling_tensors = self._sampling_tensors
|
||||
do_penalties = self._do_penalties
|
||||
do_top_p_top_k = self._do_top_p_top_k
|
||||
do_min_p = self._do_min_p
|
||||
|
||||
is_greedy = (len(sampling_metadata.categorized_sample_indices[SamplingType.GREEDY]) == logits.shape[0])
|
||||
is_random = (len(sampling_metadata.categorized_sample_indices[SamplingType.RANDOM]) == logits.shape[0])
|
||||
is_random_seed = (len(sampling_metadata.categorized_sample_indices[SamplingType.RANDOM_SEED]) == logits.shape[0])
|
||||
|
||||
|
||||
max_n_in_batch = sampling_metadata.seq_groups[0].sampling_params.n
|
||||
generator = sampling_metadata.seq_groups[0].generator
|
||||
min_tokens = sampling_metadata.seq_groups[0].sampling_params.min_tokens
|
||||
# print("use_ds3_sampler ", use_ds3_sampler)
|
||||
if use_ds3_sampler == True and (is_greedy == True or ((is_random == True or is_random_seed == True) \
|
||||
and do_penalties == False \
|
||||
and flashinfer_top_k_top_p_sampling is None \
|
||||
and min_tokens <= 0 \
|
||||
and do_min_p == False \
|
||||
and max_n_in_batch == 1 \
|
||||
# and self._should_modify_greedy_probs_inplace == False
|
||||
# and self.include_gpu_probs_tensor == False
|
||||
)):
|
||||
sampling_type = SamplingType.GREEDY
|
||||
sample_metadata: SampleMetadataType = {}
|
||||
multinomial_samples: MultinomialSamplesType = {}
|
||||
greedy_samples: Optional[torch.Tensor] = None
|
||||
multinomial_out: Optional[torch.Tensor] = None
|
||||
vacc_device = logits.device
|
||||
# Create output tensor for sampled token ids.
|
||||
if self.include_gpu_probs_tensor:
|
||||
sampled_token_ids_tensor = torch.full((logits.shape[0], 1),
|
||||
VLLM_INVALID_TOKEN_ID,
|
||||
dtype=torch.long,
|
||||
device=vacc_device)
|
||||
probs_out = torch.empty_like(logits)
|
||||
logprobs_out = torch.empty_like(logits)
|
||||
else:
|
||||
probs_out = None
|
||||
logprobs_out = None
|
||||
sampled_token_ids_tensor = None
|
||||
if is_greedy == True:
|
||||
greedy_samples, _ = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 0)
|
||||
sampling_type = SamplingType.GREEDY
|
||||
if sampled_token_ids_tensor is not None:
|
||||
# Store sampled tokens in output tensor.
|
||||
sampled_token_ids_tensor = greedy_samples.unsqueeze(-1).to(torch.long)
|
||||
if probs_out is not None:
|
||||
# probs_out = torch.softmax(logits.to(torch.float), dim=-1, dtype=torch.float).to(logits)
|
||||
probs_out = torch.softmax(logits, dim=-1)
|
||||
if self._should_modify_greedy_probs_inplace == True:
|
||||
sample_indices = (sampling_metadata.categorized_sample_indices[SamplingType.GREEDY]).long()
|
||||
probs_out[sample_indices, :] = 0
|
||||
probs_out[sample_indices, greedy_samples] = 1.0
|
||||
elif is_random == True and do_top_p_top_k == True:
|
||||
if use_ds3_sampler_op:
|
||||
logits = logits.to(torch.float)
|
||||
multinomial_out, probs_out = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 2)
|
||||
multinomial_out = multinomial_out.view(-1, max_n_in_batch)
|
||||
else:
|
||||
logits = logits.to(torch.float)
|
||||
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
|
||||
logits = torch.vacc.topk_topp(logits, sampling_tensors.top_ps, sampling_tensors.top_ks)
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
probs_out = probs
|
||||
# multinomial_out = torch.multinomial(probs, 1)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
multinomial_out = probs.div_(q).argmax(dim=1).view(-1, max_n_in_batch)
|
||||
sampling_type = SamplingType.RANDOM
|
||||
elif is_random_seed == True and generator is not None and do_top_p_top_k == True:
|
||||
if use_ds3_sampler_op:
|
||||
# print("is_random_seed ", is_random_seed)
|
||||
logits = logits.to(torch.float)
|
||||
multinomial_out, probs_out = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 1, generator)
|
||||
multinomial_out = multinomial_out.view(-1, max_n_in_batch)
|
||||
else:
|
||||
logits = logits.to(torch.float)
|
||||
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
|
||||
logits = torch.vacc.topk_topp(logits, sampling_tensors.top_ps, sampling_tensors.top_ks).to(torch.float)
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
probs_out = probs
|
||||
# torch.manual_seed(sampling_metadata.seq_groups[0].sampling_params.seed)
|
||||
# multinomial_out = torch.multinomial(probs, 1)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_(generator=generator)
|
||||
multinomial_out = probs.div_(q).argmax(dim=1).view(-1, max_n_in_batch)
|
||||
sampling_type = SamplingType.RANDOM_SEED
|
||||
|
||||
multinomial_samples[sampling_type] = multinomial_out
|
||||
|
||||
if sampled_token_ids_tensor is not None:
|
||||
if(sampling_type != SamplingType.GREEDY):
|
||||
# Store sampled tokens in output tensor.
|
||||
sampled_token_ids_tensor = multinomial_samples[sampling_type].to(torch.long)
|
||||
|
||||
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
sampling_params = seq_group.sampling_params
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
|
||||
seq_group_id = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
||||
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
|
||||
sample_results_dict: SampleResultsDictType = {}
|
||||
|
||||
maybe_deferred_args = SampleResultArgsType(
|
||||
sampling_metadata=sampling_metadata,
|
||||
sample_metadata=sample_metadata,
|
||||
multinomial_samples=multinomial_samples,
|
||||
greedy_samples=greedy_samples,
|
||||
# beam_search_logprobs=None,
|
||||
sample_results_dict=sample_results_dict)
|
||||
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
# GPU<->CPU sync happens here.
|
||||
# This also converts the sampler output to a Python object.
|
||||
# Return Pythonized sampler result & sampled token ids
|
||||
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = get_pythonized_sample_results(
|
||||
maybe_deferred_args), sampled_token_ids_tensor
|
||||
else:
|
||||
# Defer sampler result Pythonization; return deferred
|
||||
# Pythonization args & sampled token ids
|
||||
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = (
|
||||
maybe_deferred_args,
|
||||
sampled_token_ids_tensor,
|
||||
)
|
||||
|
||||
if self.include_gpu_probs_tensor:
|
||||
on_device_tensors = (probs_out, logprobs_out, maybe_sampled_tokens_tensor)
|
||||
else:
|
||||
on_device_tensors = None
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs = None
|
||||
sample_logprobs = None
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
||||
assert not isinstance(maybe_deferred_sample_results,
|
||||
SampleResultArgsType)
|
||||
logprobs = logits
|
||||
prompt_logprobs, sample_logprobs = get_logprobs(
|
||||
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
||||
|
||||
return _build_sampler_output(
|
||||
maybe_deferred_sample_results,
|
||||
sampling_metadata,
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
on_device_tensors=on_device_tensors,
|
||||
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
|
||||
|
||||
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
# if do_penalties:
|
||||
# logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
|
||||
# sampling_tensors.output_tokens,
|
||||
# sampling_tensors.presence_penalties.to(logits.device),
|
||||
# sampling_tensors.frequency_penalties.to(logits.device),
|
||||
# sampling_tensors.repetition_penalties.to(logits.device))
|
||||
|
||||
# Use float32 to apply temperature scaling.
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits = logits.to(torch.float)
|
||||
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
|
||||
|
||||
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
|
||||
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps.to(logits.device),
|
||||
sampling_tensors.top_ks.to(logits.device))
|
||||
|
||||
if do_min_p:
|
||||
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
||||
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
|
||||
probs,
|
||||
logprobs,
|
||||
sampling_metadata,
|
||||
sampling_tensors,
|
||||
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
||||
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
||||
)
|
||||
|
||||
if self.include_gpu_probs_tensor:
|
||||
# Since we will defer sampler result Pythonization,
|
||||
# preserve GPU-side tensors in support of later
|
||||
# deferred pythonization of logprobs
|
||||
assert maybe_sampled_tokens_tensor is not None
|
||||
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
||||
else:
|
||||
# Since Pythonization has already happened, don't preserve
|
||||
# GPU-side tensors.
|
||||
on_device_tensors = None
|
||||
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs = None
|
||||
sample_logprobs = None
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
||||
assert not isinstance(maybe_deferred_sample_results,
|
||||
SampleResultArgsType)
|
||||
prompt_logprobs, sample_logprobs = get_logprobs(
|
||||
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
||||
|
||||
return _build_sampler_output(
|
||||
maybe_deferred_sample_results,
|
||||
sampling_metadata,
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
on_device_tensors=on_device_tensors,
|
||||
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
|
||||
|
||||
def rejection_forward(
|
||||
self,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
|
||||
) -> torch.Tensor:
|
||||
if seeded_seqs is None:
|
||||
out, index = torch.vacc.rejection_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, 1)
|
||||
else:
|
||||
out, index = torch.vacc.rejection_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, 0, seeded_seqs[0])
|
||||
return out
|
||||
|
||||
class Sampler(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
"""
|
||||
Single-step scheduling:
|
||||
* Perform GPU-side sampling computation & compute
|
||||
GPU-side logprobs tensor
|
||||
* Pythonize sampling result & logprobs tensor
|
||||
|
||||
Multi-step scheduling:
|
||||
* Perform GPU-side sampling computation & compute
|
||||
GPU-side logprobs tensor
|
||||
* Defer Pythonization of sampling result & logprobs
|
||||
tensor
|
||||
* Encapsulate arguments required for deferred Pythonization
|
||||
in the :class:`SamplerOutput` structure
|
||||
|
||||
Args:
|
||||
logits: (num_tokens, vocab_size).
|
||||
sampling_metadata: Metadata for sampling.
|
||||
"""
|
||||
assert logits is not None
|
||||
_, vocab_size = logits.shape
|
||||
|
||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||
if not sampling_metadata.reuse_sampling_tensors:
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
elif self._do_penalties:
|
||||
# In this case, the sampling tensors logic depends on
|
||||
# "output_tokens" of a sequence. As a result, we cannot
|
||||
# reuse sampling tensors, since "output_tokens" changes
|
||||
# between decode runs.
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
|
||||
assert self._sampling_tensors is not None
|
||||
sampling_tensors = self._sampling_tensors
|
||||
do_penalties = self._do_penalties
|
||||
do_top_p_top_k = self._do_top_p_top_k
|
||||
do_min_p = self._do_min_p
|
||||
|
||||
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
if do_penalties:
|
||||
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
|
||||
sampling_tensors.output_tokens,
|
||||
sampling_tensors.presence_penalties,
|
||||
sampling_tensors.frequency_penalties,
|
||||
sampling_tensors.repetition_penalties)
|
||||
|
||||
# Use float32 to apply temperature scaling.
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits = logits.to(torch.float)
|
||||
# print("tempratures is:", temperatures)
|
||||
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1).to(logits.device))
|
||||
|
||||
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
|
||||
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
||||
sampling_tensors.top_ks)
|
||||
|
||||
if do_min_p:
|
||||
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
||||
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
|
||||
probs,
|
||||
logprobs,
|
||||
sampling_metadata,
|
||||
sampling_tensors,
|
||||
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
||||
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
||||
)
|
||||
|
||||
if self.include_gpu_probs_tensor:
|
||||
# Since we will defer sampler result Pythonization,
|
||||
# preserve GPU-side tensors in support of later
|
||||
# deferred pythonization of logprobs
|
||||
assert maybe_sampled_tokens_tensor is not None
|
||||
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
||||
else:
|
||||
# Since Pythonization has already happened, don't preserve
|
||||
# GPU-side tensors.
|
||||
on_device_tensors = None
|
||||
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs = None
|
||||
sample_logprobs = None
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
||||
assert not isinstance(maybe_deferred_sample_results,
|
||||
SampleResultArgsType)
|
||||
prompt_logprobs, sample_logprobs = get_logprobs(
|
||||
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
||||
|
||||
return _build_sampler_output(
|
||||
maybe_deferred_sample_results,
|
||||
sampling_metadata,
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
on_device_tensors=on_device_tensors,
|
||||
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
|
||||
|
||||
def _apply_top_k_top_p_vacc(
|
||||
logits: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = probs_sort.cumsum(dim=-1)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1).to(probs_sum.device)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
logits = torch.empty_like(logits_sort).scatter_(dim=-1,
|
||||
index=logits_idx,
|
||||
src=logits_sort)
|
||||
return logits
|
||||
@@ -0,0 +1,69 @@
|
||||
|
||||
import torch
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.platforms import current_platform
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.compile will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
||||
input_ < org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
|
||||
def VocabParallelEmbedding_forward(self, input_):
|
||||
|
||||
try:
|
||||
if self.tp_size > 1:
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
parallel_embedding_output = None
|
||||
if memory_recycler is not None:
|
||||
if memory_recycler.EMBEDDING_OUT_BUFFER.size(0) == input_.size(0):
|
||||
parallel_embedding_output = memory_recycler.EMBEDDING_OUT_BUFFER.to(self.weight.dtype)
|
||||
|
||||
output_parallel = torch.vacc.parallel_embedding(
|
||||
input_,
|
||||
self.weight,
|
||||
self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index,
|
||||
output = parallel_embedding_output
|
||||
)
|
||||
else:
|
||||
raise ValueError("not support non-tp")
|
||||
except:
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
|
||||
#TODO: fuse all_reduce
|
||||
return tensor_model_parallel_all_reduce(output_parallel)
|
||||
Reference in New Issue
Block a user