from typing import Optional import torch import torch.nn.functional as F import ixformer import ixformer.functions as ixf_F from ixformer._C import ReduceOp from ixformer._C import _distributed as cdist from ixformer._C._distributed import is_initialized, get_default_comm_group from ixformer.contrib.torch.extension import ixformer_torch as ixft from ixformer.contrib.torch.data_type_mapping import torch_to_ixformer_dtype class ops(): # activations @staticmethod def silu_and_mul(output, x): ixf_F.silu_and_mul(x, output) @staticmethod def gelu_and_mul(output, x): ixf_F.gelu_and_mul(x, output) @staticmethod def gelu_new(output, x): return F.gelu(x,"tanh") @staticmethod def gelu_fast(output, x): return F.gelu(x,"tanh") # rms norm @staticmethod def rms_norm(output, x, weight, epsilon): ixf_F.rms_norm(x, weight, output, epsilon) @staticmethod def fused_add_rms_norm(input, residual, weight, epsilon, scale): ixf_F.fused_add_rms_norm(input, residual, weight, epsilon, scale) # rotary embedding @staticmethod def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox_style): ixf_F.vllm_rotary_embedding_neox(positions, query, key, head_size, cos_sin_cache, is_neox_style) # paged attention @staticmethod def paged_attention_v1( output, query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes=None, kv_cache_dtype=None, ): return ixf_F.vllm_single_query_cached_kv_attention( output, query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes, ) @staticmethod def paged_attention_v2( output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes=None, kv_cache_dtype=None, use_sqrt_alibi=False, ): return ixf_F.vllm_single_query_cached_kv_attention_v2( output, 256, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes, use_sqrt_alibi, ) # awq @staticmethod def awq_gemm(x, qweight, scales, qzeros, pack_factor): return ixf_F.quantized_linear(x,qweight,scales,"awq",32 // pack_factor,qzeros,None,group_size=128) @staticmethod def awq_dequantize(qweight, scales, qzeros, holder1, holder2, holder3): raise NotImplementedError() # gqt-q @staticmethod def gptq_shuffle(qweights,g_idx,weight_bits): return ixf_F.vllm_gptq_shuffle(qweights,g_idx) @staticmethod def gptq_gemm(x, qweight, qzeros, scales, idx, status, weight_bits): batch = x.shape[0] if batch <= 8: return ixf_F.quantized_linear(x,qweight,scales,"gptq",4,qzeros,None,group_size=128) o_dtype_str = "fp16" if x.dtype == torch.half else "bf16" deq_w = ixf_F.quantized_weight_dequant(qweight,scales,"gptq",o_dtype_str,4,qzeros,group_size=128) return torch.matmul(x,deq_w) # squeezellm @staticmethod def squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table): raise NotImplementedError() # marlin @staticmethod def marlin_gemm(x_2d, qweight, scales, workspace, size_m, size_n, size_k): raise NotImplementedError() # moe @staticmethod def moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad): raise NotImplementedError() # smoothquant @staticmethod def quant(output,input,scale): ixf_F.vllm_smooth_quant(output,input,scale) return output @staticmethod def dequant(output,x,scale,global_scale): ixf_F.vllm_smooth_dequant(output,x,scale,global_scale) return output @staticmethod def dequant_add_residual(output,x,residual,scale,global_scale): if isinstance(x,torch.Tensor): ixf_F.vllm_smooth_dequant_add_residual(output,x,residual,scale,global_scale) return output @staticmethod def dequant_silu_and_mul_quant(output,x,gate_scale, up_scale, scale, temp = None): ixf_F.vllm_smooth_dequant_silu_and_mul_quant(output,x,gate_scale, up_scale, scale, temp) @staticmethod def rms_norm_quant(output, input, weight, epsilon): return ixf_F.vllm_smooth_rms_norm_quant(output, input, weight, epsilon) @staticmethod def fused_add_rms_norm_quant(output, input, residual, weight, epsilon): ixf_F.vllm_smooth_fused_add_rms_norm_quant(output, input, residual, weight, epsilon) @staticmethod def dequant_fused_add_rms_norm_quant(output, input, residual, weight, epsilon, scale, global_scale): ixf_F.vllm_smooth_dequant_fused_add_rms_norm_quant(output, input, residual, weight, epsilon, scale, global_scale) @staticmethod def dequant_rotary_embedding(positions, query, key, head_size, cos_sin_cache, query_out, key_out, query_scale, key_scale, is_neox_style): ixf_F.vllm_smooth_dequant_rotary_embedding_neox(positions, query, key, head_size, cos_sin_cache, query_out, key_out, query_scale, key_scale, is_neox_style) @staticmethod def linear_a8_w8_o32_(x, weight, output): return ixf_F.linear_i8w8o32(x,weight,output) class cache_ops(): @staticmethod def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping): ixf_F.vllm_cache_ops_reshape_and_cache( key, value, key_cache, value_cache, slot_mapping ) @staticmethod def copy_blocks(key_caches, value_caches, block_mapping): ixf_F.vllm_copy_cache( key_caches, value_caches, block_mapping ) @staticmethod def swap_blocks(src_key_cache, dst_key_cache, src_to_dst): ixf_F.vllm_swap_blocks( src_key_cache, dst_key_cache, src_to_dst ) class custom_ar(): IS_INIT:bool = False @staticmethod def is_init(): return_status = custom_ar.IS_INIT custom_ar.IS_INIT = True return return_status @staticmethod def init_cumtom_ar(): if not is_initialized(get_default_comm_group()): group = ixft.create_ixformer_group_from_pg() ixformer.cuda.set_device(torch.cuda.current_device()) cdist.update_default_comm_group(group) cdist.ipc.init_communicator_by_nccl() @staticmethod def all_reduce_reg(ptr,tensor,out = None): raise NotImplementedError() @staticmethod def all_reduce_unreg(ptr,tensor,buffer,out = None): dtype = tensor.dtype if torch.is_tensor(tensor): dtype = torch_to_ixformer_dtype(dtype) if out is None: out = tensor cdist.ipc.allreduce( tensor.data_ptr(), out.data_ptr(), dtype, tensor.numel(), ReduceOp.SUM ) return out @staticmethod def dispose(): ixformer.distributed.destroy_process_group() @staticmethod def should_custom_ar(tensor:torch.Tensor, max_size, world_size, full_nvlink): return cdist.ipc.should_custom_ar(tensor.numel(),tensor.element_size(),max_size,world_size) class cuda_utils(): @staticmethod def get_max_shared_memory_per_block_device_attribute(gpu): return 100000000