[DSv32] Use torch.compile for _get_logits_head_gate (#11565)

This commit is contained in:
Trevor Morris
2025-10-13 18:38:39 -07:00
committed by GitHub
parent aaf7af1b17
commit 384733639a

View File

@@ -205,6 +205,7 @@ class Indexer(CustomOp):
return ans
@torch.compile(dynamic=True)
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x)
weights = weights * self.n_heads**-0.5