from typing import Tuple, Union, Optional, List import torch from torch.nn import functional as F def split_last_two_dims_into_blocks(x, h, w): leading_dims = x.shape[:-2] H, W = x.shape[-2:] assert ( H % h == 0 and W % w == 0 ), "The last two dimensions must be divisible by block size." x_reshaped = x.view(-1, 1, H, W) unfolded = F.unfold(x_reshaped, kernel_size=(h, w), stride=(h, w)) unfolded = unfolded.view(-1, 1, h, w, H // h, W // w) unfolded = unfolded.permute(0, 1, 4, 5, 2, 3) final_shape = leading_dims + (H // h, W // w, h, w) result = unfolded.view(final_shape) return result def merge_blocks_to_original_layout(x, h, w): leading_dims = x.shape[:-4] H_div_h, W_div_w, h, w = x.shape[-4:] H = H_div_h * h W = W_div_w * w x_reshaped = x.view(-1, 1, H_div_h, W_div_w, h, w) x_reshaped = x_reshaped.permute(0, 1, 4, 5, 2, 3) x_reshaped = x_reshaped.view(-1, h * w, H_div_h * W_div_w) folded = F.fold(x_reshaped, output_size=(H, W), kernel_size=(h, w), stride=(h, w)) final_shape = leading_dims + (H, W) result = folded.view(final_shape) return result def w8a8_block_fp8_matmul( input: torch.Tensor, weight: torch.Tensor, input_scale: Optional[torch.Tensor], weight_scale: Optional[torch.Tensor], block_size: List[int], is_linear_weight: bool = False, output_opt: Optional[torch.Tensor] = None, **kwargs ): b0, b1 = block_size dim0, dim1 = weight.shape dim0pad, dim1pad = 0, 0 if dim0 % b0 != 0: dim0pad = b0 - dim0 % b0 if dim1 % b1 != 0: dim1pad = b1 - dim1 % b1 dim0_origin, dim1_origin = dim0, dim1 dim0 += dim0pad dim1 += dim1pad bs0, bs1 = dim0 // b0, dim1 // b1 weight_dequant = torch.nn.functional.pad(weight, (0, dim1pad, 0, dim0pad), value=0) weight_dequant = weight_dequant.cpu().view(bs0, b0, bs1, b1).permute( 0, 2, 1, 3 ).reshape(bs0, bs1, -1).float().to(input.device) * weight_scale.unsqueeze(-1) weight_dequant = ( weight_dequant.reshape(bs0, bs1, b0, b1) .permute(0, 2, 1, 3) .reshape(dim0, dim1) .to(input.dtype) ) weight_dequant = weight_dequant[:dim0_origin, :dim1_origin] output = torch.matmul( input, weight_dequant.T if is_linear_weight else weight_dequant ) if output_opt is not None: output = output_opt.copy_(output) return output def w8a8_block_fp8_linear( input: torch.Tensor, weight: torch.Tensor, input_scale: Optional[torch.Tensor], weight_scale: Optional[torch.Tensor], block_size: List[int], **kwargs ): assert input_scale is None, "w8a8_block_fp8_matmul only support quant weight now" return w8a8_block_fp8_matmul( input, weight, None, weight_scale, block_size, is_linear_weight=True ) def fused_experts( hidden_states: torch.Tensor, w13_weight: torch.Tensor, w2_weight: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, use_fp8_w8a8: bool = True, w13_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a13_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, decode_with_batch: bool = False, ) -> torch.Tensor: batch_seq_all, hidden_dims = hidden_states.shape intermediate_size = w2_weight.shape[-1] num_experts = w13_weight.shape[0] w13_weight = w13_weight.contiguous() w2_weight = w2_weight.contiguous() w13_scale = w13_scale.contiguous() w2_scale = w2_scale.contiguous() final_hidden_states = torch.zeros_like(hidden_states) import torch.nn.functional as F w1_scale = w13_scale w2_scale = w2_scale _, bs0_w13, bs1_w13 = w1_scale.shape _, bs0_w2, bs1_w2 = w2_scale.shape sel_experts = topk_ids.shape[1] if hidden_states.shape[0] == 1: for id in range(sel_experts): expert_idx = topk_ids[0][id] expert_w1 = w13_weight[expert_idx].contiguous() expert_w2 = w2_weight[expert_idx].contiguous() ws1 = w1_scale[expert_idx].unsqueeze(2).contiguous() ws2 = w2_scale[expert_idx].unsqueeze(2).contiguous() dim0, dim1 = expert_w1.shape b0, b1 = dim0 // bs0_w13, dim1 // bs1_w13 # assert (bs0, bs1, 1)==ws1.shape, f"bs0, bs1, 1 is {bs0},{bs1}, 1, <==> {ws1.shape}" expert_w1 = ( expert_w1 .view(bs0_w13, b0, bs1_w13, b1) .permute(0, 2, 1, 3) .reshape(bs0_w13, bs1_w13, -1) .float() .to(hidden_states.device) * ws1 ) expert_w1 = ( expert_w1.reshape(bs0_w13, bs1_w13, b0, b1) .permute(0, 2, 1, 3) .reshape(dim0, dim1) .to(hidden_states.dtype) ) dim0, dim1 = expert_w2.shape b0, b1 = dim0 // bs0_w2, dim1 // bs1_w2 # assert (bs0, bs1, 1)==ws2.shape expert_w2 = ( expert_w2 .view(bs0_w2, b0, bs1_w2, b1) .permute(0, 2, 1, 3) .reshape(bs0_w2, bs1_w2, -1) .float() .to(hidden_states.device) * ws2 ) expert_w2 = ( expert_w2.reshape(bs0_w2, bs1_w2, b0, b1) .permute(0, 2, 1, 3) .reshape(dim0, dim1) .to(hidden_states.dtype) ) expert_weights = topk_weights[0][id].to(hidden_states.dtype) x = hidden_states x = F.linear(x, expert_w1) gate = F.silu(x[:, :intermediate_size]) x = x[:, intermediate_size:] * gate x = F.linear(x, expert_w2) current_hidden_states = x * expert_weights current_hidden_states = current_hidden_states.to(x.dtype) final_hidden_states += current_hidden_states else: for expert_idx in range(num_experts): # topk_ids [tokens, experts] => sample:[10, 8] # expert_mask [tokens, experts] => sample:[10, 8] expert_mask = topk_ids == expert_idx idx = torch.where(expert_mask)[0] if idx.numel() == 0: continue expert_w1 = w13_weight[expert_idx].contiguous() expert_w2 = w2_weight[expert_idx].contiguous() ws1 = w1_scale[expert_idx].unsqueeze(2).contiguous() ws2 = w2_scale[expert_idx].unsqueeze(2).contiguous() dim0, dim1 = expert_w1.shape b0, b1 = dim0 // bs0_w13, dim1 // bs1_w13 # assert (bs0, bs1, 1)==ws1.shape, f"bs0, bs1, 1 is {bs0},{bs1}, 1, <==> {ws1.shape}" expert_w1 = ( expert_w1 .view(bs0_w13, b0, bs1_w13, b1) .permute(0, 2, 1, 3) .reshape(bs0_w13, bs1_w13, -1) .float() .to(hidden_states.device) * ws1 ) expert_w1 = ( expert_w1.reshape(bs0_w13, bs1_w13, b0, b1) .permute(0, 2, 1, 3) .reshape(dim0, dim1) .to(hidden_states.dtype) ) dim0, dim1 = expert_w2.shape b0, b1 = dim0 // bs0_w2, dim1 // bs1_w2 # assert (bs0, bs1, 1)==ws2.shape expert_w2 = ( expert_w2 .view(bs0_w2, b0, bs1_w2, b1) .permute(0, 2, 1, 3) .reshape(bs0_w2, bs1_w2, -1) .float() .to(hidden_states.device) * ws2 ) expert_w2 = ( expert_w2.reshape(bs0_w2, bs1_w2, b0, b1) .permute(0, 2, 1, 3) .reshape(dim0, dim1) .to(hidden_states.dtype) ) # [seq, experts] expert_weights = ( topk_weights.masked_select(expert_mask) .unsqueeze(1) .to(hidden_states.dtype) ) x = hidden_states[idx] x = F.linear(x, expert_w1) gate = F.silu(x[:, :intermediate_size]) x = x[:, intermediate_size:] * gate x = F.linear(x, expert_w2) current_hidden_states = x * expert_weights current_hidden_states = current_hidden_states.to(x.dtype) # final_hidden_states[idx] += current_hidden_states final_hidden_states.index_add_(0, idx, current_hidden_states) final_hidden_states = final_hidden_states.reshape(batch_seq_all, hidden_dims) return final_hidden_states def fused_mlp_mm_fp8( hidden_states: torch.Tensor, w13_weight: torch.Tensor, w2_weight: torch.Tensor, use_fp8_w8a8: bool = True, w13_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a13_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape_w13: Optional[List[int]] = None, block_shape_w2: Optional[List[int]] = None, ): def fp8_to_fp16(inp, scale, block_size, trans_type): inp_t = inp.to(trans_type) inp_t = split_last_two_dims_into_blocks(inp_t, block_size[0], block_size[1]) assert scale.size(0) == inp_t.size(-4) assert scale.size(1) == inp_t.size(-3) inp_t = inp_t * scale.unsqueeze(-1).unsqueeze(-1) inp_t = merge_blocks_to_original_layout(inp_t, block_size[0], block_size[1]) return inp_t.to(trans_type) w13_weight = w13_weight.contiguous() w2_weight = w2_weight.contiguous() w13_scale = w13_scale.contiguous() w2_scale = w2_scale.contiguous() w13_fp = fp8_to_fp16(w13_weight, w13_scale, block_shape_w13, hidden_states.dtype) w2_fp = fp8_to_fp16(w2_weight, w2_scale, block_shape_w2, hidden_states.dtype) out = hidden_states @ w13_fp out = torch.chunk(out, 2, dim=-1) out = F.silu(out[0]) * out[1] out = out @ w2_fp return out def mla_matmul_scale(input: torch.Tensor, weight: torch.Tensor, scale: float): output = torch.matmul(input, weight) output = output * scale output = output.to(input.dtype) return output def mla_matmul(input: torch.Tensor, weight: torch.Tensor): output = torch.matmul(input, weight) output = output.to(input.dtype) return output