[DSv32] Use torch.compile for _get_logits_head_gate (#11565)
This commit is contained in:
@@ -205,6 +205,7 @@ class Indexer(CustomOp):
|
|||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
|
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
|
||||||
weights, _ = self.weights_proj(x)
|
weights, _ = self.weights_proj(x)
|
||||||
weights = weights * self.n_heads**-0.5
|
weights = weights * self.n_heads**-0.5
|
||||||
|
|||||||
Reference in New Issue
Block a user