267 lines
7.9 KiB
Python
267 lines
7.9 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
import ixformer
|
|
import ixformer.functions as ixf_F
|
|
from ixformer._C import ReduceOp
|
|
from ixformer._C import _distributed as cdist
|
|
from ixformer._C._distributed import is_initialized, get_default_comm_group
|
|
from ixformer.contrib.torch.extension import ixformer_torch as ixft
|
|
from ixformer.contrib.torch.data_type_mapping import torch_to_ixformer_dtype
|
|
|
|
|
|
class ops():
|
|
# activations
|
|
@staticmethod
|
|
def silu_and_mul(output, x):
|
|
ixf_F.silu_and_mul(x, output)
|
|
|
|
@staticmethod
|
|
def gelu_and_mul(output, x):
|
|
ixf_F.gelu_and_mul(x, output)
|
|
|
|
@staticmethod
|
|
def gelu_new(output, x):
|
|
return F.gelu(x,"tanh")
|
|
|
|
@staticmethod
|
|
def gelu_fast(output, x):
|
|
return F.gelu(x,"tanh")
|
|
|
|
# rms norm
|
|
@staticmethod
|
|
def rms_norm(output, x, weight, epsilon):
|
|
ixf_F.rms_norm(x, weight, output, epsilon)
|
|
|
|
@staticmethod
|
|
def fused_add_rms_norm(input, residual, weight, epsilon, scale):
|
|
ixf_F.fused_add_rms_norm(input, residual, weight, epsilon, scale)
|
|
|
|
# rotary embedding
|
|
@staticmethod
|
|
def rotary_embedding(positions, query, key, head_size,
|
|
cos_sin_cache, is_neox_style):
|
|
ixf_F.vllm_rotary_embedding_neox(positions, query, key, head_size,
|
|
cos_sin_cache, is_neox_style)
|
|
|
|
# paged attention
|
|
@staticmethod
|
|
def paged_attention_v1(
|
|
output,
|
|
query,
|
|
key_cache,
|
|
value_cache,
|
|
head_mapping,
|
|
scale,
|
|
block_tables,
|
|
context_lens,
|
|
block_size,
|
|
max_context_len,
|
|
alibi_slopes=None,
|
|
kv_cache_dtype=None,
|
|
):
|
|
return ixf_F.vllm_single_query_cached_kv_attention(
|
|
output,
|
|
query,
|
|
key_cache,
|
|
value_cache,
|
|
head_mapping,
|
|
scale,
|
|
block_tables,
|
|
context_lens,
|
|
block_size,
|
|
max_context_len,
|
|
alibi_slopes,
|
|
)
|
|
|
|
@staticmethod
|
|
def paged_attention_v2(
|
|
output,
|
|
exp_sums,
|
|
max_logits,
|
|
tmp_output,
|
|
query,
|
|
key_cache,
|
|
value_cache,
|
|
head_mapping,
|
|
scale,
|
|
block_tables,
|
|
context_lens,
|
|
block_size,
|
|
max_context_len,
|
|
alibi_slopes=None,
|
|
kv_cache_dtype=None,
|
|
use_sqrt_alibi=False,
|
|
):
|
|
return ixf_F.vllm_single_query_cached_kv_attention_v2(
|
|
output,
|
|
256,
|
|
exp_sums,
|
|
max_logits,
|
|
tmp_output,
|
|
query,
|
|
key_cache,
|
|
value_cache,
|
|
head_mapping,
|
|
scale,
|
|
block_tables,
|
|
context_lens,
|
|
block_size,
|
|
max_context_len,
|
|
alibi_slopes,
|
|
use_sqrt_alibi,
|
|
)
|
|
|
|
# awq
|
|
@staticmethod
|
|
def awq_gemm(x, qweight, scales, qzeros, pack_factor):
|
|
return ixf_F.quantized_linear(x,qweight,scales,"awq",32 // pack_factor,qzeros,None,group_size=128)
|
|
|
|
@staticmethod
|
|
def awq_dequantize(qweight, scales, qzeros, holder1, holder2, holder3):
|
|
raise NotImplementedError()
|
|
|
|
# gqt-q
|
|
@staticmethod
|
|
def gptq_shuffle(qweights,g_idx,weight_bits):
|
|
return ixf_F.vllm_gptq_shuffle(qweights,g_idx)
|
|
|
|
@staticmethod
|
|
def gptq_gemm(x, qweight, qzeros, scales, idx, status, weight_bits):
|
|
batch = x.shape[0]
|
|
if batch <= 8:
|
|
return ixf_F.quantized_linear(x,qweight,scales,"gptq",4,qzeros,None,group_size=128)
|
|
o_dtype_str = "fp16" if x.dtype == torch.half else "bf16"
|
|
deq_w = ixf_F.quantized_weight_dequant(qweight,scales,"gptq",o_dtype_str,4,qzeros,group_size=128)
|
|
return torch.matmul(x,deq_w)
|
|
|
|
# squeezellm
|
|
@staticmethod
|
|
def squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table):
|
|
raise NotImplementedError()
|
|
|
|
# marlin
|
|
@staticmethod
|
|
def marlin_gemm(x_2d, qweight, scales, workspace, size_m, size_n, size_k):
|
|
raise NotImplementedError()
|
|
|
|
# moe
|
|
@staticmethod
|
|
def moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
|
expert_ids, num_tokens_post_pad):
|
|
raise NotImplementedError()
|
|
|
|
# smoothquant
|
|
@staticmethod
|
|
def quant(output,input,scale):
|
|
ixf_F.vllm_smooth_quant(output,input,scale)
|
|
return output
|
|
|
|
@staticmethod
|
|
def dequant(output,x,scale,global_scale):
|
|
ixf_F.vllm_smooth_dequant(output,x,scale,global_scale)
|
|
return output
|
|
|
|
@staticmethod
|
|
def dequant_add_residual(output,x,residual,scale,global_scale):
|
|
if isinstance(x,torch.Tensor):
|
|
ixf_F.vllm_smooth_dequant_add_residual(output,x,residual,scale,global_scale)
|
|
return output
|
|
|
|
@staticmethod
|
|
def dequant_silu_and_mul_quant(output,x,gate_scale, up_scale, scale, temp = None):
|
|
ixf_F.vllm_smooth_dequant_silu_and_mul_quant(output,x,gate_scale, up_scale, scale, temp)
|
|
|
|
@staticmethod
|
|
def rms_norm_quant(output, input, weight, epsilon):
|
|
return ixf_F.vllm_smooth_rms_norm_quant(output, input, weight, epsilon)
|
|
|
|
@staticmethod
|
|
def fused_add_rms_norm_quant(output, input, residual, weight, epsilon):
|
|
ixf_F.vllm_smooth_fused_add_rms_norm_quant(output, input, residual, weight, epsilon)
|
|
|
|
@staticmethod
|
|
def dequant_fused_add_rms_norm_quant(output, input, residual, weight, epsilon, scale, global_scale):
|
|
ixf_F.vllm_smooth_dequant_fused_add_rms_norm_quant(output, input, residual, weight, epsilon, scale, global_scale)
|
|
|
|
@staticmethod
|
|
def dequant_rotary_embedding(positions, query, key, head_size,
|
|
cos_sin_cache, query_out, key_out, query_scale, key_scale, is_neox_style):
|
|
ixf_F.vllm_smooth_dequant_rotary_embedding_neox(positions, query, key, head_size,
|
|
cos_sin_cache, query_out, key_out, query_scale, key_scale, is_neox_style)
|
|
|
|
@staticmethod
|
|
def linear_a8_w8_o32_(x, weight, output):
|
|
return ixf_F.linear_i8w8o32(x,weight,output)
|
|
|
|
|
|
class cache_ops():
|
|
|
|
@staticmethod
|
|
def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping):
|
|
ixf_F.vllm_cache_ops_reshape_and_cache(
|
|
key, value, key_cache, value_cache, slot_mapping
|
|
)
|
|
|
|
@staticmethod
|
|
def copy_blocks(key_caches, value_caches, block_mapping):
|
|
ixf_F.vllm_copy_cache(
|
|
key_caches, value_caches, block_mapping
|
|
)
|
|
|
|
@staticmethod
|
|
def swap_blocks(src_key_cache, dst_key_cache, src_to_dst):
|
|
ixf_F.vllm_swap_blocks(
|
|
src_key_cache, dst_key_cache, src_to_dst
|
|
)
|
|
|
|
class custom_ar():
|
|
|
|
IS_INIT:bool = False
|
|
|
|
@staticmethod
|
|
def is_init():
|
|
return_status = custom_ar.IS_INIT
|
|
custom_ar.IS_INIT = True
|
|
return return_status
|
|
|
|
@staticmethod
|
|
def init_cumtom_ar():
|
|
if not is_initialized(get_default_comm_group()):
|
|
group = ixft.create_ixformer_group_from_pg()
|
|
ixformer.cuda.set_device(torch.cuda.current_device())
|
|
cdist.update_default_comm_group(group)
|
|
cdist.ipc.init_communicator_by_nccl()
|
|
|
|
@staticmethod
|
|
def all_reduce_reg(ptr,tensor,out = None):
|
|
raise NotImplementedError()
|
|
|
|
@staticmethod
|
|
def all_reduce_unreg(ptr,tensor,buffer,out = None):
|
|
dtype = tensor.dtype
|
|
if torch.is_tensor(tensor):
|
|
dtype = torch_to_ixformer_dtype(dtype)
|
|
|
|
if out is None:
|
|
out = tensor
|
|
cdist.ipc.allreduce(
|
|
tensor.data_ptr(), out.data_ptr(), dtype, tensor.numel(), ReduceOp.SUM
|
|
)
|
|
return out
|
|
|
|
@staticmethod
|
|
def dispose():
|
|
ixformer.distributed.destroy_process_group()
|
|
|
|
@staticmethod
|
|
def should_custom_ar(tensor:torch.Tensor, max_size, world_size, full_nvlink):
|
|
return cdist.ipc.should_custom_ar(tensor.numel(),tensor.element_size(),max_size,world_size)
|
|
|
|
class cuda_utils():
|
|
@staticmethod
|
|
def get_max_shared_memory_per_block_device_attribute(gpu):
|
|
return 100000000
|