diff --git a/vllm_kunlun/models/deepseek_v2.py b/vllm_kunlun/models/deepseek_v2.py index 7ec23d2..d954f8d 100644 --- a/vllm_kunlun/models/deepseek_v2.py +++ b/vllm_kunlun/models/deepseek_v2.py @@ -37,21 +37,25 @@ from vllm_kunlun.ops.attention.layer import Attention from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, ParallelConfig, VllmConfig, - get_current_vllm_config) -from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm_kunlun.ops.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm_kunlun.ops.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention @@ -59,9 +63,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -71,23 +79,32 @@ from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerBackend from vllm_kunlun.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP -from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from vllm.model_executor.models.interfaces import ( + MixtureOfExperts, + SupportsLoRA, + SupportsPP, +) +from vllm.model_executor.models.utils import ( + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops as ops - + import xspeedgate_ops + _is_kunlun = True logger = init_logger(__name__) -class DeepseekV2MLP(nn.Module): +class DeepseekV2MLP(nn.Module): def __init__( self, hidden_size: int, @@ -105,21 +122,27 @@ class DeepseekV2MLP(nn.Module): # replicated and no collective ops are needed. # Otherwise we use standard TP with an allreduce at the end. self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, disable_tp=is_sequence_parallel, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - disable_tp=is_sequence_parallel, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -130,7 +153,6 @@ class DeepseekV2MLP(nn.Module): class DeepseekV2MoE(nn.Module): - def __init__( self, config: Union[DeepseekV2Config, DeepseekV3Config], @@ -153,17 +175,22 @@ class DeepseekV2MoE(nn.Module): self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts, dtype=torch.float32)) + torch.empty(config.n_routed_experts, dtype=torch.float32) + ) else: self.gate.e_score_correction_bias = None @@ -173,14 +200,13 @@ class DeepseekV2MoE(nn.Module): self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) if config.n_shared_experts is None: self.experts = FusedMoE( @@ -205,8 +231,7 @@ class DeepseekV2MoE(nn.Module): ) self.shared_experts = None else: - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) + intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, @@ -253,8 +278,9 @@ class DeepseekV2MoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - fused_moe_out = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if self.shared_experts is not None: shared_output, final_hidden_states = fused_moe_out @@ -268,7 +294,7 @@ class DeepseekV2MoE(nn.Module): final_hidden_states *= self.routed_scaling_factor elif self.shared_experts is not None: assert shared_output is not None - shared_output *= (1. / self.routed_scaling_factor) + shared_output *= 1.0 / self.routed_scaling_factor if self.shared_experts is not None: assert shared_output is not None @@ -276,25 +302,26 @@ class DeepseekV2MoE(nn.Module): if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( - final_hidden_states, 0) + final_hidden_states, 0 + ) final_hidden_states = final_hidden_states[:num_tokens] elif self.tp_size > 1: - final_hidden_states = ( - self.experts.maybe_all_reduce_tensor_model_parallel( - final_hidden_states)) + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) return final_hidden_states.view(num_tokens, hidden_dim) def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: import math + if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 class DeepseekV2Attention(nn.Module): - def __init__( self, vllm_config: VllmConfig, @@ -329,60 +356,69 @@ class DeepseekV2Attention(nn.Module): self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - assert topk_indices_buffer is None, "topk_indices_buffer is not \ + assert ( + topk_indices_buffer is None + ), "topk_indices_buffer is not \ supported for DeepseekV2Attention" if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") + prefix=f"{prefix}.kv_b_proj", + ) # O projection. - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) + rope_scaling["rope_type"] = "deepseek_yarn" + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) @@ -390,13 +426,15 @@ class DeepseekV2Attention(nn.Module): mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - self.attn = Attention(self.num_local_heads, - self.qk_head_dim, - self.scaling, - num_kv_heads=self.num_local_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -406,42 +444,39 @@ class DeepseekV2Attention(nn.Module): if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, - self.qk_head_dim) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: - q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, - self.qk_head_dim) - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - kv_a, _ = latent_cache.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) kv_a = self.kv_a_layernorm(kv_a) kv = self.kv_b_proj(kv_a)[0] - kv = kv.view(-1, self.num_local_heads, - self.qk_nope_head_dim + self.v_head_dim) + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) - k[..., :self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe # padding value to qk_head_dim for alignment v = torch.nn.functional.pad( - v, [0, self.qk_head_dim - self.v_head_dim], - value=0).view(-1, self.num_local_heads * self.qk_head_dim) + v, [0, self.qk_head_dim - self.v_head_dim], value=0 + ).view(-1, self.num_local_heads * self.qk_head_dim) attn_output = self.attn(q, k, v) - attn_output = attn_output.view( - -1, self.num_local_heads, - self.qk_head_dim)[..., :self.v_head_dim].reshape( - -1, self.num_local_heads * self.v_head_dim) + attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[ + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output + @custom_op("vllm::sparse_attn_indexer_vllm_kunlun", mutates_args=()) def sparse_attn_indexer_vllm_kunlun( hidden_states: torch.Tensor, @@ -458,7 +493,6 @@ def sparse_attn_indexer_vllm_kunlun( total_seq_lens: int, topk_indices_buffer: Optional[torch.Tensor], ) -> None: - # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata # assert isinstance(attn_metadata, dict) @@ -486,7 +520,6 @@ def sparse_attn_indexer_vllm_kunlun( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - torch.ops.xspeedgate_ops.indexer_k_quant_and_cache( k, kv_cache, @@ -494,16 +527,16 @@ def sparse_attn_indexer_vllm_kunlun( quant_block_size, scale_fmt, ) - topk_indices_buffer[:hidden_states.shape[0]] = -1 + topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill for chunk in prefill_metadata.chunks: - k_fp8 = torch.empty([chunk.total_seq_lens, head_dim], - device=k.device, - dtype=torch.int8) - k_scale = torch.empty([chunk.total_seq_lens, 4], - device=k.device, - dtype=torch.uint8) + k_fp8 = torch.empty( + [chunk.total_seq_lens, head_dim], device=k.device, dtype=torch.int8 + ) + k_scale = torch.empty( + [chunk.total_seq_lens, 4], device=k.device, dtype=torch.uint8 + ) torch.ops.xspeedgate_ops.cp_gather_indexer_k_quant_cache( kv_cache=kv_cache, dst_k=k_fp8, @@ -512,9 +545,9 @@ def sparse_attn_indexer_vllm_kunlun( cu_seq_lens=chunk.cu_seq_lens, ) logits = int8_mqa_logits( - q_fp8[chunk.token_start:chunk.token_end], + q_fp8[chunk.token_start : chunk.token_end], (k_fp8, k_scale.view(torch.float32)), - weights[chunk.token_start:chunk.token_end], + weights[chunk.token_start : chunk.token_end], chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, context_q_lens_xpu=chunk.context_q_lens, @@ -523,23 +556,26 @@ def sparse_attn_indexer_vllm_kunlun( context_k_lens_cpu=chunk.context_k_lens_cpu, ) del k_fp8, k_scale - + num_rows = logits.shape[0] topk_indices = topk_indices_buffer[ - chunk.token_start:chunk.token_end, :topk_tokens] - + chunk.token_start : chunk.token_end, :topk_tokens + ] + # when seqLens=None and next_n=None, it means that it is used to calculate topk_indices in prefill # refer to top_k_per_row_prefill:https://github.com/vllm-project/vllm/blob/6a09612b2e0e09d037a220ea8115632b8084e008/csrc/sampler.cu#L698 - torch.ops.xspeedgate_ops.topk_per_row(logits=logits, - srcIndices=topk_indices, - numRows=num_rows, - stride0=logits.stride(0), - stride1=logits.stride(1), - topK=topk_tokens, - rowStarts=chunk.cu_seqlen_ks, - rowEnds=chunk.cu_seqlen_ke, - seqLens=None, - next_n=None) + torch.ops.xspeedgate_ops.topk_per_row( + logits=logits, + srcIndices=topk_indices, + numRows=num_rows, + stride0=logits.stride(0), + stride1=logits.stride(1), + topK=topk_tokens, + rowStarts=chunk.cu_seqlen_ks, + rowEnds=chunk.cu_seqlen_ke, + seqLens=None, + next_n=None, + ) if has_decode: decode_metadata = attn_metadata.decode @@ -553,10 +589,12 @@ def sparse_attn_indexer_vllm_kunlun( # prefill and decode by decode_threshold # (currently set to 1 + speculative tokens) padded_q_fp8_decode_tokens = pack_seq_triton( - q_fp8[:num_decode_tokens], decode_lens) + q_fp8[:num_decode_tokens], decode_lens + ) else: padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( - decode_lens.shape[0], -1, *q_fp8.shape[1:]) + decode_lens.shape[0], -1, *q_fp8.shape[1:] + ) # TODO: move and optimize below logic with triton kernels batch_size = padded_q_fp8_decode_tokens.shape[0] next_n = padded_q_fp8_decode_tokens.shape[1] @@ -574,31 +612,33 @@ def sparse_attn_indexer_vllm_kunlun( ) num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] + # when row_starts=None and row_ends=None, it means that it is used to calculate topk_indices in decode # refer to top_k_per_row_decode:https://github.com/vllm-project/vllm/blob/6a09612b2e0e09d037a220ea8115632b8084e008/csrc/sampler.cu#L643 - torch.ops.xspeedgate_ops.topk_per_row(logits=logits, - srcIndices=topk_indices, - numRows=num_rows, - stride0=logits.stride(0), - stride1=logits.stride(1), - topK=topk_tokens, - rowStarts=None, - rowEnds=None, - seqLens=decode_metadata.seq_lens, - next_n=next_n) - - + torch.ops.xspeedgate_ops.topk_per_row( + logits=logits, + srcIndices=topk_indices, + numRows=num_rows, + stride0=logits.stride(0), + stride1=logits.stride(1), + topK=topk_tokens, + rowStarts=None, + rowEnds=None, + seqLens=decode_metadata.seq_lens, + next_n=next_n, + ) + if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens topk_indices = unpack_seq_triton( topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), - decode_lens) - topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( - topk_indices + decode_lens, ) + topk_indices_buffer[ + :num_decode_tokens, : topk_indices.shape[-1] + ] = topk_indices def sparse_attn_indexer_vllm_kunlun_fake( @@ -618,19 +658,22 @@ def sparse_attn_indexer_vllm_kunlun_fake( ) -> None: return + sparse_attn_indexer_vllm_kunlun.register_fake(sparse_attn_indexer_vllm_kunlun_fake) -class Indexer(nn.Module): - def __init__(self, - vllm_config: VllmConfig, - config: Union[DeepseekV2Config, DeepseekV3Config], - hidden_size: int, - q_lora_rank: int, - quant_config: Optional[QuantizationConfig], - cache_config: Optional[CacheConfig], - topk_indices_buffer: Optional[torch.Tensor], - prefix: str = ""): +class Indexer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + config: Union[DeepseekV2Config, DeepseekV3Config], + hidden_size: int, + q_lora_rank: int, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + topk_indices_buffer: Optional[torch.Tensor], + prefix: str = "", + ): super().__init__() self.vllm_config = vllm_config self.config = config @@ -641,22 +684,28 @@ class Indexer(nn.Module): self.rope_dim = config.qk_rope_head_dim # 64 self.q_lora_rank = q_lora_rank # 1536 # no tensor parallel, just replicated - self.wq_b = ReplicatedLinear(self.q_lora_rank, - self.head_dim * self.n_head, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wq_b") - self.wk = ReplicatedLinear(hidden_size, - self.head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wk") + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.head_dim * self.n_head, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b", + ) + self.wk = ReplicatedLinear( + hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wk", + ) self.k_norm = LayerNorm(self.head_dim, eps=1e-6) - self.weights_proj = ReplicatedLinear(hidden_size, - self.n_head, - bias=False, - quant_config=None, - prefix=f"{prefix}.weights_proj") + self.weights_proj = ReplicatedLinear( + hidden_size, + self.n_head, + bias=False, + quant_config=None, + prefix=f"{prefix}.weights_proj", + ) self.softmax_scale = self.head_dim**-0.5 self.scale_fmt = "ue8m0" self.quant_block_size = 128 # TODO: get from config @@ -666,30 +715,39 @@ class Indexer(nn.Module): # where we store value in fp8 and scale in fp32 # per self.quant_block_size element self.k_cache = DeepseekV32IndexerCache( - head_dim=self.head_dim + - self.head_dim // self.quant_block_size * 4, + head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4, dtype=torch.uint8, prefix=f"{prefix}.k_cache", - cache_config=cache_config) + cache_config=cache_config, + ) self.max_model_len = vllm_config.model_config.max_model_len - if self.max_model_len % cache_config.block_size != 0: #由于I8_paged_mqa_logits输入参数的限制,最大长度必须为block_zise的整数倍 - self.max_model_len = self.max_model_len + cache_config.block_size - (self.max_model_len % cache_config.block_size) + if ( + self.max_model_len % cache_config.block_size != 0 + ): # 由于I8_paged_mqa_logits输入参数的限制,最大长度必须为block_size的整数倍 + self.max_model_len = ( + self.max_model_len + + cache_config.block_size + - (self.max_model_len % cache_config.block_size) + ) self.prefix = prefix - from vllm.v1.attention.backends.mla.indexer import ( - get_max_prefill_buffer_size) + from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size + self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) - def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, - rotary_emb) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb + ) -> torch.Tensor: q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) q_pe, q_nope = torch.split( - q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) k, _ = self.wk(hidden_states) k = self.k_norm(k) k_pe, k_nope = torch.split( - k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) q = torch.cat([q_pe, q_nope], dim=-1) @@ -714,7 +772,7 @@ class Indexer(nn.Module): weights, _ = self.weights_proj(hidden_states) weights = weights * self.n_head**-0.5 weights = weights * q_scale * self.softmax_scale - + torch.ops.vllm.sparse_attn_indexer_vllm_kunlun( hidden_states, self.k_cache.prefix, @@ -737,7 +795,7 @@ class DeepseekV2MLAAttention(nn.Module): """ Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - + For more info see MLACommonImpl in: vllm/v1/attention/backends/mla/utils.py """ @@ -787,53 +845,60 @@ class DeepseekV2MLAAttention(nn.Module): bias=False, quant_config=quant_config, prefix=f"{prefix}.fused_qkv_a_proj", - disable_tp=True) + disable_tp=True, + ) else: self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) if self.q_lora_rank is not None: - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(self.q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + self.q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + prefix=f"{prefix}.kv_b_proj", + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) + rope_scaling["rope_type"] = "deepseek_yarn" + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) scaling_factor = rope_scaling["factor"] @@ -843,9 +908,16 @@ class DeepseekV2MLAAttention(nn.Module): self.is_v32 = hasattr(config, "index_topk") if self.is_v32: - self.indexer = Indexer(vllm_config, config, hidden_size, - q_lora_rank, quant_config, cache_config, - topk_indices_buffer, f"{prefix}.indexer") + self.indexer = Indexer( + vllm_config, + config, + hidden_size, + q_lora_rank, + quant_config, + cache_config, + topk_indices_buffer, + f"{prefix}.indexer", + ) else: self.indexer = None @@ -855,11 +927,12 @@ class DeepseekV2MLAAttention(nn.Module): rotary_emb=self.rotary_emb, o_proj=self.o_proj, fused_qkv_a_proj=self.fused_qkv_a_proj - if self.q_lora_rank is not None else None, + if self.q_lora_rank is not None + else None, kv_a_proj_with_mqa=self.kv_a_proj_with_mqa - if self.q_lora_rank is None else None, - q_a_layernorm=self.q_a_layernorm - if self.q_lora_rank is not None else None, + if self.q_lora_rank is None + else None, + q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None, q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, q_proj=self.q_proj if self.q_lora_rank is None else None, indexer=self.indexer, @@ -891,11 +964,12 @@ class DeepseekV2MLAAttention(nn.Module): class DeepseekV2DecoderLayer(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - prefix: str, - topk_indices_buffer: Optional[torch.Tensor] = None) -> None: + def __init__( + self, + vllm_config: VllmConfig, + prefix: str, + topk_indices_buffer: Optional[torch.Tensor] = None, + ) -> None: super().__init__() config = vllm_config.model_config.hf_config @@ -907,11 +981,25 @@ class DeepseekV2DecoderLayer(nn.Module): self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + + # FIXME: Temporary compatibility code for new config format. Remove after vLLM upgrade. + # https://github.com/huggingface/transformers/pull/39847 + if hasattr(config, "rope_parameters"): + rope_params = config.rope_parameters + + rope_theta = rope_params.get("rope_theta", rope_theta) + rope_type = rope_params.get("rope_type", "default") + + if rope_type != "default": + raise NotImplementedError( + f"Unsupported rope_type='{rope_type}' in rope_parameters. " + f"Only rope_type='default' is currently supported." + ) + + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) + layer_idx = int(prefix.split(sep=".")[-1]) self.layer_idx = layer_idx if model_config.use_mla: attn_cls = DeepseekV2MLAAttention @@ -925,8 +1013,7 @@ class DeepseekV2DecoderLayer(nn.Module): qk_nope_head_dim=config.qk_nope_head_dim, qk_rope_head_dim=config.qk_rope_head_dim, v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, + q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, kv_lora_rank=config.kv_lora_rank, rope_theta=rope_theta, rope_scaling=rope_scaling, @@ -937,9 +1024,11 @@ class DeepseekV2DecoderLayer(nn.Module): topk_indices_buffer=topk_indices_buffer, ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ): self.mlp = DeepseekV2MoE( config=config, parallel_config=parallel_config, @@ -954,10 +1043,10 @@ class DeepseekV2DecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.routed_scaling_factor = config.routed_scaling_factor def forward( @@ -971,8 +1060,7 @@ class DeepseekV2DecoderLayer(nn.Module): residual = hidden_states.clone() hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -982,32 +1070,29 @@ class DeepseekV2DecoderLayer(nn.Module): # Fix FP16 overflow # We scale both hidden_states and residual before # rmsnorm, and rmsnorm result would not affect by scale. - hidden_states *= 1. / self.routed_scaling_factor + hidden_states *= 1.0 / self.routed_scaling_factor if self.layer_idx == 0: # The residual is shared by all layers, we only scale it on # first layer. - residual *= 1. / self.routed_scaling_factor + residual *= 1.0 / self.routed_scaling_factor # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) - if isinstance(self.mlp, - DeepseekV2MLP) and hidden_states.dtype == torch.float16: + if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: # Fix FP16 overflow # Scaling the DeepseekV2MLP output, it is the input of # input_layernorm of next decoder layer. # The scaling of DeepseekV2MOE output would be done in the forward # of DeepseekV2MOE - hidden_states *= 1. / self.routed_scaling_factor + hidden_states *= 1.0 / self.routed_scaling_factor return hidden_states, residual @support_torch_compile class DeepseekV2Model(nn.Module): - fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -1025,7 +1110,8 @@ class DeepseekV2Model(nn.Module): vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, - device="cuda") + device="cuda", + ) else: topk_indices_buffer = None @@ -1034,23 +1120,26 @@ class DeepseekV2Model(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, - topk_indices_buffer), - prefix=f"{prefix}.layers") + lambda prefix: DeepseekV2DecoderLayer( + vllm_config, prefix, topk_indices_buffer + ), + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -1077,17 +1166,15 @@ class DeepseekV2Model(nn.Module): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, - SupportsLoRA): +class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], } @@ -1103,16 +1190,18 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, # initializing DeepseekV2Model, as it is passed inplace to # quantization config init and may be used to select the # quant_method for relevant layers during initialization. - self.fuse_qkv_a_proj = hasattr( - config, "q_lora_rank") and config.q_lora_rank is not None + self.fuse_qkv_a_proj = ( + hasattr(config, "q_lora_rank") and config.q_lora_rank is not None + ) if self.fuse_qkv_a_proj: self.packed_modules_mapping["fused_qkv_a_proj"] = [ "q_a_proj", "kv_a_proj_with_mqa", ] - self.model = DeepseekV2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = DeepseekV2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( config.vocab_size, @@ -1124,12 +1213,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) self.expert_weights = [] # Set MoE hyperparameters - self.num_moe_layers = (config.num_hidden_layers - - config.first_k_dense_replace) + self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_expert_groups = config.n_group self.moe_layers: list[FusedMoE] = [] @@ -1178,8 +1267,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) + self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.mlp, DeepseekV2MoE): moe = layer.mlp @@ -1198,8 +1286,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -1209,8 +1298,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -1226,7 +1314,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts, - num_redundant_experts=self.num_redundant_experts) + num_redundant_experts=self.num_redundant_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -1238,7 +1327,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, if spec_layer is not None: continue # skip spec decode layers for main model - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -1248,15 +1337,16 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal # weight loading if it's not enabled # if go with fusion option, then update name - if ((param_name == "fused_qkv_a_proj") - and name_mapped not in params_dict): + if ( + param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: continue else: name = name_mapped @@ -1295,14 +1385,17 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) if success: name = name_mapped break @@ -1327,8 +1420,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, if name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -1345,13 +1439,15 @@ class GlmMoeDsaForCausalLM(DeepseekV2ForCausalLM): # Compatibility with # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py -def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config, - DeepseekV3Config], - weight_name: str) -> Optional[int]: - if (hasattr(config, "num_nextn_predict_layers") - and config.num_nextn_predict_layers > 0): +def get_spec_layer_idx_from_weight_name( + config: Union[DeepseekV2Config, DeepseekV3Config], weight_name: str +) -> Optional[int]: + if ( + hasattr(config, "num_nextn_predict_layers") + and config.num_nextn_predict_layers > 0 + ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): if weight_name.startswith(f"model.layers.{layer_idx+i}."): return layer_idx + i - return None \ No newline at end of file + return None diff --git a/vllm_kunlun/transformer_utils/__init__.py b/vllm_kunlun/transformer_utils/__init__.py index e69de29..798fef6 100644 --- a/vllm_kunlun/transformer_utils/__init__.py +++ b/vllm_kunlun/transformer_utils/__init__.py @@ -0,0 +1,21 @@ +# +# Copyright (c) 2026 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-kunlun project. +# + + +from . import tokenizer + +__all__ = ["tokenizer"] diff --git a/vllm_kunlun/transformer_utils/tokenizer.py b/vllm_kunlun/transformer_utils/tokenizer.py new file mode 100644 index 0000000..ccf14fc --- /dev/null +++ b/vllm_kunlun/transformer_utils/tokenizer.py @@ -0,0 +1,223 @@ +# +# Copyright (c) 2026 Baidu, Inc. All Rights Reserved. + +# This file is a part of the vllm-kunlun project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile +import shutil +import os +import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional, Union + +import huggingface_hub +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +from vllm import envs +from vllm.logger import init_logger +from vllm.transformers_utils import tokenizer +from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config +from vllm.transformers_utils.tokenizer import get_cached_tokenizer +from vllm.transformers_utils.tokenizers import MistralTokenizer +from vllm.transformers_utils.utils import check_gguf_file + +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer_base import TokenizerBase +else: + TokenizerBase = Any + +logger = init_logger(__name__) + +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, TokenizerBase] + + +def kunlun_get_tokenizer( + tokenizer_name: Union[str, Path], + *args, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + revision: Optional[str] = None, + download_dir: Optional[str] = None, + **kwargs, +) -> AnyTokenizer: + """Gets a tokenizer for the given model name via HuggingFace or ModelScope.""" + if envs.VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + # avoid circuit import + from vllm.model_executor.model_loader.weight_utils import get_lock + + # Only set the tokenizer here, model will be downloaded on the workers. + if not os.path.exists(tokenizer_name): + # Use file lock to prevent multiple processes from + # downloading the same file at the same time. + with get_lock(tokenizer_name, download_dir): + tokenizer_path = snapshot_download( + model_id=tokenizer_name, + cache_dir=download_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + # Ignore weights - we only need the tokenizer. + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) + tokenizer_name = tokenizer_path + + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + + if "truncation_side" not in kwargs: + kwargs["truncation_side"] = "left" + + # Separate model folder from file path for GGUF models + is_gguf = check_gguf_file(tokenizer_name) + if is_gguf: + kwargs["gguf_file"] = Path(tokenizer_name).name + tokenizer_name = Path(tokenizer_name).parent + + # if tokenizer is from official mistral org + is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai" + if is_from_mistral_org and tokenizer_mode != "mistral": + warnings.warn( + "It is strongly recommended to run mistral models with " + '`--tokenizer-mode "mistral"` to ensure correct ' + "encoding and decoding.", + FutureWarning, + stacklevel=2, + ) + + tokenizer: AnyTokenizer + if tokenizer_mode == "mistral": + tokenizer = MistralTokenizer.from_pretrained( + str(tokenizer_name), revision=revision + ) + elif tokenizer_mode == "custom": + from vllm.transformers_utils.tokenizer_base import TokenizerRegistry + + tokenizer = TokenizerRegistry.get_tokenizer( + str(tokenizer_name), + *args, + revision=revision, + download_dir=download_dir, + **kwargs, + ) + else: + try: + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, + # suggest using the --trust-remote-code flag. + + if not trust_remote_code and ( + "does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e) + ): + err_msg = ( + "Failed to load the tokenizer. If the tokenizer " + "is a custom tokenizer not yet available in the " + "HuggingFace transformers library, consider " + "setting `trust_remote_code=True` in LLM or using " + "the `--trust-remote-code` flag in the CLI." + ) + raise RuntimeError(err_msg) from e + + # FIXME: Temporary compatibility code for new config format. Remove after vLLM upgrade. + if "TokenizersBackend" in str(e): + logger.warning( + "TokenizerBackend not supported, patching tokenizer_config.json " + "and loading with PreTrainedTokenizerFast." + ) + tmp_dir = tempfile.mkdtemp(prefix="vllm_tokenizer_patch_") + try: + TOKENIZER_FILES = [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "added_tokens.json", + "chat_template.jinja", + "generation_config.json", + ] + + for fname in TOKENIZER_FILES: + src = os.path.join(tokenizer_name, fname) + if os.path.exists(src): + shutil.copy(src, tmp_dir) + + config_path = os.path.join(tmp_dir, "tokenizer_config.json") + with open(config_path, "r", encoding="utf-8") as f: + cfg = json.load(f) + if cfg.get("tokenizer_class") in ("TokenizersBackend",): + cfg["tokenizer_class"] = "PreTrainedTokenizerFast" + if "extra_special_tokens" in cfg: + cfg["additional_special_tokens"] = cfg.pop( + "extra_special_tokens" + ) + + with open(config_path, "w", encoding="utf-8") as f: + json.dump(cfg, f, indent=2) + + tokenizer = AutoTokenizer.from_pretrained( + tmp_dir, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + else: + raise e + + # The special_tokens in tokenizer should also be + # controlled by do_lower_case in encoder_config + encoder_config = get_sentence_transformer_tokenizer_config( + tokenizer_name, revision + ) + if isinstance(encoder_config, dict) and encoder_config.get( + "do_lower_case", False + ): + special_tokens_map = { + k: v.lower() for k, v in tokenizer.special_tokens_map.items() + } + tokenizer.add_special_tokens(special_tokens_map) + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.warning( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead." + ) + tokenizer = get_cached_tokenizer(tokenizer) + + return tokenizer + + +tokenizer.get_tokenizer = kunlun_get_tokenizer + +logger.info_once( + "[Monkey Patch Applied] >>> vllm.transformer_utils.tokenizer.get_tokenizer \ + --> vllm_kunlun.transformer_utils.tokenizer.kunlun_get_tokenizer" +)