# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import TYPE_CHECKING, Literal import torch import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.scalar_type import ScalarType from vllm.utils.flashinfer import ( flashinfer_quant_nvfp4_8x4_sf_layout, ) from vllm.utils.math_utils import cdiv import torch.nn.functional as F import ixformer.inference.functions as ops from ixformer.core import config from ixformer.distributed import _distributed as cdist logger = init_logger(__name__) current_platform.import_kernels() if TYPE_CHECKING: def register_fake(fn): return lambda name: fn else: try: from torch.library import register_fake except ImportError: from torch.library import impl_abstract as register_fake def swiglustep_and_mul_torch(output, input, limit=7.0): b, n = input.shape d = n // 2 gate = input[:, :d] up = input[:, d:] # 直接写入 output torch.mul( torch.clamp(F.silu(gate), max=limit), torch.clamp(up, -limit, limit), out=output ) def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ops.silu_and_mul(x, out) def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ops.gelu_and_mul(x, out) def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ops.gelu_tanh_and_mul(x, out) def fatrelu_and_mul(out: torch.Tensor, x: torch.Tensor, threshold: float = 0.0) -> None: raise NotImplementedError("FIX soon") def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: out.copy_(F.gelu(x, approximate="tanh")) return out def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: out.copy_(F.gelu(x, approximate="tanh")) return out def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: out.copy_(F.gelu(x, approximate="tanh")) return out def swigluoai_and_mul( out: torch.Tensor, x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0 ) -> None: ops.swigluoai_and_mul(x, out, alpha, limit) # return def swigluoai_and_mul_torch( out: torch.Tensor, x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0 ): gate, up = x[..., ::2], x[..., 1::2] gate = gate.clamp(min=None, max=limit) up = up.clamp(min=-limit, max=limit) glu = gate * torch.sigmoid(gate * alpha) gated_output = (up + 1) * glu out.copy_(gated_output) def rms_norm_qk( output_q: torch.Tensor, output_k: torch.Tensor, input_q: torch.Tensor, input_k: torch.Tensor, weight_q: torch.Tensor, weight_k: torch.Tensor, epsilon: float, ) -> None: torch.ops.ixf_ops.rms_norm_qk( output_q, output_k, input_q, input_k, weight_q, weight_k, epsilon ) def advance_step_flashattn( num_seqs: int, num_queries: int, block_size: int, input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, input_positions: torch.Tensor, seq_lens: torch.Tensor, slot_mapping: torch.Tensor, block_tables: torch.Tensor, ) -> None: """Advance a step on GPU for existing inputs for a multi-step runner""" return ops.advance_step_flashattn( num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mapping, block_tables, ) def quant_kv(kv): amax_, _ = torch.max(torch.abs(kv), dim=-1, keepdim=True) f_scale = amax_.float() / 127.0 scales = f_scale.view(kv.shape[:-1]) # 量化 kv = kv / f_scale kv = torch.clamp(torch.round(kv), -127, 127).to(torch.int8) return kv, scales # page attention ops def paged_attention_v1( out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, num_kv_heads: int, scale: float, block_tables: torch.Tensor, seq_lens: torch.Tensor, block_size: int, max_seq_len: int, alibi_slopes: torch.Tensor | None, kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step, ) def paged_attention_v2( out: torch.Tensor, exp_sum: torch.Tensor, max_logits: torch.Tensor, tmp_out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, num_kv_heads: int, scale: float, block_tables: torch.Tensor, seq_lens: torch.Tensor, block_size: int, max_seq_len: int, alibi_slopes: torch.Tensor | None, kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step, ) def paged_attention_rocm( out: torch.Tensor, exp_sum: torch.Tensor, max_logits: torch.Tensor, tmp_out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, num_kv_heads: int, scale: float, block_tables: torch.Tensor, seq_lens: torch.Tensor, query_start_loc: torch.Tensor | None, block_size: int, max_seq_len: int, alibi_slopes: torch.Tensor | None, kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, fp8_out_scale: torch.Tensor | None = None, mfma_type: str = "fp8" if envs.VLLM_ROCM_FP8_MFMA_PAGE_ATTN else "f16", ) -> None: torch.ops._rocm_C.paged_attention( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, query_start_loc, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, fp8_out_scale, mfma_type, ) def mla_decode_kvcache_cpu( out: torch.Tensor, query: torch.Tensor, kv_cache: torch.Tensor, scale: float, block_tables: torch.Tensor, seq_lens: torch.Tensor, ) -> None: torch.ops._C_cpu.mla_decode_kvcache( out, query, kv_cache, scale, block_tables, seq_lens ) # merge attn states ops def merge_attn_states( output: torch.Tensor, prefix_output: torch.Tensor, prefix_lse: torch.Tensor, suffix_output: torch.Tensor, suffix_lse: torch.Tensor, output_lse: torch.Tensor | None = None, ) -> None: torch.ops._C.merge_attn_states( output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse ) def convert_vertical_slash_indexes( q_seqlens: torch.Tensor, # [BATCH, ] kv_seqlens: torch.Tensor, # [BATCH, ] vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] context_size: int, block_size_M: int, block_size_N: int, causal: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: batch_size = slash_indexes.size(0) num_heads = slash_indexes.size(1) nnz_slash = slash_indexes.size(2) nnz_vertical = vertical_indexes.size(2) num_rows = (context_size + block_size_M - 1) // block_size_M block_count = torch.zeros( batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device ) block_offset = torch.zeros( batch_size, num_heads, num_rows, nnz_slash, dtype=q_seqlens.dtype, device=q_seqlens.device, ) column_count = torch.zeros( batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device ) column_index = torch.zeros( batch_size, num_heads, num_rows, nnz_vertical, dtype=q_seqlens.dtype, device=q_seqlens.device, ) torch.ops._C.convert_vertical_slash_indexes( block_count, block_offset, column_count, column_index, q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, context_size, block_size_M, block_size_N, causal, ) return block_count, block_offset, column_count, column_index def convert_vertical_slash_indexes_mergehead( q_seqlens: torch.Tensor, # [BATCH, ] kv_seqlens: torch.Tensor, # [BATCH, ] vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] # [N_HEADS] : different head use different number of indices vertical_indices_count: torch.Tensor, slash_indices_count: torch.Tensor, context_size: int, block_size_M: int, block_size_N: int, causal: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: batch_size = slash_indexes.size(0) num_heads = slash_indexes.size(1) nnz_slash = slash_indexes.size(2) nnz_vertical = vertical_indexes.size(2) num_rows = (context_size + block_size_M - 1) // block_size_M block_count = torch.empty( batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device ) block_offset = torch.empty( batch_size, num_heads, num_rows, nnz_slash, dtype=q_seqlens.dtype, device=q_seqlens.device, ) column_count = torch.empty( batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device ) column_index = torch.empty( batch_size, num_heads, num_rows, nnz_vertical, dtype=q_seqlens.dtype, device=q_seqlens.device, ) torch.ops._C.convert_vertical_slash_indexes_mergehead( block_count, block_offset, column_count, column_index, q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count, slash_indices_count, context_size, block_size_M, block_size_N, causal, ) return block_count, block_offset, column_count, column_index # pos encoding ops def rotary_embedding( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: # torch.ops._C.rotary_embedding( # positions, query, key, head_size, cos_sin_cache, is_neox # ) ops.vllm_rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) def batched_rotary_embedding( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor, ) -> None: ops.vllm_batched_rotary_embedding( positions, query, key, head_size, cos_sin_cache, is_neox, rot_dim, cos_sin_cache_offsets, ) # layer norm ops def rms_norm( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float ) -> None: # torch.ops._C.rms_norm(out, input, weight, epsilon) input_contiguous = input.contiguous() ops.rms_norm(input_contiguous, weight, epsilon, out) # def fused_add_rms_norm( # input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float # ) -> None: # torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) def fused_add_rms_norm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float, residual_alpha: float | None = 1.0, ) -> None: # torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) output, residual_output = ops.residual_rms_norm( input=input, weight=weight, residual=residual, eps=epsilon, residual_alpha=residual_alpha, ) input[:] = output residual[:] = residual_output def fused_qk_norm_rope( qkv: torch.Tensor, num_heads_q: int, num_heads_k: int, num_heads_v: int, head_dim: int, eps: float, q_weight: torch.Tensor, k_weight: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, position_ids: torch.Tensor, ) -> None: torch.ops._C.fused_qk_norm_rope( qkv, num_heads_q, num_heads_k, num_heads_v, head_dim, eps, q_weight, k_weight, cos_sin_cache, is_neox, position_ids, ) def apply_repetition_penalties_torch( logits: torch.Tensor, prompt_mask: torch.Tensor, output_mask: torch.Tensor, repetition_penalties: torch.Tensor, ) -> None: repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( 1, logits.size(1) ) # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0) # If logits are positive, divide by penalty, otherwise multiply by penalty. scaling = torch.where(logits > 0, 1.0 / penalties, penalties) logits *= scaling def apply_repetition_penalties_cuda( logits: torch.Tensor, prompt_mask: torch.Tensor, output_mask: torch.Tensor, repetition_penalties: torch.Tensor, ) -> None: # torch.ops._C.apply_repetition_penalties_( # logits, prompt_mask, output_mask, repetition_penalties # ) apply_repetition_penalties_torch( logits, prompt_mask, output_mask, repetition_penalties ) def apply_repetition_penalties( logits: torch.Tensor, prompt_mask: torch.Tensor, output_mask: torch.Tensor, repetition_penalties: torch.Tensor, ) -> None: """Apply repetition penalties to logits in-place. Args: logits: The logits tensor of shape [num_seqs, vocab_size]. prompt_mask: A boolean tensor indicating which tokens appear in the prompt. output_mask: A boolean tensor indicating which tokens appear in the output. repetition_penalties: The repetition penalties of shape (num_seqs, ). """ if logits.is_cuda and logits.is_contiguous(): apply_repetition_penalties_cuda( logits, prompt_mask, output_mask, repetition_penalties ) else: apply_repetition_penalties_torch( logits, prompt_mask, output_mask, repetition_penalties ) # fused quant layer norm ops def rms_norm_dynamic_per_token_quant( input: torch.Tensor, weight: torch.Tensor, epsilon: float, quant_dtype: torch.dtype, scale_ub: torch.Tensor | None = None, residual: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=quant_dtype) scales = torch.empty( (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 ) torch.ops._C.rms_norm_dynamic_per_token_quant( output, input, weight, scales, epsilon, scale_ub, residual ) return output, scales # fused quant layer norm ops blocked def rms_norm_per_block_quant( input: torch.Tensor, weight: torch.Tensor, epsilon: float, quant_dtype: torch.dtype, group_size: list[int], scale_ub: torch.Tensor | None = None, residual: torch.Tensor | None = None, is_scale_transposed: bool = False, tma_alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: assert len(group_size) == 2 output = torch.empty_like(input, dtype=quant_dtype) if is_scale_transposed: if tma_alignment == 0: scales = torch.empty( (input.shape[-1] // group_size[1], input.numel() // input.shape[-1]), device=input.device, dtype=torch.float32, ).transpose(0, 1) else: m = input.shape[-2] sf_k = input.shape[-1] // group_size[1] tma_aligned_m = (m + tma_alignment - 1) // tma_alignment * tma_alignment shape = input.shape[:-2] + (m, sf_k) stride = ( (1, tma_aligned_m) if input.dim() == 2 else (tma_aligned_m * sf_k, 1, tma_aligned_m) ) scales = torch.empty_strided( shape, stride, device=input.device, dtype=torch.float32 ) else: scales = torch.empty( (input.numel() // input.shape[-1], input.shape[-1] // group_size[1]), device=input.device, dtype=torch.float32, ) assert tma_alignment in [0, 4], "Expected TMA alignment 0 or 4, but got " + str( tma_alignment ) torch.ops._C.rms_norm_per_block_quant( output, input, weight, scales, epsilon, scale_ub, residual, group_size[1], is_scale_transposed, ) return output, scales # quantization ops # awq def awq_dequantize( qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int, ) -> torch.Tensor: if envs.VLLM_USE_TRITON_AWQ: from vllm.model_executor.layers.quantization.awq_triton import ( awq_dequantize_triton, ) return awq_dequantize_triton(qweight, scales, zeros) return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) if hasattr(torch.ops._C, "awq_dequantize"): @register_fake("_C::awq_dequantize") def _awq_dequantize_fake( qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: torch.SymInt, thx: int, thy: int, ) -> torch.Tensor: in_c = qweight.size(0) qout_c = qweight.size(1) out_c = qout_c * 8 return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device) # def awq_gemm( # input: torch.Tensor, # qweight: torch.Tensor, # scales: torch.Tensor, # qzeros: torch.Tensor, # split_k_iters: int, # ) -> torch.Tensor: # if envs.VLLM_USE_TRITON_AWQ: # from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton # return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters) # return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters) def awq_gemm( input: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, pack_factor, group_size: int = 128, ) -> torch.Tensor: return ops.wui4a16(input, qweight, scales, qzeros, None, group_size, "NN") if hasattr(torch.ops._C, "awq_gemm"): @register_fake("_C::awq_gemm") def _awq_gemm_fake( input: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, split_k_iters: torch.SymInt, ) -> torch.Tensor: num_in_feats = input.size(0) return torch.empty( (split_k_iters, num_in_feats, qweight.size(1) * 8), dtype=input.dtype, device=input.device, ).sum(0) # gptq def gptq_gemm( a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, use_exllama: bool, use_v2_format: bool, bit: int, ) -> torch.Tensor: # return torch.ops._C.gptq_gemm( # a, # b_q_weight, # b_gptq_qzeros, # b_gptq_scales, # b_g_idx, # use_exllama, # use_v2_format, # bit, # ) return ops.gptq_gemm( a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit ) if hasattr(torch.ops._C, "gptq_gemm"): @register_fake("_C::gptq_gemm") def _gptq_gemm_fake( a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, use_exllama: bool, use_v2_format: bool, bit: int, ) -> torch.Tensor: return torch.empty( (a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device ) def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: # torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) ops.vllm_gptq_shuffle(q_weight, q_perm, bit) if hasattr(torch.ops._C, "allspark_w8a16_gemm"): @register_fake("_C::allspark_w8a16_gemm") def _allspark_w8a16_gemm_fake( a: torch.Tensor, b_qweight: torch.Tensor, b_scales: torch.Tensor, b_qzeros: torch.Tensor | None, n: torch.SymInt, group_size: torch.SymInt, sm_count: torch.SymInt, sm_version: torch.SymInt, CUBLAS_M_THRESHOLD: torch.SymInt, has_zp: bool, n32k16_reorder: bool, ) -> torch.Tensor: m = a.size(0) return torch.empty((m, n), device=a.device, dtype=a.dtype) if hasattr(torch.ops._C, "ggml_dequantize"): @register_fake("_C::ggml_dequantize") def _ggml_dequantize_fake( W: torch.Tensor, quant_type: int, m: torch.SymInt, n: torch.SymInt, dtype: torch.dtype | None = None, ) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) @register_fake("_C::ggml_mul_mat_vec_a8") def _ggml_mul_mat_vec_a8_fake( W: torch.Tensor, X: torch.Tensor, quant_type: int, row: torch.SymInt, ) -> torch.Tensor: return torch.empty((X.shape[0], row), dtype=X.dtype, device=W.device) @register_fake("_C::ggml_mul_mat_a8") def _ggml_mul_mat_a8_fake( W: torch.Tensor, X: torch.Tensor, quant_type: int, row: torch.SymInt, ) -> torch.Tensor: batch = X.size(0) return torch.empty((batch, row), dtype=X.dtype, device=W.device) @register_fake("_C::ggml_moe_a8") def _ggml_moe_a8_fake( X: torch.Tensor, W: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, quant_type: int, row: torch.SymInt, top_k: torch.SymInt, tokens: torch.SymInt, ) -> torch.Tensor: tokens = X.size(0) return torch.empty((tokens * top_k, row), dtype=torch.float16, device=W.device) if hasattr(torch.ops._C, "ggml_moe_a8_vec"): @register_fake("_C::ggml_moe_a8_vec") def _ggml_moe_a8_vec_fake( X: torch.Tensor, W: torch.Tensor, topk_ids: torch.Tensor, top_k: int, quant_type: int, row: torch.SymInt, tokens: torch.SymInt, ) -> torch.Tensor: tokens = X.size(0) return torch.empty((tokens * top_k, row), dtype=X.dtype, device=W.device) # cutlass def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) def cutlass_scaled_fp4_mm( a: torch.Tensor, b: torch.Tensor, block_scale_a: torch.Tensor, block_scale_b: torch.Tensor, alpha: torch.Tensor, out_dtype: torch.dtype, ) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 m, n = a.shape[0], b.shape[0] out = torch.empty((m, n), dtype=out_dtype, device=a.device) torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, alpha) return out def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: # return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) return False def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: # return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) return False def cutlass_scaled_mm( a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, bias: torch.Tensor | None = None, format: str | None = "TN", ) -> torch.Tensor: """ `cutlass_scaled_mm` implements a fused version of `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` where scale_a * a and scale_b * b are implemented using numpy-style broadcasting. In order to support blockwise scaling like found in DeepSeek V3 we also support extended "group" broadcast rules. We extend the numpy-style broadcasting rules with the following rule: "if the extent of a dimension in the source shape is between 1 and corresponding extent in the target shape we repeat each element along that dimension src_shape[dim] // target_shape[dim] times consecutively" example if we have: a = [[1, 2], and target_shape = (2, 4) [3, 4]] then we would expand a to: a = [[1, 1, 2, 2], [3, 3, 4, 4]] currently we only support the case: scale_a.shape * [1, 128] == a.shape scale_b.shape * [128, 128] == b.shape """ target_shape = (*a.shape[:-1], b.shape[1]) assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype # a is x, b is weight m = a.shape[:-1] n = b.shape[1] * 2 if envs.VLLM_W8A8_LINEAR_USE_W4A8 else b.shape[1] if format == "TN": b = b.t() out = torch.empty(m + (n,), dtype=out_dtype, device=a.device) if envs.VLLM_W8A8_LINEAR_USE_W4A8: assert format == "NN" ops.w4a8( a, b, scale_a, scale_b, bias=bias, format=0, output=out.view(-1, n), output_dtype=out_dtype, ) else: ops.w8a8( a, b, scale_a, scale_b, bias, format=format, output=out.view(-1, n), out_dtype=out_dtype, ) return out def cutlass_scaled_mm_azp( a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, azp_adj: torch.Tensor, azp: torch.Tensor | None = None, bias: torch.Tensor | None = None, ) -> torch.Tensor: """ :param azp_adj: In the per-tensor case, this should include the azp. Always per-channel. :param azp: Only set in the per-token case. Per-token if set. """ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype # Massage the input to be 2D target_shape = (*a.shape[:-1], b.shape[1]) a = a.view(-1, a.shape[-1]) assert azp is None or azp.numel() == a.shape[0] out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) return out.view(*target_shape) def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: # return torch.ops._C.cutlass_sparse_scaled_mm_supported(cuda_device_capability) return False def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: if cuda_device_capability < 90 or cuda_device_capability >= 110: return False try: return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability) except AttributeError: # Return False on non-CUDA platforms where it is not available return False def cutlass_sparse_compress(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Compresses a sparse matrix for use with Cutlass sparse operations. This function takes a dense tensor and compresses it into two components: non-zero elements and metadata. The compressed representation is compatible with Cutlass sparse kernels. Args: a (torch.Tensor): The input tensor to be compressed. Must have one of the following data types: - `torch.int8` - `torch.float8_e4m3fn` - `torch.bfloat16` - `torch.float16` Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`. - `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation. Raises: ValueError: If the compression operation fails. Notes: - The `a_meta` tensor has a data type of `torch.uint8`. - Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`). - The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor. - The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`. """ assert a.dtype in [torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16] assert a.is_contiguous() # a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4 elemsPerMetaElem = 4 assert a.shape[1] % (2 * elemsPerMetaElem) == 0 return torch.ops._C.cutlass_sparse_compress(a) def cutlass_scaled_sparse_mm( a: torch.Tensor, bt_nzs: torch.Tensor, bt_meta: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, bias: torch.Tensor | None = None, ) -> torch.Tensor: """ Performs a scaled sparse matrix multiplication using Cutlass. Steps: 1. Create a dense matrix `a` of shape (m, k) on the CUDA device: `a = torch.randn((m, k), device='cuda')`. 2. Create a dense matrix `b` of shape (k, n) on the CUDA device: `b = torch.randn((k, n), device='cuda')`. 3. Prune matrix `b` to 2:4 sparsity along the specified dimension: `b = prune_to_2_4(b, dim=0)`. 4. Compress the transposed sparse matrix `b.t()`: `bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`. 5. Perform sparse matrix multiplication using the compressed matrix, applying scaling factors for `a` and `b`, and the output data type: `out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`. Returns: - The result of the scaled sparse matrix multiplication. """ assert bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0 assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 assert bias is None or bias.shape[0] == bt_nzs.shape[0] and bias.dtype == out_dtype m = a.shape[0] n = bt_nzs.shape[0] out = torch.empty((m, n), dtype=out_dtype, device=a.device) torch.ops._C.cutlass_scaled_sparse_mm( out, a, bt_nzs, bt_meta, scale_a, scale_b, bias ) return out def get_cutlass_moe_mm_data( topk_ids: torch.Tensor, expert_offsets: torch.Tensor, problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, input_permutation: torch.Tensor, output_permutation: torch.Tensor, num_experts: int, n: int, k: int, blockscale_offsets: torch.Tensor | None = None, ): """ Prepare data necessary to perform CUTLASS grouped matrix multiplications used in CUTLASS-based fused MoE. The function takes in topk_ids (token-expert mapping) and uses it to compute: - expert_offsets: Indices that mark at which token index each expert begins its computation after the input is sorted with input_permutation. The number of tokens computed with expert E is expert_offsets[E + 1] - expert_offsets[E] - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's multiplication in two grouped MMs used in the fused MoE operation. - input_permutation: Permutation that must be used to shuffle the input before executing the MMs. - output_permutation: Permutation that must be used to shuffle the output after executing the MMs. - blockscale_offsets: Optional argument passed for fp4 moe. Indices that mark at which block scale index each expert begins its computation. The number of block scale rows computed with expert E is blockscale_offsets[E + 1] - blockscale_offsets[E] """ return torch.ops._C.get_cutlass_moe_mm_data( topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k, blockscale_offsets, ) def get_cutlass_moe_mm_problem_sizes_from_expert_offsets( expert_first_token_offset: torch.Tensor, problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, n: int, k: int, swap_ab: bool, ): """Compute per-expert (M, N, K) problem sizes from expert_first_token_offset""" return torch.ops._C.get_cutlass_moe_mm_problem_sizes_from_expert_offsets( expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab, ) def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): """ Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor. This is used in MoE to permute the input tensor before performing grouped matrix multiplications. """ num_tokens_permuted = dst2src_map.shape[0] output_tensor = torch.empty( (num_tokens_permuted, input_tensor.shape[1]), device=input_tensor.device, dtype=input_tensor.dtype, ) torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor) return output_tensor def get_cutlass_pplx_moe_mm_data( expert_offsets: torch.Tensor, problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, expert_num_tokens: torch.Tensor, num_local_experts: int, padded_m: int, n: int, k: int, ): """ Prepare data necessary to perform CUTLASS grouped matrix multiplications used in CUTLASS-based fused MoE. The function takes in expert_num_tokens (token count per expert) and non_zero_expert_idxs (consecutive indices of experts with non-zero token counts) and uses them to compute: - expert_offsets: Indices that mark at which token index each expert begins its computation. - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's multiplication in two grouped MMs used in the fused MoE operation. """ return torch.ops._C.get_cutlass_pplx_moe_mm_data( expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens, num_local_experts, padded_m, n, k, ) def cutlass_moe_mm( out_tensors: torch.Tensor, a_tensors: torch.Tensor, b_tensors: torch.Tensor, a_scales: torch.Tensor, b_scales: torch.Tensor, expert_offsets: torch.Tensor, problem_sizes: torch.Tensor, a_strides: torch.Tensor, b_strides: torch.Tensor, c_strides: torch.Tensor, per_act_token: bool, per_out_ch: bool, ): """ A single grouped matrix multiplication used in CUTLASS-based fused MoE. The function executes fp8-quantized OUT = AB matrix multiplication. - expert_offsets: Indices that mark at which token index each expert begins its computation. The number of tokens computed with expert E is expert_offsets[E + 1] - expert_offsets[E] - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped MMs used in the fused MoE operation. - a/b/c_strides: The data strides passed to grouped matrix multiplication. """ return torch.ops._C.cutlass_moe_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch, ) def cutlass_fp4_moe_mm( out_tensors: torch.Tensor, a_tensors: torch.Tensor, b_tensors: torch.Tensor, a_scales: torch.Tensor, b_scales: torch.Tensor, alphas: torch.Tensor, problem_sizes: torch.Tensor, expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, ): """ An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs the gemms for each combination based on the specified problem sizes. This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward. - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized input and expert weights. - a_/b_scales: The blockscales in FP8-E4M3 precision - expert_offsets/sf_offsets: Indices that mark at which token index each expert begins its computation. The number of tokens computed with expert E is expert_offsets[E + 1] - expert_offsets[E] And the sf_size per expert is sf_offset[E+1] - sf_offset[E] - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped MMs used in the fused MoE operation. """ return torch.ops._C.cutlass_fp4_group_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, alphas, problem_sizes, expert_offsets, sf_offsets, ) # gptq_marlin def gptq_marlin_repack( b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int, is_a_8bit: bool = False, ) -> torch.Tensor: return torch.ops._C.gptq_marlin_repack( b_q_weight, perm, size_k, size_n, num_bits, is_a_8bit ) if hasattr(torch.ops._C, "gptq_marlin_repack"): @register_fake("_C::gptq_marlin_repack") def _gptq_marlin_repack_fake( b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: torch.SymInt, size_n: torch.SymInt, num_bits: int, is_a_8bit: bool = False, ) -> torch.Tensor: pack_factor = 32 // num_bits marlin_tile_size = 16 return torch.empty( (size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor), dtype=b_q_weight.dtype, device=b_q_weight.device, ) # awq_marlin def awq_marlin_repack( b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int, is_a_8bit: bool = False, ) -> torch.Tensor: return torch.ops._C.awq_marlin_repack( b_q_weight, size_k, size_n, num_bits, is_a_8bit ) if hasattr(torch.ops._C, "awq_marlin_repack"): @register_fake("_C::awq_marlin_repack") def _awq_marlin_repack_fake( b_q_weight: torch.Tensor, size_k: torch.SymInt, size_n: torch.SymInt, num_bits: int, is_a_8bit: bool = False, ) -> torch.Tensor: pack_factor = 32 // num_bits marlin_tile_size = 16 return torch.empty( (size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor), dtype=b_q_weight.dtype, device=b_q_weight.device, ) def gptq_marlin_moe_repack( b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int, is_a_8bit: bool = False, ) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 output = torch.empty( (num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype, ) for e in range(num_experts): output[e] = torch.ops._C.gptq_marlin_repack( b_q_weight[e], perm[e], size_k, size_n, num_bits, is_a_8bit ) return output def awq_marlin_moe_repack( b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int, is_a_8bit: bool = False, ) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 output = torch.empty( (num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype, ) for e in range(num_experts): output[e] = torch.ops._C.awq_marlin_repack( b_q_weight[e], size_k, size_n, num_bits, is_a_8bit ) return output def marlin_int4_fp8_preprocess( qweight: torch.Tensor, qzeros_or_none: torch.Tensor | None = None, inplace: bool = False, ): return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace) def marlin_gemm( a: torch.Tensor, c: torch.Tensor | None, b_q_weight: torch.Tensor, b_bias: torch.Tensor | None, b_scales: torch.Tensor, a_scales: torch.Tensor | None, global_scale: torch.Tensor | None, b_zeros: torch.Tensor | None, g_idx: torch.Tensor | None, perm: torch.Tensor | None, workspace: torch.Tensor, b_q_type: ScalarType, size_m: int, size_n: int, size_k: int, is_k_full: bool = True, use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False, ) -> torch.Tensor: return torch.ops._C.marlin_gemm( a, c, b_q_weight, b_bias, b_scales, a_scales, global_scale, b_zeros, g_idx, perm, workspace, b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float, ) if hasattr(torch.ops._C, "marlin_gemm"): @register_fake("_C::marlin_gemm") def _marlin_gemm_fake( a: torch.Tensor, c: torch.Tensor | None, b_q_weight: torch.Tensor, b_bias: torch.Tensor | None, b_scales: torch.Tensor, a_scales: torch.Tensor | None, global_scale: torch.Tensor | None, b_zeros: torch.Tensor | None, g_idx: torch.Tensor | None, perm: torch.Tensor | None, workspace: torch.Tensor, b_q_type_id: int, size_m: torch.SymInt, size_n: torch.SymInt, size_k: torch.SymInt, is_k_full: bool = True, use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False, ) -> torch.Tensor: dtype = a.dtype if dtype not in [torch.half, torch.bfloat16]: dtype = b_scales.dtype return torch.empty((size_m, size_n), device=a.device, dtype=dtype) # machete def machete_supported_schedules( a_type: torch.dtype, b_type: ScalarType, group_scales_type: torch.dtype | None, group_zeros_type: torch.dtype | None = None, channel_scales_type: torch.dtype | None = None, token_scales_type: torch.dtype | None = None, out_type: torch.dtype | None = None, ) -> list[str]: return torch.ops._C.machete_supported_schedules( a_type, b_type.id, group_scales_type, group_zeros_type, channel_scales_type, token_scales_type, out_type, ) def machete_mm( a: torch.Tensor, # b_q Should be the tensor returned by machete_prepack_B b_q: torch.Tensor, b_type: ScalarType, out_type: torch.dtype | None = None, b_group_scales: torch.Tensor | None = None, b_group_zeros: torch.Tensor | None = None, b_group_size: int | None = None, b_channel_scales: torch.Tensor | None = None, a_token_scales: torch.Tensor | None = None, schedule: str | None = None, ) -> torch.Tensor: return torch.ops._C.machete_mm( a, b_q, b_type.id, out_type, b_group_scales, b_group_zeros, b_group_size, b_channel_scales, a_token_scales, schedule, ) if hasattr(torch.ops._C, "machete_mm"): @register_fake("_C::machete_mm") def machete_mm_fake( a: torch.Tensor, # b_q Should be the tensor returned by machete_prepack_B b_q: torch.Tensor, b_type: ScalarType, out_type: torch.dtype | None = None, b_group_scales: torch.Tensor | None = None, b_group_zeros: torch.Tensor | None = None, b_group_size: int | None = None, b_channel_scales: torch.Tensor | None = None, a_token_scales: torch.Tensor | None = None, schedule: str | None = None, ) -> torch.Tensor: m = a.size(0) n = b_q.size(1) return torch.empty((m, n), device=a.device, dtype=a.dtype) def machete_prepack_B( b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, group_scales_type: torch.dtype | None, ) -> torch.Tensor: return torch.ops._C.machete_prepack_B( b_q_weight, a_type, b_type.id, group_scales_type ) if hasattr(torch.ops._C, "machete_prepack_B"): @register_fake("_C::machete_prepack_B") def machete_prepack_B_fake( b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, group_scales_type: torch.dtype | None, ) -> torch.Tensor: return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) # CUTLASS W4A8 def cutlass_w4a8_mm( a: torch.Tensor, # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b b_q: torch.Tensor, b_group_scales: torch.Tensor, b_group_size: int, b_channel_scales: torch.Tensor, a_token_scales: torch.Tensor, out_type: torch.dtype | None = None, maybe_schedule: str | None = None, ) -> torch.Tensor: return torch.ops._C.cutlass_w4a8_mm( a, b_q, b_group_scales, b_group_size, b_channel_scales, a_token_scales, out_type, maybe_schedule, ) if hasattr(torch.ops._C, "cutlass_w4a8_mm"): @register_fake("_C::cutlass_w4a8_mm") def cutlass_w4a8_mm_fake( a: torch.Tensor, # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b b_q: torch.Tensor, b_group_scales: torch.Tensor, b_group_size: int, b_channel_scales: torch.Tensor, a_token_scales: torch.Tensor, out_type: torch.dtype | None = None, maybe_schedule: str | None = None, ) -> torch.Tensor: m = a.size(0) n = b_q.size(1) out_dtype = out_type if out_type is not None else torch.bfloat16 return torch.empty((m, n), device=a.device, dtype=out_dtype) def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor: return torch.ops._C.cutlass_pack_scale_fp8(scales) if hasattr(torch.ops._C, "cutlass_pack_scale_fp8"): @register_fake("_C::cutlass_pack_scale_fp8") def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor: return torch.empty_like(scales, memory_format=torch.contiguous_format) def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: return torch.ops._C.cutlass_encode_and_reorder_int4b(b) if hasattr(torch.ops._C, "cutlass_encode_and_reorder_int4b"): @register_fake("_C::cutlass_encode_and_reorder_int4b") def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: return torch.empty_like(b, memory_format=torch.contiguous_format) def cutlass_w4a8_moe_mm( out_tensors: torch.Tensor, a_tensors: torch.Tensor, b_tensors: torch.Tensor, a_scales: torch.Tensor, b_scales: torch.Tensor, b_group_scales: torch.Tensor, b_group_size: int, expert_offsets: torch.Tensor, problem_sizes: torch.Tensor, a_strides: torch.Tensor, b_strides: torch.Tensor, c_strides: torch.Tensor, group_scale_strides: torch.Tensor, maybe_schedule: str | None = None, ): """ Executes the CUTLASS-based fused-MoE grouped matrix multiplication for the W4A8 quantization scheme. Uses group-wise quantization (INT4 -> FP8) and both per-channel + per-token scaling in the epilogue. Args: out_tensors: Output buffer for all experts (updated in-place). a_tensors: FP8 (E4M3FN) activations for all experts. b_tensors: INT4-packed weight matrix for all experts, packed to INT32 a_scales: Per-token FP8 activation scales, applied in the epilogue. b_scales: Per-channel FP8 weight scales for each expert, applied in the epilogue. b_group_scales: FP8 scale values for group-wise INT4 weight blocks. b_group_size: Number of elements grouped under each entry of b_group_scales. expert_offsets: Cumulative token offsets problem_sizes: Per-expert (M, N, K) GEMM sizes used by the grouped GEMM launcher. a/b/c/group_scale_strides: Strides describing the memory layout of the input tensors. maybe_schedule: Optional override to choose a specific kernel or epilogue schedule. Returns: out_tensors updated in-place with the dequantized INT4xFP8 grouped GEMM result. """ return torch.ops._C.cutlass_w4a8_moe_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides, maybe_schedule, ) def cutlass_encode_and_reorder_int4b_grouped( b_tensors: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: return torch.ops._C.cutlass_encode_and_reorder_int4b_grouped(b_tensors) if hasattr(torch.ops._C, "cutlass_encode_and_reorder_int4b_grouped"): @register_fake("_C::cutlass_encode_and_reorder_int4b_grouped") def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor: return torch.empty_like(b, memory_format=torch.contiguous_format) def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.ops._C.permute_cols(a, perm) if hasattr(torch.ops._C, "permute_cols"): @register_fake("_C::permute_cols") def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.empty_like(a) # fp4 def scaled_fp4_quant( input: torch.Tensor, input_global_scale: torch.Tensor, is_sf_swizzled_layout: bool = True, backend: str = "none", ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP4 and return quantized tensor and scale. This function quantizes the last dimension of the given tensor `input`. For every 16 consecutive elements, a single dynamically computed scaling factor is shared. This scaling factor is quantized using the `input_global_scale` and is stored in a swizzled layout (see https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x). Args: input: The input tensor to be quantized to FP4 input_global_scale: A scalar scaling factor for the entire tensor. use_8x4_sf_layout: Whether to use the 8x4 or 128x4 layout for the scaling Returns: tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every two values are packed into a uint8 and float8_e4m3 scaling factors in the sizzled layout. """ assert not current_platform.is_rocm() assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}." other_dims = 1 if input.ndim == 1 else -1 input = input.reshape(other_dims, input.shape[-1]) m, n = input.shape block_size = 16 device = input.device assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}." assert input.dtype in (torch.float16, torch.bfloat16), ( f"input.dtype needs to be fp16 or bf16 but got {input.dtype}." ) use_8x4_sf_layout = True if "trtllm" in backend and m <= 32 else False # noqa: SIM210 if use_8x4_sf_layout: output, output_scale = flashinfer_quant_nvfp4_8x4_sf_layout( input, input_global_scale ) else: # Two fp4 values will be packed into an uint8. output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) if is_sf_swizzled_layout: # We use the rounded values to store the swizzled values. Due to the # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. # So, we first pad the scales to multiples of 128 and 4. Then, the scales # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x round_up = lambda x, y: (x + y - 1) // y * y rounded_m = round_up(m, 128) scale_n = n // block_size rounded_n = round_up(scale_n, 4) output_scale = torch.empty( (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 ) else: output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8) torch.ops._C.scaled_fp4_quant( output, input, output_scale, input_global_scale, is_sf_swizzled_layout ) output_scale = output_scale.view(torch.float8_e4m3fn) return output, output_scale def scaled_fp4_experts_quant( input_tensor: torch.Tensor, input_global_scale: torch.Tensor, expert_offsets: torch.Tensor, blockscale_offsets: torch.Tensor, topk: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to NVFP4 and return quantized tensor and scale, for packed MoE Inputs. Args: input_tensor: The input tensor to be quantized to NVFP4 input_global_scale: A scalar scaling factor for the entire tensor. expert_offsets: The expert offsets tensor blockscale_offsets: The blockscale offsets tensor Outputs: output: The quantized tensor in NVFP4 output_scales: The blockscale tensor in FP8-E4M3 """ assert not current_platform.is_rocm() assert input_tensor.ndim == 2, ( f"input.ndim needs to be == 2, but got {input_tensor.ndim}." ) # Control the maximum number of tokens per expert supported by the # NVFP4 MoE Expert Quantization. This is used to prevent the kernel # from running out of memory. This value can also be increased to support # larger models. MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE m_numtopk, k = input_tensor.shape assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, ( f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" f"{MAX_TOKENS_PER_EXPERT})" f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use" f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value." ) scales_k = k // 16 padded_k = (scales_k + (4 - 1)) // 4 # output is uint8 and packed fp4 values output = torch.empty( m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8 ) output_scales = torch.empty( MAX_TOKENS_PER_EXPERT * topk, padded_k, dtype=torch.int32, device=input_tensor.device, ) torch.ops._C.scaled_fp4_experts_quant( output, output_scales, input_tensor, input_global_scale, expert_offsets, blockscale_offsets, ) output_scales = output_scales.view(torch.float8_e4m3fn) return output, output_scales def silu_and_mul_scaled_fp4_experts_quant( input_tensor: torch.Tensor, input_global_scale: torch.Tensor, expert_offsets: torch.Tensor, blockscale_offsets: torch.Tensor, topk: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ Fused SiLU+Mul+NVFP4 quantization for MoE intermediate activations. Args: input_tensor: The input tensor with gate || up layout [m_topk, k*2] input_global_scale: A per-expert scaling factor [n_experts] expert_offsets: The expert offsets tensor [n_experts+1] blockscale_offsets: The blockscale offsets tensor [n_experts+1] topk: Number of top-k experts selected Outputs: output: The quantized tensor in NVFP4 [m_topk, k/2] output_scales: The blockscale tensor in FP8-E4M3 """ assert not current_platform.is_rocm() assert input_tensor.ndim == 2, ( f"input.ndim needs to be == 2, but got {input_tensor.ndim}." ) # Control the maximum number of tokens per expert supported by the # NVFP4 MoE Expert Quantization. This is used to prevent the kernel # from running out of memory. This value can also be increased to support # larger models. MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE m_numtopk, k_times_2 = input_tensor.shape assert k_times_2 % 2 == 0, "input width must be even (gate || up layout)" k = k_times_2 // 2 assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, ( f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" f"{MAX_TOKENS_PER_EXPERT})" f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use" f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value." ) scales_k = k // 16 padded_k = (scales_k + (4 - 1)) // 4 # output is uint8 and packed fp4 values output = torch.empty( m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8 ) output_scales = torch.empty( MAX_TOKENS_PER_EXPERT * topk, padded_k, dtype=torch.int32, device=input_tensor.device, ) torch.ops._C.silu_and_mul_scaled_fp4_experts_quant( output, output_scales, input_tensor, input_global_scale, expert_offsets, blockscale_offsets, ) output_scales = output_scales.view(torch.float8_e4m3fn) return output, output_scales # fp8 def scaled_fp8_quant( input: torch.Tensor, scale: torch.Tensor | None = None, num_token_padding: int | None = None, scale_ub: torch.Tensor | None = None, use_per_token_if_dynamic: bool = False, output: torch.Tensor | None = None, group_shape: tuple[int, int] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. This function supports both static and dynamic quantization: If you provide the scale, it will use static scaling and if you omit it, the scale will be determined dynamically. The function also allows optional padding of the output tensors for downstream kernels that will benefit from padding. Args: input: The input tensor to be quantized to FP8 (must be 2D: [M, N]) scale: Optional scaling factor for the FP8 quantization. Supports: - 0D or [1]: per-tensor scaling - 1D: requires explicit group_shape to disambiguate per-channel vs per-token (use (-1, 1) for per-channel, (1, -1) for per-token) - 2D [M/group_m, N/group_n]: group scaling (e.g. [M, N/128] for DeepSeek-style (1,128) groups, or [M/128, N/128] for (128,128)) scale_ub: Optional upper bound for scaling factor in dynamic per token case num_token_padding: If specified, pad the first dimension of the output to at least this value. use_per_token_if_dynamic: Whether to do per_tensor or per_token in the dynamic quantization case. group_shape: Optional tuple (group_m, group_n) specifying the group shape for static quantization. Use -1 for "full extent" (e.g., (-1, -1) for per-tensor, (-1, 1) for per-channel, etc.) Required for 1D scales; optional for 2D scales. Returns: tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ # This code assumes batch_dim and num_tokens are flattened assert input.ndim == 2 shape: tuple[int, int] | torch.Size = input.shape # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) if output is None: output = torch.empty(shape, device=input.device, dtype=out_dtype) else: assert num_token_padding is None, "padding not supported if output passed in" assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) torch.ops._C.dynamic_per_token_scaled_fp8_quant( output, input, scale, scale_ub ) else: scale = torch.empty(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: torch.ops._C.static_scaled_fp8_quant(output, input, scale, group_shape) return output, scale # gptq allspark def allspark_repack_weight( qweight: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor | None = None, has_zp: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format for Ampere W8A16 Fused Gemm kernel Args: qweight: uint8 weight tensor, original k x n format. scale: fp16/bf16 weight scale tensor, 1 x n format. zero_point: fp16/bf16 weight zero_point tensor, 1 x n format. Must be provided for asymmetric quantization. has_zp: if use symmetric quantization, has_zp = False. if use asymmetric quantization, has_zp = True. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : rearranged weight, scale, and optionally zero_point. """ K = qweight.shape[0] N = qweight.shape[1] N_32align = (N + 32 - 1) // 32 * 32 qweight_reorder = torch.empty( (N_32align, K), device=qweight.device, dtype=qweight.dtype ) scale_reorder = torch.empty((1, N_32align), device=scale.device, dtype=scale.dtype) zero_point_reorder = None if has_zp: assert zero_point is not None, ( "zero_point must be provided for asymmetric quantization." ) zero_point_reorder = torch.empty( (1, N_32align), device=zero_point.device, dtype=zero_point.dtype ) torch.ops._C.rearrange_kn_weight_as_n32k16_order( qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder, zero_point_reorder, K, N, N_32align, ) return qweight_reorder, scale_reorder, zero_point_reorder def allspark_w8a16_gemm( a: torch.Tensor, b_qweight: torch.Tensor, b_scales: torch.Tensor, b_qzeros: torch.Tensor | None, n: int, group_size: int, sm_count: int, sm_version: int, CUBLAS_M_THRESHOLD: int, has_zp: bool, n32k16_reorder: bool, ) -> torch.Tensor: return torch.ops._C.allspark_w8a16_gemm( a, b_qweight, b_scales, b_qzeros, n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, has_zp, n32k16_reorder, ) # int8 def scaled_int8_quant( input: torch.Tensor, scale: torch.Tensor | None = None, azp: torch.Tensor | None = None, symmetric: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. Args: input: The input tensor to be quantized to int8. scale: Optional scaling factor for the int8 quantization. When not provided, we invoke dynamic-per-token quantization. azp: Optional zero-point for the int8 quantization. Must be provided for asymmetric quantization if `scale` is provided. symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. assert symmetric == (azp is None), ( "azp must only be provided for asymmetric quantization." ) torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. input_scales = torch.empty( (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 ) input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) # torch.ops._C.dynamic_scaled_int8_quant( # output, input.contiguous(), input_scales, input_azp # ) ops.dynamic_scaled_int8_quant(output, input, input_scales) return output, input_scales, input_azp # gguf def ggml_dequantize( W: torch.Tensor, quant_type: int, m: int, n: int, dtype: torch.dtype | None ) -> torch.Tensor: return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype) def ggml_mul_mat_vec_a8( W: torch.Tensor, X: torch.Tensor, quant_type: int, row: int, ) -> torch.Tensor: return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row) def ggml_mul_mat_a8( W: torch.Tensor, X: torch.Tensor, quant_type: int, row: int, ) -> torch.Tensor: return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) def ggml_moe_a8( X: torch.Tensor, W: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, quant_type: int, row: int, top_k: int, tokens: int, ) -> torch.Tensor: return torch.ops._C.ggml_moe_a8( X, W, sorted_token_ids, expert_ids, num_tokens_post_padded, quant_type, row, top_k, tokens, ) def ggml_moe_a8_vec( X: torch.Tensor, W: torch.Tensor, topk_ids: torch.Tensor, top_k: int, quant_type: int, row: torch.SymInt, tokens: torch.SymInt, ) -> torch.Tensor: return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row, tokens) def ggml_moe_get_block_size(quant_type: int) -> int: return torch.ops._C.ggml_moe_get_block_size(quant_type) # mamba def selective_scan_fwd( u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D_: torch.Tensor | None, z_: torch.Tensor | None, delta_bias_: torch.Tensor | None, delta_softplus: bool, query_start_loc: torch.Tensor | None, cache_indices: torch.Tensor | None, has_initial_state: torch.Tensor | None, ssm_states: torch.Tensor, pad_slot_id: int, block_size: int = 1024, block_idx_first_scheduled_token: torch.Tensor | None = None, block_idx_last_scheduled_token: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None, ): torch.ops._C.selective_scan_fwd( u, delta, A, B, C, D_, z_, delta_bias_, delta_softplus, query_start_loc, cache_indices, has_initial_state, ssm_states, pad_slot_id, block_size, block_idx_first_scheduled_token, block_idx_last_scheduled_token, initial_state_idx, ) # ROCm skinny gemms def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor: return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) def wvSplitK( a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None ) -> torch.Tensor: return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count) def wvSplitKrc( a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None ) -> torch.Tensor: return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count) def wvSplitKQ( a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, cu_count: int, bias: torch.Tensor = None, ) -> torch.Tensor: out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device) torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count) return out # moe def moe_sum(input: torch.Tensor, output: torch.Tensor): torch.ops._moe_C.moe_sum(input, output) def moe_align_block_size( topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, expert_map: torch.Tensor | None = None, ) -> None: # torch.ops._moe_C.moe_align_block_size( # topk_ids, # num_experts, # block_size, # sorted_token_ids, # experts_ids, # num_tokens_post_pad, # expert_map, # ) ops.vllm_moe_align_block_size( topk_ids, num_experts, block_size, sorted_token_ids, experts_ids, num_tokens_post_pad, ) def batched_moe_align_block_size( max_tokens_per_batch: int, block_size: int, expert_num_tokens: torch.Tensor, sorted_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, ) -> None: torch.ops._moe_C.batched_moe_align_block_size( max_tokens_per_batch, block_size, expert_num_tokens, sorted_ids, expert_ids, num_tokens_post_pad, ) def moe_lora_align_block_size( topk_ids: torch.Tensor, token_lora_mapping: torch.Tensor, num_experts: int, block_size: int, max_loras: int, max_num_tokens_padded: int, max_num_m_blocks: int, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, adapter_enabled: torch.Tensor, lora_ids: torch.Tensor, expert_map: torch.Tensor | None = None, ) -> None: torch.ops._moe_C.moe_lora_align_block_size( topk_ids, token_lora_mapping, num_experts, block_size, max_loras, max_num_tokens_padded, max_num_m_blocks, sorted_token_ids, experts_ids, num_tokens_post_pad, adapter_enabled, lora_ids, expert_map, ) def moe_wna16_gemm( input: torch.Tensor, output: torch.Tensor, b_qweight: torch.Tensor, b_scales: torch.Tensor, b_qzeros: torch.Tensor | None, topk_weights: torch.Tensor | None, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, top_k: int, BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int, bit: int, ) -> torch.Tensor: if not current_platform.is_cuda(): raise NotImplementedError( "The optimized moe_wna16_gemm kernel is only available on CUDA platforms" ) torch.ops._moe_C.moe_wna16_gemm( input, output, b_qweight, b_scales, b_qzeros, topk_weights, sorted_token_ids, experts_ids, num_tokens_post_pad, top_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit, ) def dsv3_router_gemm( hidden_states: torch.Tensor, router_weight: torch.Tensor, output_dtype: torch.dtype, ) -> torch.Tensor: output = torch.empty( hidden_states.shape[0], router_weight.shape[0], device=hidden_states.device, dtype=output_dtype, ) torch.ops._moe_C.dsv3_router_gemm(output, hidden_states, router_weight) return output def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool = False, e_score_correction_bias: torch.Tensor | None = None, ) -> None: # torch.ops._moe_C.topk_softmax( # topk_weights, # topk_ids, # token_expert_indices, # gating_output, # renormalize, # e_score_correction_bias, # ) ops.vllm_moe_topk_softmax( topk_weights, topk_ids, token_expert_indices, gating_output ) if renormalize: topk_weights[:] = topk_weights / topk_weights.sum(dim=-1, keepdim=True) def topk_sigmoid_torch( topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool = False, e_score_correction_bias: torch.Tensor | None = None, ): batch, num_experts = gating_output.shape k = topk_weights.shape[1] # Sigmoid + bias probs = torch.sigmoid(gating_output) if e_score_correction_bias is not None: probs = probs + e_score_correction_bias.unsqueeze(0) # Top-K topk_vals, topk_idx = torch.topk(probs, k=k, dim=1) # 写入结果 topk_weights[:] = topk_vals topk_ids[:] = topk_idx token_expert_indices[:] = (torch.arange(k, device=gating_output.device).unsqueeze(0) * batch + torch.arange(batch, device=gating_output.device).unsqueeze(1)) # renormalize if renormalize: denom = topk_weights.sum(dim=1, keepdim=True) denom = torch.where(denom > 0, denom, torch.ones_like(denom)) topk_weights[:] = topk_weights / denom def topk_sigmoid( topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool = False, e_score_correction_bias: torch.Tensor | None = None, ) -> None: # torch.ops._moe_C.topk_sigmoid( # topk_weights, # topk_ids, # token_expert_indices, # gating_output, # renormalize, # e_score_correction_bias, # ) topk_sigmoid_torch(topk_weights, topk_ids, token_expert_indices, gating_output, renormalize, e_score_correction_bias) def grouped_topk( scores: torch.Tensor, num_expert_group: int, topk_group: int, topk: int, renormalize: bool, routed_scaling_factor: float, bias: torch.Tensor, scoring_func: int = 0, ): """ Perform grouped top-k routing for mixture of experts. Args: scores: Raw inputs (logits if scoring_func=1, scores if scoring_func=0) num_expert_group: Number of expert groups topk_group: Number of groups to select topk: Number of experts to select per token renormalize: Whether to renormalize the output weights routed_scaling_factor: Scaling factor for routing weights bias: Bias tensor (e_score_correction_bias). Always fused in kernel. scoring_func: 0=none (no activation), 1=sigmoid """ if not current_platform.is_cuda(): raise NotImplementedError( "The fused grouped_topk kernel is only available on CUDA platforms" ) return torch.ops._moe_C.grouped_topk( scores, num_expert_group, topk_group, topk, renormalize, routed_scaling_factor, bias, scoring_func, ) def moe_wna16_marlin_gemm( input: torch.Tensor, output: torch.Tensor | None, b_qweight: torch.Tensor, b_bias: torch.Tensor | None, b_scales: torch.Tensor, a_scales: torch.Tensor | None, global_scale: torch.Tensor | None, b_qzeros: torch.Tensor | None, g_idx: torch.Tensor | None, perm: torch.Tensor | None, workspace: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_past_padded: torch.Tensor, topk_weights: torch.Tensor, moe_block_size: int, top_k: int, mul_topk_weights: bool, b_q_type: ScalarType, size_m: int, size_n: int, size_k: int, is_k_full: bool, use_atomic_add: bool, use_fp32_reduce: bool, is_zp_float: bool, thread_k: int = -1, thread_n: int = -1, blocks_per_sm: int = -1, ) -> torch.Tensor: return torch.ops._moe_C.moe_wna16_marlin_gemm( input, output, b_qweight, b_bias, b_scales, a_scales, global_scale, b_qzeros, g_idx, perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights, moe_block_size, top_k, mul_topk_weights, b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float, thread_k, thread_n, blocks_per_sm, ) if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): @register_fake("_moe_C::marlin_gemm_moe") def marlin_gemm_moe_fake( a: torch.Tensor, b_q_weights: torch.Tensor, sorted_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, b_scales: torch.Tensor, b_zero_points: torch.Tensor, g_idx: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, b_q_type: ScalarType, size_m: torch.SymInt, size_n: torch.SymInt, size_k: torch.SymInt, is_k_full: bool, num_experts: int, topk: int, moe_block_size: int, replicate_input: bool, apply_weights: bool, ) -> torch.Tensor: return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device) @register_fake("_moe_C::moe_wna16_marlin_gemm") def moe_wna16_marlin_gemm_fake( input: torch.Tensor, output: torch.Tensor | None, b_qweight: torch.Tensor, b_bias: torch.Tensor | None, b_scales: torch.Tensor, a_scales: torch.Tensor | None, global_scale: torch.Tensor | None, b_qzeros: torch.Tensor | None, g_idx: torch.Tensor | None, perm: torch.Tensor | None, workspace: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_past_padded: torch.Tensor, topk_weights: torch.Tensor, moe_block_size: int, top_k: int, mul_topk_weights: bool, b_q_type: ScalarType, size_m: int, size_n: int, size_k: int, is_k_full: bool, use_atomic_add: bool, use_fp32_reduce: bool, is_zp_float: bool, ): return torch.empty( (size_m * top_k, size_n), dtype=input.dtype, device=input.device ) def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, ) -> None: torch.ops._C_cache_ops.reshape_and_cache( key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale, v_scale, ) def reshape_and_cache_flash( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, ) -> None: # torch.ops._C_cache_ops.reshape_and_cache_flash( # key, # value, # key_cache, # value_cache, # slot_mapping, # kv_cache_dtype, # k_scale, # v_scale, # ) ops.reshape_and_cache_flash( key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, 1.0, 1.0, ) def concat_and_cache_mla( kv_c: torch.Tensor, k_pe: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, scale: torch.Tensor, ) -> None: # torch.ops._C_cache_ops.concat_and_cache_mla( # kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale # ) ops.vllm_concat_and_cache_mla( kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale ) def concat_and_cache_mla_int8( kv_c_int8: torch.Tensor, kv_c_scale: torch.Tensor, k_pe_int8: torch.Tensor, k_pe_scale: torch.Tensor, kv_cache: torch.Tensor, kv_cache_scale: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, scale: torch.Tensor, ) -> None: ops.vllm_concat_and_cache_mla_int8( kv_c_int8, kv_c_scale, k_pe_int8, k_pe_scale, kv_cache, kv_cache_scale, slot_mapping, kv_cache_dtype, scale, ) def concat_and_cache_mla_rope_fused( positions: torch.Tensor, q_pe: torch.Tensor, k_pe: torch.Tensor, kv_c: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, slot_mapping: torch.Tensor, kv_cache: torch.Tensor, kv_cache_dtype: str, kv_cache_scale: torch.Tensor, ) -> None: torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused( positions, q_pe, k_pe, kv_c, cos_sin_cache, is_neox, slot_mapping, kv_cache, kv_cache_dtype, kv_cache_scale, ) def swap_blocks( src: torch.Tensor, dst: torch.Tensor, block_size_in_bytes: int, block_mapping: torch.Tensor, ) -> None: """ Copy specific blocks from one tensor to another. This method assumes each of the two input tensors is composed of consecutive contiguous blocks, of size block_size_in_bytes. i.e. the memory layout for each tensor is: [block0] [block1] ... [block N] block_mapping determines the subset of blocks to copy of the source tensor, and their matching destination block number on the destination tensor. block_mapping is expected to be a tensor of shape (num_blocks_to_copy, 2) where each block_mapping[i] represents a single copy operation, copying block #block_mapping[i][0] from the source tensor to block #block_mapping[i][1] on the destination tensor. block_mapping should have dtype int64. The source and the destination tensors can be either on cpu or gpu, but not both on cpu. the block mapping tensor must on cpu. """ # torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping) ops.vllm_swap_blocks(src, dst, block_mapping) def convert_fp8( output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" ) -> None: torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) def gather_and_maybe_dequant_cache( src_cache: torch.Tensor, dst: torch.Tensor, block_table: torch.Tensor, cu_seq_lens: torch.Tensor, token_to_seq: torch.Tensor, num_tokens: int, kv_cache_dtype: str, scale: torch.Tensor, seq_starts: torch.Tensor | None = None, ) -> None: torch.ops._C_cache_ops.gather_and_maybe_dequant_cache( src_cache, dst, block_table, cu_seq_lens, token_to_seq, num_tokens, kv_cache_dtype, scale, seq_starts, ) def cp_gather_cache( src_cache: torch.Tensor, dst: torch.Tensor, block_table: torch.Tensor, cu_seq_lens: torch.Tensor, batch_size: int, seq_starts: torch.Tensor | None = None, ) -> None: # torch.ops._C_cache_ops.cp_gather_cache( # src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts # ) ops.vllm_cp_gather_cache( src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts ) def cp_gather_and_upconvert_fp8_kv_cache( src_cache: torch.Tensor, dst: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor, workspace_starts: torch.Tensor, batch_size: int, ) -> None: """Gather and upconvert FP8 KV cache to BF16 workspace. Args: src_cache: FP8 KV cache [num_blocks, block_size, 656] dst: BF16 output workspace [total_tokens, 576] block_table: Block indices [num_reqs, max_blocks] seq_lens: Sequence lengths [num_reqs] workspace_starts: Workspace start offsets [num_reqs] batch_size: Number of requests """ torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache( src_cache, dst, block_table, seq_lens, workspace_starts, batch_size ) def indexer_k_quant_and_cache( k: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, quant_block_size: int, kv_cache_dtype: str, ) -> None: torch.ops._C_cache_ops.indexer_k_quant_and_cache( k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype ) def cp_gather_indexer_k_quant_cache( kv_cache: torch.Tensor, dst_k: torch.Tensor, dst_scale: torch.Tensor, block_table: torch.Tensor, cu_seq_lens: torch.Tensor, ) -> None: torch.ops._C_cache_ops.cp_gather_indexer_k_quant_cache( kv_cache, dst_k, dst_scale, block_table, cu_seq_lens ) def get_device_attribute(attribute: int, device: int) -> int: return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) def get_max_shared_memory_per_block_device_attribute(device: int) -> int: # ruff: noqa: E501 # return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( # device # ) return 32 * 1024 # custom ar def init_custom_ar( ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor, rank: int, fully_connected: bool, ) -> int: return torch.ops._C_custom_ar.init_custom_ar( ipc_tensors, rank_data, rank, fully_connected ) def all_reduce( fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, reg_buffer_sz_bytes: int, ) -> None: torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) def dispose(fa: int) -> None: torch.ops._C_custom_ar.dispose(fa) def meta_size() -> int: return torch.ops._C_custom_ar.meta_size() def register_buffer(fa: int, ipc_tensors: list[int]) -> None: return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]: return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) def register_graph_buffers( fa: int, handles: list[list[int]], offsets: list[list[int]] ) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]: return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size) def open_mem_handle(mem_handle: torch.Tensor): return torch.ops._C_custom_ar.open_mem_handle(mem_handle) def free_shared_buffer(ptr: int) -> None: torch.ops._C_custom_ar.free_shared_buffer(ptr) # quick all reduce def init_custom_qr(rank: int, world_size: int, qr_max_size: int | None = None) -> int: return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size) def qr_destroy(fa: int) -> None: torch.ops._C_custom_ar.qr_destroy(fa) def qr_all_reduce( fa: int, inp: torch.Tensor, out: torch.Tensor, quant_level: int, cast_bf2half: bool = False, ) -> None: torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half) def qr_get_handle(fa: int) -> torch.Tensor: return torch.ops._C_custom_ar.qr_get_handle(fa) def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None: return torch.ops._C_custom_ar.qr_open_handles(fa, handles) def qr_max_size() -> int: return torch.ops._C_custom_ar.qr_max_size() def get_flash_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, num_heads_k: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: cache_seqlens: (batch_size), dtype torch.int32. num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. num_heads_k: num_heads_k. Return: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ return torch.ops._C.get_flash_mla_metadata( cache_seqlens, num_heads_per_head_k, num_heads_k ) def flash_mla_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: float | None = None, causal: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: q: (batch_size, seq_len_q, num_heads_q, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. head_dim_v: Head_dim of v. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache( q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, ) return out, softmax_lse def sm100_cutlass_mla_decode( out: torch.Tensor, lse: torch.Tensor, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, page_table: torch.Tensor, workspace: torch.Tensor, scale: float, num_kv_splits: int, ) -> torch.Tensor: torch.ops._C.sm100_cutlass_mla_decode( out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, scale, num_kv_splits, ) return out def sm100_cutlass_mla_get_workspace_size( max_seq_len: int, num_batches: int, sm_count: int, num_kv_splits: int ) -> int: return torch.ops._C.sm100_cutlass_mla_get_workspace_size( max_seq_len, num_batches, sm_count, num_kv_splits ) def dsv3_fused_a_gemm( output: torch.Tensor, mat_a: torch.Tensor, mat_b: torch.Tensor, ) -> None: """DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens). Computes output = mat_a @ mat_b.T where: mat_a: [num_tokens, 7168] row-major bf16 (hidden states) mat_b: [7168, 2112] column-major bf16 (weight transposed) output: [num_tokens, 2112] row-major bf16 Optimized for the DeepSeek V2/V3 QKV A-projection at small batch sizes. Requires SM 9.0+ (Hopper). """ torch.ops._C.dsv3_fused_a_gemm(output, mat_a, mat_b) if hasattr(torch.ops._C, "weight_packed_linear"): @register_fake("_C::weight_packed_linear") def weight_packed_linear_fake( mat1: torch.Tensor, mat2: torch.Tensor, bias: torch.Tensor | None, is_vnni: bool, ) -> torch.Tensor: return torch.empty( (mat1.size(0), mat2.size(0)), dtype=mat1.dtype, device=mat2.device ) if hasattr(torch.ops._C, "fused_experts_cpu"): @register_fake("_C::fused_experts_cpu") def fused_experts_cpu_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool, use_int8_w8a8: bool, use_fp8_w8a16: bool, w1_scale: torch.Tensor | None, w2_scale: torch.Tensor | None, block_size: list[int] | None, a1_scale: torch.Tensor | None, a2_scale: torch.Tensor | None, is_vnni: bool, ) -> torch.Tensor: return torch.empty_like(hidden_states) if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): @register_fake("_C::int8_scaled_mm_with_quant") def int8_scaled_mm_with_quant_fake( mat1: torch.Tensor, mat2: torch.Tensor, scales2: torch.Tensor, bias: torch.Tensor | None, out_dtype: torch.dtype, is_vnni: bool, ) -> torch.Tensor: M = mat1.size(0) N = mat2.size(0) return torch.empty((M, N), dtype=out_dtype) # Add our new features here.. # moe def invoke_fused_moe_kernel( A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A_scale: torch.Tensor | None, B_scale: torch.Tensor | None, topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, config: dict[str, "any"], compute_type, use_fp8_w8a8: bool, use_int8_w8a16: bool, block_shape: list[int] | None = None, bias: torch.Tensor | None = None, ) -> None: ops.vllm_invoke_fused_moe_kernel( A, B, C, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, mul_routed_weight, top_k, config["BLOCK_SIZE_M"], bias=bias ) # broadcast class Async_helper: # For now, the comm and the other kernels are in the same stream, so we can remove the stream wait.. def wait( self, ): return True def broadcast(tensor, src=0, group=None, async_op=False): cdist.broadcast(tensor, src, group, async_op=True) if async_op: return Async_helper() else: pass # w8a16 def linear_w8a16( x: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, group_size: int = -1, format: str = "TN", ) -> torch.Tensor: return ops.w8a16(x, qweight, scales, format="TN", group_size=group_size) ## lora sgmv / bgmv def sbgmv_expand( x: torch.Tensor, w_t_all: torch.Tensor, y: torch.Tensor, b_seq_start_loc: torch.Tensor = None, seq_len_tensor: torch.Tensor = None, lora_indices_tensor: torch.Tensor = None, batches: int = -1, max_seq_length: int = -1, token_nums: int = -1, add_input=True, ): """ x: inputs w_t_all: lora weight y: output y += x@wt_t_all """ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32] assert w_t_all.dtype in [ torch.float16, torch.bfloat16, ] assert x.is_contiguous() # assert y.is_contiguous() if x.dtype == torch.float: x = x.to(w_t_all.dtype) if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank) assert w_t_all.size(1) == 1 w_t_all = w_t_all.squeeze(dim=1) else: assert w_t_all.ndim == 3 # shape:(lora_num,size,rank) assert w_t_all.is_contiguous() assert add_input == True lora_indices = lora_indices_tensor.cpu().tolist() lora_num = w_t_all.shape[0] ## 单一lora model, 且所有request均使用lora if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices): if lora_indices[0] != -1: w_t = w_t_all[0] y += torch.matmul(x, w_t.t()) ## 多个lora model else: ## prefill if batches != -1: for i, lora_id, start, seq_len in zip( range(batches), lora_indices, b_seq_start_loc, seq_len_tensor ): if lora_id != -1: xi = x[start : start + seq_len] w_t = w_t_all[lora_id] y[start : start + seq_len] += xi @ w_t.t() ## decode else: batches = x.shape[0] for i, lora_id in zip(range(batches), lora_indices): if lora_id != -1: xi = x[i].unsqueeze(0) w_t = w_t_all[lora_id] y[i] += (xi @ w_t.t()).squeeze(0) return y def sbgmv_shrink( x: torch.Tensor, w_t_all: torch.Tensor, y: torch.Tensor, b_seq_start_loc: torch.Tensor = None, seq_len_tensor: torch.Tensor = None, lora_indices_tensor: torch.Tensor = None, batches: int = -1, max_seq_length: int = -1, token_nums: int = -1, scale: float = 1.0, ): """ xx: inputs w_t_all: lora weight y: output scale: float y = x@w_t_all * scale """ assert x.dtype == w_t_all.dtype assert x.dtype in [torch.float16, torch.bfloat16] assert x.is_contiguous() assert y.is_contiguous() if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank) assert w_t_all.size(1) == 1 w_t_all = w_t_all.squeeze(dim=1) else: assert w_t_all.ndim == 3 # shape:(lora_num,size,rank) assert w_t_all.is_contiguous() lora_num = w_t_all.shape[0] lora_indices = lora_indices_tensor.cpu().tolist() ## 单一lora model, 且所有request均使用lora if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices): if lora_indices[0] != -1: w_t = w_t_all[0] y = torch.matmul(x, w_t.t()) * scale ## 多个lora model else: ## prefill if batches != -1: for i, lora_id, start, seq_len in zip( range(batches), lora_indices, b_seq_start_loc, seq_len_tensor ): if lora_id != -1: xi = x[start : start + seq_len] w_t = w_t_all[lora_id] y[start : start + seq_len] = (xi @ w_t.t()) * scale ## decode else: batches = x.shape[0] for i, lora_id in zip(range(batches), lora_indices): if lora_id != -1: xi = x[i].unsqueeze(0) w_t = w_t_all[lora_id] y[i] = (xi @ w_t.t()).squeeze(0) * scale return y def dynamic_scaled_quant_dynamic_int8(x, input_scales=None, int8_out=None, scales=None): return ops.dynamic_scaled_quant_smoothquant(x, input_scales, int8_out, scales) class CPUDNNLGEMMHandler: def __init__(self) -> None: self.handler_tensor: torch.Tensor | None = None self.n = -1 self.k = -1 def __del__(self): if self.handler_tensor is not None: torch.ops._C.release_dnnl_matmul_handler(self.handler_tensor.item()) _supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler")) def is_onednn_acl_supported(): return torch.ops._C.is_onednn_acl_supported() def create_onednn_mm( weight: torch.Tensor, # [K, N] primitive_cache_size: int = 128, ) -> CPUDNNLGEMMHandler: handler = CPUDNNLGEMMHandler() handler.k, handler.n = weight.size() # store the handler pointer in a tensor it doesn't get inlined handler.handler_tensor = torch.tensor( torch.ops._C.create_onednn_mm_handler(weight, primitive_cache_size), dtype=torch.int64, ) return handler def onednn_mm( dnnl_handler: CPUDNNLGEMMHandler, x: torch.Tensor, bias: torch.Tensor | None, ) -> torch.Tensor: output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype) torch.ops._C.onednn_mm( output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler_tensor ) return output def create_onednn_scaled_mm( weight: torch.Tensor, # [K, N] weight_scales: torch.Tensor, output_type: torch.dtype, dynamic_quant: bool, use_azp: bool, primitive_cache_size: int = 128, ) -> CPUDNNLGEMMHandler: handler = CPUDNNLGEMMHandler() handler.k, handler.n = weight.size() # store the handler pointer in a tensor so it doesn't get inlined handler.handler_tensor = torch.tensor( torch.ops._C.create_onednn_scaled_mm_handler( weight, weight_scales, output_type, dynamic_quant, use_azp, primitive_cache_size, ), dtype=torch.int64, ) return handler def onednn_scaled_int8_quant( input: torch.Tensor, scale: torch.Tensor | None = None, azp: torch.Tensor | None = None, symmetric: bool = True, ): """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. Args: input: The input tensor to be quantized to int8. scale: Optional scaling factor for the int8 quantization. When not provided, we invoke dynamic-per-token quantization. azp: Optional zero-point for the int8 quantization. Must be provided for asymmetric quantization if `scale` is provided. symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) token_num = input.numel() // input.shape[-1] input = input.view((token_num, input.shape[-1])) if scale is not None: # static-per-tensor quantization. assert symmetric == (azp is None), ( "azp must only be provided for asymmetric quantization." ) torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. input_scales = torch.empty((token_num, 1), device=input.device, dtype=torch.float32) input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) return output, input_scales, input_azp def onednn_scaled_mm( dnnl_handler: CPUDNNLGEMMHandler, x: torch.Tensor, output: torch.Tensor, input_scale: torch.Tensor | None, input_zp: torch.Tensor | None, input_zp_adj: torch.Tensor | None, bias: torch.Tensor | None, ) -> torch.Tensor: torch.ops._C.onednn_scaled_mm( output, x, input_scale, input_zp, input_zp_adj, bias, dnnl_handler.handler_tensor, ) return output def cpu_attn_get_scheduler_metadata( num_reqs: int, num_heads: int, num_kv_heads: int, head_dim: int, seq_lens: torch.Tensor, dtype: torch.dtype, query_start_loc: torch.Tensor, causal: bool, sliding_window_size: int, isa: str, enable_kv_split: bool, ) -> torch.Tensor: sheduler_metadata = torch.ops._C.get_scheduler_metadata( num_reqs, num_heads, num_kv_heads, head_dim, seq_lens, dtype, query_start_loc, causal, sliding_window_size, isa, enable_kv_split, ) return sheduler_metadata def cpu_attn_reshape_and_cache( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, isa: str, ) -> None: torch.ops._C.cpu_attn_reshape_and_cache( key, value, key_cache, value_cache, slot_mapping, isa, ) def cpu_attention_with_kv_cache( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, output: torch.Tensor, query_start_loc: torch.Tensor, seq_lens: torch.Tensor, scale: float, causal: bool, alibi_slopes: torch.Tensor | None, sliding_window: tuple[int, int], block_table: torch.Tensor, softcap: float, scheduler_metadata: torch.Tensor, s_aux: torch.Tensor | None, ) -> None: torch.ops._C.cpu_attention_with_kv_cache( query, key_cache, value_cache, output, query_start_loc, seq_lens, scale, causal, alibi_slopes, sliding_window[0], sliding_window[1], block_table, softcap, scheduler_metadata, s_aux, ) def cpu_gemm_wna16( input: torch.Tensor, q_weight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor | None, g_idx: torch.Tensor | None, bias: torch.Tensor | None, pack_factor: int, isa_hint: str, ) -> torch.Tensor: output = torch.empty((input.size(0), scales.size(1)), dtype=input.dtype) torch.ops._C.cpu_gemm_wna16( input, q_weight, output, scales, zeros, g_idx, bias, pack_factor, isa_hint, ) return output def cpu_prepack_moe_weight( weight: torch.Tensor, isa: str, ) -> torch.Tensor: output = torch.empty_like(weight) torch.ops._C.prepack_moe_weight(weight, output, isa) return output def cpu_fused_moe( input: torch.Tensor, w13: torch.Tensor, w2: torch.Tensor, w13_bias: torch.Tensor | None, w2_bias: torch.Tensor | None, topk_weights: torch.Tensor, topk_ids: torch.Tensor, act: str, isa: str, skip_weighted: bool = False, ) -> torch.Tensor: output = torch.empty_like(input) torch.ops._C.cpu_fused_moe( output, input, w13, w2, w13_bias, w2_bias, topk_weights, topk_ids, skip_weighted, act, isa, ) return output if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"): @register_fake("_qutlass_C::matmul_mxf4_bf16_tn") def _fake_matmul_mxf4_bf16_tn( a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, b_sf: torch.Tensor, alpha: torch.Tensor, ): return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16) def matmul_mxf4_bf16_tn( a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, b_sf: torch.Tensor, alpha: torch.Tensor, ) -> torch.Tensor: return torch.ops._qutlass_C.matmul_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha) if hasattr(torch.ops._qutlass_C, "matmul_ada_mxf4_bf16_tn"): @register_fake("_qutlass_C::matmul_ada_mxf4_bf16_tn") def _fake_matmul_ada_mxf4_bf16_tn( a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, b_sf: torch.Tensor, alpha: torch.Tensor, ): return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16) def matmul_ada_mxf4_bf16_tn( a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, b_sf: torch.Tensor, alpha: torch.Tensor, ) -> torch.Tensor: return torch.ops._qutlass_C.matmul_ada_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha) if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxQuest"): @register_fake("_qutlass_C::fusedQuantizeMxQuest") def _fake_fused_quantize_mx_quest( a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor ): return xh_e2m1, xh_e8m0 if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxAbsMax"): @register_fake("_qutlass_C::fusedQuantizeMxAbsMax") def _fake_fused_quantize_mx_absmax( a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor ): return xh_e2m1, xh_e8m0 def fusedQuantizeMx( a: torch.Tensor, b: torch.Tensor, *, method: Literal["quest", "abs_max"] = "quest" ) -> tuple[torch.Tensor, torch.Tensor]: if a.dim() == 0: raise ValueError("`a` must have at least 1 dimension.") if a.size(-1) % 32 != 0: raise ValueError(f"last dim of `a` must be divisible by 32, got {a.size(-1)}.") if b.device != a.device: raise ValueError("`a` and `b` must be on the same device.") xh_e2m1 = torch.empty( *a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device ) rows, cols = a.numel() // a.size(-1), a.size(-1) // 32 n_row_blocks = cdiv(rows, 128) n_col_blocks = cdiv(cols, 4) padded_rows = n_row_blocks * 128 padded_cols = n_col_blocks * 4 xh_e8m0 = torch.empty( padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=a.device ) if not hasattr(torch.ops, "_qutlass_C"): raise RuntimeError( "The `_qutlass_C` extension is not loaded. " "Make sure your custom op library is imported before calling fusedQuantizeMx." ) if method == "quest": return torch.ops._qutlass_C.fusedQuantizeMxQuest(a, b, xh_e2m1, xh_e8m0) elif method == "abs_max": return torch.ops._qutlass_C.fusedQuantizeMxAbsMax(a, b, xh_e2m1, xh_e8m0) else: raise ValueError(f"invalid method {method!r}, must be 'quest' or 'abs_max'") if hasattr(torch.ops._qutlass_C, "fusedQuantizeNv"): @register_fake("_qutlass_C::fusedQuantizeNv") def _fake_fused_quantize_nv( a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e4m3: torch.Tensor, global_scale: torch.Tensor, ): return xh_e2m1, xh_e4m3 def fusedQuantizeNv( a: torch.Tensor, b: torch.Tensor, global_scale: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: xh_e2m1 = torch.empty( *a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device ) rows, cols = a.numel() // a.size(-1), a.size(-1) // 16 n_row_blocks = cdiv(rows, 128) n_col_blocks = cdiv(cols, 4) padded_rows = n_row_blocks * 128 padded_cols = n_col_blocks * 4 xh_e4m3 = torch.empty( padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=a.device ) return torch.ops._qutlass_C.fusedQuantizeNv(a, b, xh_e2m1, xh_e4m3, global_scale) def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor: """ Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832) kernels. Note that these kernels exploit the recursive properties of Sylvester Hadamards, and therefore do not require transform weight data Note that sylvester hadamard transforms are also symmetric, which means that this function is also applies the (transpose <=> inverse) transform. :param x: value to be transformed inplace :param inplace: modify value in place :return: value after transformation """ return torch.ops._C.hadacore_transform(x, inplace) if hasattr(torch.ops._C, "hadacore_transform"): @register_fake("_C::hadacore_transform") def _hadacore_transform_fake(x: torch.Tensor, inplace: bool) -> torch.Tensor: return torch.empty_like(x) if not inplace else x def gather_cache( src_cache: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] dst: torch.Tensor, # [TOT_TOKENS, ENTRIES...] block_table: torch.Tensor, # [BATCH, BLOCK_INDICES] cu_seq_lens: torch.Tensor, # [BATCH+1] batch_size: int, seq_starts: torch.Tensor = None, ): ops.vllm_gather_cache( src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts ) def gather_cache_int8( src_cache: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] src_cache_scale: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, 2] kv_lora_rank: int, dst: torch.Tensor, # [TOT_TOKENS, ENTRIES...] block_table: torch.Tensor, # [BATCH, BLOCK_INDICES] cu_seq_lens: torch.Tensor, # [BATCH+1] batch_size: int, seq_starts: torch.Tensor = None, ): ops.vllm_gather_cache_int8( src_cache, src_cache_scale, kv_lora_rank, dst, block_table, cu_seq_lens, batch_size, seq_starts, ) def rejection_greedy_sample_torch( output_token_ids: torch.Tensor, # [batch_size, max_spec_len + 1] cu_num_draft_tokens: torch.Tensor, # [batch_size] (前缀和形式) draft_token_ids: torch.Tensor, # [num_tokens] target_argmax: torch.Tensor, # [num_tokens] bonus_token_ids: torch.Tensor, # [batch_size] is_greedy: torch.Tensor = None, # [batch_size] 或 None ): """ 完全等价于 rejection_greedy_sample_kernel 的 PyTorch 实现 接口参数与 Triton 核完全一致 """ batch_size = output_token_ids.size(0) device = output_token_ids.device # 处理 is_greedy 为 None 的情况(保持与 Triton 核相同行为) if is_greedy is None: is_greedy_mask = torch.ones(batch_size, dtype=torch.bool, device=device) else: is_greedy_mask = is_greedy.to(device) for req_idx in range(batch_size): if not is_greedy_mask[req_idx]: continue # 非贪婪请求直接跳过 # 计算当前请求的token范围(前缀和转实际数量) start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1] end_idx = cu_num_draft_tokens[req_idx] num_draft_tokens = end_idx - start_idx rejected = False for pos in range(num_draft_tokens): if not rejected: global_pos = start_idx + pos draft_token = draft_token_ids[global_pos] target_token = target_argmax[global_pos] # 存储目标token(与Triton核完全一致的行为) output_token_ids[req_idx, pos] = target_token # 检查是否拒绝 if draft_token != target_token: rejected = True # 全部接受时追加bonus token if not rejected and num_draft_tokens < output_token_ids.size(1): output_token_ids[req_idx, num_draft_tokens] = bonus_token_ids[req_idx] return output_token_ids # 原位修改 def rejection_random_sample_torch( output_token_ids: torch.Tensor, # [batch_size, max_spec_len + 1] cu_num_draft_tokens: torch.Tensor, # [batch_size] (前缀和形式) draft_token_ids: torch.Tensor, # [num_tokens] draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] 或 None target_probs: torch.Tensor, # [num_tokens, vocab_size] bonus_token_ids: torch.Tensor, # [batch_size] recovered_token_ids: torch.Tensor, # [num_tokens] uniform_probs: torch.Tensor, # [num_tokens] (0~1均匀分布) is_greedy: torch.Tensor | None, # [batch_size] 或 None NO_DRAFT_PROBS: bool = False, # 是否忽略draft_probs ): batch_size = output_token_ids.size(0) max_spec_len_plus_1 = output_token_ids.size(1) device = output_token_ids.device # 处理 is_greedy 为 None 的情况 if is_greedy is None: is_greedy = torch.zeros(batch_size, dtype=torch.bool, device=device) else: is_greedy = is_greedy.to(device) for req_idx in range(batch_size): if is_greedy[req_idx]: continue # 贪婪采样请求直接跳过 # 计算当前请求的token范围 start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1] end_idx = cu_num_draft_tokens[req_idx] num_draft_tokens = end_idx - start_idx rejected = False for pos in range(num_draft_tokens): if not rejected: global_pos = start_idx + pos draft_token_id = draft_token_ids[global_pos] # 获取draft概率 (处理NO_DRAFT_PROBS情况) if NO_DRAFT_PROBS: draft_prob = 1.0 else: assert ( draft_probs is not None ), "draft_probs不能为None当NO_DRAFT_PROBS=False" draft_prob = draft_probs[global_pos, draft_token_id] # 获取target概率和均匀随机数 target_prob = target_probs[global_pos, draft_token_id] uniform_prob = uniform_probs[global_pos] # 拒绝采样逻辑 if draft_prob > 0 and (target_prob / draft_prob) >= uniform_prob: # 接受draft token output_token_ids[req_idx, pos] = draft_token_id else: # 拒绝并使用恢复的token rejected = True output_token_ids[req_idx, pos] = recovered_token_ids[global_pos] # 如果全部接受则追加bonus token if not rejected and num_draft_tokens < max_spec_len_plus_1: output_token_ids[req_idx, num_draft_tokens] = bonus_token_ids[req_idx] return output_token_ids weak_ref_tensor = ops.weak_ref_tensor def indexer_k_cache(k: torch.Tensor, kv_cache: torch.Tensor,slot_mapping: torch.Tensor)-> None: num_tokens, head_dim = k.shape _, block_size, cache_stride = kv_cache.shape assert head_dim == cache_stride for i in range(num_tokens): block_idx = torch.div(slot_mapping[i], block_size, rounding_mode="floor") block_offset = slot_mapping[i] % block_size kv_cache[block_idx, block_offset, :] = k[i] def ref_mqa_logits( q: torch.Tensor, # [num_tokens, n_head, head_dim] - 可能已量化 k: torch.Tensor, # [num_blocks, block_size, head_dim] 或展开形式 - 可能已量化 weights: torch.Tensor, # [num_tokens, n_head, 1] - 权重 cu_seqlen_ks: torch.Tensor, # 序列起始位置 cu_seqlen_ke: torch.Tensor, # 序列结束位置 ) -> torch.Tensor: """ 多查询注意力logits计算的PyTorch等价实现 """ M, H, D = q.shape N = k.shape[0] device = q.device # 初始化输出logits [M, N] logits = torch.full((M, N), -float('inf'), device=device, dtype=torch.float32) for i in range(M): seq_start = cu_seqlen_ks[i] seq_end = cu_seqlen_ke[i] if seq_start >= seq_end: continue #当前查询的Q [H, D] q_i = q[i] # [H, D] seq_k = k[seq_start:seq_end] # [seq_len, head_dim] # 计算注意力分数 [H, seq_len] attention_scores = torch.matmul(q_i, seq_k.T) # BF16计算 attention_scores = F.relu(attention_scores) # 应用权重 [H, seq_len] attention_scores_f32 = attention_scores.float() weights_i = weights[i].unsqueeze(1) # [H, 1] weighted_scores = attention_scores_f32 * weights_i # [H, seq_len] # 汇总所有头的logits [seq_len] logits_i = torch.sum(weighted_scores, dim=0) # [seq_len] # 将结果填充到输出logits的对应位置 logits[i, seq_start:seq_end] = logits_i return logits def ref_paged_mqa_logits( q: torch.Tensor, kv_cache: torch.Tensor, weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, max_model_len: int, clean_logits: bool = True ) -> torch.Tensor: """使用分页KV缓存计算FP8多查询注意力logits的PyTorch实现 Args: q: 查询张量 [B, next_n, H, D] kv_cache: 分页KV缓存 [num_blocks, block_size, 1, D] weights: 权重张量 [B * next_n, H], dtype=torch.float32 context_lens: 上下文长度 [B], dtype=int32 block_tables: 块映射表 [B, max_blocks], dtype=int32 schedule_metadata: 调度元数据 max_model_len: 最大序列长度,用于确定输出logits大小 Returns: Logits张量 [B * next_n, max_model_len], dtype=torch.float32 """ def reassemble_k_from_paged_cache( kv_cache: torch.Tensor, block_table: torch.Tensor, context_len: int, head_dim: int, block_size: int ) -> torch.Tensor: """从分页缓存中重组K值""" num_blocks_needed = (context_len + block_size - 1) // block_size valid_blocks = block_table[:num_blocks_needed] device = kv_cache.device # 初始化输出K序列 [context_len, head_dim] k_sequence = torch.zeros(context_len, head_dim, device=device, dtype=kv_cache.dtype) token_offset = 0 for block_idx in valid_blocks: if block_idx < 0: break # 当前块中的token数量 tokens_in_block = min(block_size, context_len - token_offset) if tokens_in_block <= 0: break # 从缓存块中提取K值 block_data = kv_cache[block_idx] # [block_size, 1, D] # 提取K值 k_sequence[token_offset:token_offset + tokens_in_block] = block_data[:tokens_in_block, 0, :head_dim] # [tokens_in_block, D] token_offset += tokens_in_block return k_sequence def compute_mqa_logits( q: torch.Tensor, # [next_n, H, D] k: torch.Tensor, # [context_len, D] weights: torch.Tensor, # [next_n, H] context_len: int, max_model_len: int ) -> torch.Tensor: """计算多查询注意力logits""" next_n, H, D = q.shape device = q.device # 初始化批次logits [next_n, max_model_len] batch_logits = torch.full((next_n, max_model_len), -float('inf'), device=device, dtype=torch.float32) # 扩展K以匹配多头 [context_len, H, D] k_expanded = k.unsqueeze(1).expand(-1, H, -1) # [context_len, H, D] # 转置以便矩阵乘法 q_transposed = q.transpose(0, 1) # [H, next_n, D] k_transposed = k_expanded.transpose(0, 1) # [H, context_len, D] # 批量计算注意力分数 [H, next_n, context_len] attention_scores = torch.bmm(q_transposed, k_transposed.transpose(1, 2)) # [H, next_n, context_len] attention_scores = F.relu(attention_scores) # 应用权重并汇总所有头 [next_n, context_len] weights_expanded = weights.transpose(0, 1).unsqueeze(2) # [H, next_n, 1] weighted_scores = attention_scores * weights_expanded # [H, next_n, context_len] logits_per_token = weighted_scores.sum(dim=0) # [next_n, context_len] # 填充到输出logits中 batch_logits[:, :context_len] = logits_per_token return batch_logits def clean_logits_tensor( logits: torch.Tensor, context_lens: torch.Tensor, next_n: int, max_model_len: int ) -> torch.Tensor: """清理logits张量,将超出上下文长度的位置设为负无穷""" B = len(context_lens) for batch_idx in range(B): context_len = context_lens[batch_idx].item() if context_len >= max_model_len: continue # 当前批次在logits中的位置 batch_start = batch_idx * next_n batch_end = (batch_idx + 1) * next_n # 将超出上下文长度的位置设为负无穷 logits[batch_start:batch_end, context_len:] = -float('inf') return logits B, next_n, H, D = q.shape num_blocks, block_size, _, cache_stride = kv_cache.shape device = q.device # 初始化输出logits [B * next_n, max_model_len] logits = torch.full((B * next_n, max_model_len), -float('inf'), device=device, dtype=torch.float32) # 处理每个批次 for batch_idx in range(B): context_len = context_lens[batch_idx].item() if context_len == 0: continue # 当前批次的查询 [next_n, H, D] batch_q = q[batch_idx] # [next_n, H, D] # 当前批次的权重 [next_n, H] batch_weights_start = batch_idx * next_n batch_weights_end = (batch_idx + 1) * next_n batch_weights = weights[batch_weights_start:batch_weights_end] # [next_n, H] # 从分页缓存中重组K值 batch_k = reassemble_k_from_paged_cache( kv_cache, block_tables[batch_idx], context_len, D, block_size ) # [context_len, D] # 计算多查询注意力logits batch_logits = compute_mqa_logits( batch_q, batch_k, batch_weights, context_len, max_model_len ) # [next_n, max_model_len] # 填充到输出logits中 logits[batch_weights_start:batch_weights_end] = batch_logits if clean_logits: # 清理logits:将超出上下文长度的位置设为负无穷 logits = clean_logits_tensor(logits, context_lens, next_n, max_model_len) return logits def sparse_prefill_fwd( q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float, d_v: int = 512, ): """ 稀疏注意力预填充内核的PyTorch实现 Args: - q: [s_q, h_q, d_qk], bfloat16 - kv: [s_kv, h_kv, d_qk], bfloat16 - indices: [s_q, h_kv, topk], int32. 无效索引设为-1或>=s_kv - sm_scale: float - d_v: 值向量的维度,只能为512 Returns: - (output, max_logits, lse) - output: [s_q, h_q, d_v], bfloat16 - max_logits: [s_q, h_q], float - lse: [s_q, h_q], float, 以2为底的对数求和指数 """ def ref_masked_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, sm_scale: float, ) -> torch.Tensor: query = query * sm_scale dtype = query.dtype device = query.device attn = torch.einsum("qhd,khd->hqk", query, key) attn = attn.to(torch.float) attn = torch.softmax(attn, dim=-1) value = value.to(torch.float) out = torch.einsum("hqk,khd->qhd", attn, value) out = out.to(device).to(dtype) return out s_q, h_q, d_qk = q.shape s_kv, h_kv, _ = kv.shape _, _, topk = indices.shape device = q.device dtype = q.dtype # 分离K和V k = kv # [s_kv, h_kv, d_qk] v = kv[:, :, :d_v] # [s_kv, h_kv, d_v] # 初始化输出 output = torch.zeros(s_q, h_q, d_v, device=device, dtype=dtype) # 处理每个查询位置 for i in range(s_q): # 当前查询 [h_q, d_qk] q_i = q[i].unsqueeze(0) # [1, h_q, d_qk] # 获取当前查询位置的稀疏索引 [topk] sparse_indices = indices[i, 0] # [topk] # 过滤有效索引 (>=0 且 < s_kv) valid_mask = (sparse_indices >= 0) & (sparse_indices < s_kv) valid_indices = sparse_indices[valid_mask] # 获取有效的K和V valid_k = k[valid_indices].repeat(1, h_q, 1) # [valid_len, h_q, d_qk] valid_v = v[valid_indices].repeat(1, h_q, 1) # [valid_len, h_q, d_v] out = ref_masked_attention( q_i, valid_k, valid_v, sm_scale ) out = out.view(h_q, d_v) output[i].copy_(out, non_blocking=True) return output