[DSv32] Use torch.compile for _get_logits_head_gate (#11565)
This commit is contained in:
@@ -205,6 +205,7 @@ class Indexer(CustomOp):
|
||||
|
||||
return ans
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
|
||||
weights, _ = self.weights_proj(x)
|
||||
weights = weights * self.n_heads**-0.5
|
||||
|
||||
Reference in New Issue
Block a user