kernel
drbh commited on
Commit
3bdb4b8
Β·
1 Parent(s): 89e2950

feat: bump build for shared experts

Browse files
Files changed (36) hide show
  1. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} +1 -1
  2. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
  3. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  4. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py +277 -1
  5. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} +1 -1
  6. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
  7. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  8. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py +277 -1
  9. build/torch26-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} +1 -1
  10. build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
  11. build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  12. build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py +277 -1
  13. build/torch26-cxx98-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} +1 -1
  14. build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
  15. build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  16. build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py +277 -1
  17. build/torch26-cxx98-cu124-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} +1 -1
  18. build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
  19. build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  20. build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py +277 -1
  21. build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_76c7de7.abi3.so +0 -3
  22. build/torch26-cxx98-cu126-x86_64-linux/megablocks/{_megablocks_9a1816c.abi3.so β†’ _megablocks_89e2950.abi3.so} +1 -1
  23. build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  24. build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py +277 -1
  25. build/torch27-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} +1 -1
  26. build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
  27. build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  28. build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py +277 -1
  29. build/torch27-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} +1 -1
  30. build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
  31. build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  32. build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py +277 -1
  33. build/torch27-cxx11-cu128-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} +1 -1
  34. build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
  35. build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py +3 -3
  36. build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py +277 -1
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3c5605ba50f2661b9dc4c5609572323fb4f52787181109c5900c261c5e2bf602
3
  size 10517576
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:070067fec0e735e865610caf4fc33b384fe8c9c47a002c365f740c82c5af1bab
3
  size 10517576
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4e4c48e189572141f6a140dd83f9eca19eaebbc20c5cd686aa0263aafec14533
3
- size 10517576
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_76c7de7
3
- ops = torch.ops._megablocks_76c7de7
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_76c7de7::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_89e2950
3
+ ops = torch.ops._megablocks_89e2950
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_89e2950::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
@@ -680,6 +740,136 @@ def moe_forward(
680
  return x, expert_weights, router_scores
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
691
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
692
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
693
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
694
-
695
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
 
 
 
696
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
697
  forward_fn = parallel_forward_once if has_parallel else forward_once
698
 
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
722
  hidden_size=self.experts.hidden_size,
723
  mlp_impl=mlp_impl,
724
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
  return output, expert_weights_out
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
155
+ # Shared expert MLP forward pass
156
+ def shared_mlp_forward(
157
+ x: torch.Tensor,
158
+ up_proj_weight: torch.Tensor,
159
+ down_proj_weight: torch.Tensor,
160
+ up_proj_bias: Optional[torch.Tensor] = None,
161
+ down_proj_bias: Optional[torch.Tensor] = None,
162
+ activation_fn: Optional[Any] = None,
163
+ gradient_scale: Optional[float] = None,
164
+ ) -> torch.Tensor:
165
+ # Default activation function
166
+ if activation_fn is None:
167
+ activation_fn = torch.nn.functional.gelu
168
+
169
+ # Scale weights
170
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
171
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
172
+ if up_proj_bias is not None:
173
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
174
+ if down_proj_bias is not None:
175
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
176
+
177
+ # Resolve dtensors
178
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
179
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
180
+ if up_proj_bias is not None:
181
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
182
+ if down_proj_bias is not None:
183
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
184
+
185
+ # Up projection
186
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
+
188
+ # Activation
189
+ x = activation_fn(x)
190
+
191
+ # Down projection
192
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
+
194
+ return x
195
+
196
+
197
+ # Combine outputs from shared expert and regular experts
198
+ def combine_expert_shared_outputs(
199
+ shared_expert_out: torch.Tensor,
200
+ expert_out: torch.Tensor,
201
+ shared_expert_weighted_sum: bool = False,
202
+ moe_top_k: int = 1,
203
+ ) -> torch.Tensor:
204
+ if shared_expert_weighted_sum:
205
+ # Weighted sum based on number of experts used
206
+ total_experts = moe_top_k + 1
207
+ shared_weight = 1.0 / total_experts
208
+ expert_weight = moe_top_k / total_experts
209
+ return shared_expert_out * shared_weight + expert_out * expert_weight
210
+ else:
211
+ # Simple addition
212
+ return shared_expert_out + expert_out
213
+
214
+
215
  # Global variable to store load balancing loss
216
  _LOAD_BALANCING_LOSS = []
217
 
 
740
  return x, expert_weights, router_scores
741
 
742
 
743
+ def moe_forward_with_shared_expert(
744
+ x: torch.Tensor,
745
+ router_weight: torch.Tensor,
746
+ moe_top_k: int,
747
+ moe_num_experts: int,
748
+ moe_jitter_eps: float = None,
749
+ moe_normalize_expert_weights: int = None,
750
+ uniform_expert_assignment: bool = False,
751
+ training: bool = False,
752
+ w1: torch.Tensor = None,
753
+ w2: torch.Tensor = None,
754
+ w1_bias: torch.Tensor = None,
755
+ w2_bias: torch.Tensor = None,
756
+ gradient_scale: Optional[float] = None,
757
+ alpha: float = 1.702,
758
+ sort_end_bit: int = 0,
759
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
760
+ moe_capacity_factor: float = 1.0,
761
+ moe_expert_model_parallelism: bool = False,
762
+ forward_fn: Any = None,
763
+ hidden_size: int = None,
764
+ mlp_impl: str = "grouped",
765
+ # Shared expert parameters
766
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
767
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
768
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
769
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
770
+ shared_expert_weighted_sum: bool = False,
771
+ shared_activation_fn: Optional[Any] = None,
772
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
+
774
+ # First, compute regular MoE forward pass
775
+ expert_out, expert_weights, router_scores = moe_forward(
776
+ x=x,
777
+ router_weight=router_weight,
778
+ moe_top_k=moe_top_k,
779
+ moe_num_experts=moe_num_experts,
780
+ moe_jitter_eps=moe_jitter_eps,
781
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
782
+ uniform_expert_assignment=uniform_expert_assignment,
783
+ training=training,
784
+ w1=w1,
785
+ w2=w2,
786
+ w1_bias=w1_bias,
787
+ w2_bias=w2_bias,
788
+ gradient_scale=gradient_scale,
789
+ alpha=alpha,
790
+ sort_end_bit=sort_end_bit,
791
+ expert_parallel_group=expert_parallel_group,
792
+ moe_capacity_factor=moe_capacity_factor,
793
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
794
+ forward_fn=forward_fn,
795
+ hidden_size=hidden_size,
796
+ mlp_impl=mlp_impl,
797
+ )
798
+
799
+ # If shared expert weights provided, compute shared expert output
800
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
+ shared_expert_out = shared_mlp_forward(
802
+ x=x,
803
+ up_proj_weight=shared_up_proj_weight,
804
+ down_proj_weight=shared_down_proj_weight,
805
+ up_proj_bias=shared_up_proj_bias,
806
+ down_proj_bias=shared_down_proj_bias,
807
+ activation_fn=shared_activation_fn,
808
+ gradient_scale=gradient_scale,
809
+ )
810
+
811
+ # Combine expert outputs
812
+ combined_out = combine_expert_shared_outputs(
813
+ shared_expert_out=shared_expert_out,
814
+ expert_out=expert_out,
815
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
816
+ moe_top_k=moe_top_k,
817
+ )
818
+
819
+ return combined_out, expert_weights, router_scores
820
+
821
+ # Return regular MoE output if no shared expert
822
+ return expert_out, expert_weights, router_scores
823
+
824
+
825
+ def create_shared_expert_weights(
826
+ hidden_size: int,
827
+ shared_expert_hidden_size: int,
828
+ device: torch.device,
829
+ dtype: torch.dtype,
830
+ init_method: Any,
831
+ output_layer_init_method: Any = None,
832
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
833
+
834
+ if output_layer_init_method is None:
835
+ output_layer_init_method = init_method
836
+
837
+ # Create weight tensors
838
+ up_proj_weight = torch.empty(
839
+ shared_expert_hidden_size,
840
+ hidden_size,
841
+ device=device,
842
+ dtype=dtype,
843
+ )
844
+ down_proj_weight = torch.empty(
845
+ hidden_size,
846
+ shared_expert_hidden_size,
847
+ device=device,
848
+ dtype=dtype,
849
+ )
850
+
851
+ # Initialize weights
852
+ init_method(up_proj_weight)
853
+ output_layer_init_method(down_proj_weight)
854
+
855
+ # No bias by default
856
+ return up_proj_weight, down_proj_weight, None, None
857
+
858
+ # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
+ # This exists because device_mesh is trapped in hook closures with no model attribute
860
+ # Fragile - breaks if hook structure changes or Python internals change
861
+ # TODO: Replace with a more robust solution when available
862
+ def get_device_mesh(model):
863
+ # Extract device_mesh from child's unused pre_hook closure
864
+ try:
865
+ # Find the pre-hook that contains 'device_mesh' in its closure
866
+ hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
867
+ # Extract the device_mesh from the closure
868
+ return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
869
+ except Exception:
870
+ return None
871
+
872
+
873
  class MegaBlocksMoeMLP(torch.nn.Module):
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
+
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
886
+ if expert_parallel_group is None:
887
+ device_mesh = get_device_mesh(self)
888
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
+
890
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
 
 
916
  hidden_size=self.experts.hidden_size,
917
  mlp_impl=mlp_impl,
918
  )
919
+ return output, expert_weights_out
920
+
921
+
922
+ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
+
924
+ def __init__(self):
925
+ super().__init__()
926
+ # Shared expert weights will be set by the user
927
+ self.shared_up_proj_weight = None
928
+ self.shared_down_proj_weight = None
929
+ self.shared_up_proj_bias = None
930
+ self.shared_down_proj_bias = None
931
+ self.shared_expert_weighted_sum = False
932
+ self.shared_activation_fn = None
933
+
934
+ def set_shared_expert_weights(
935
+ self,
936
+ up_proj_weight: torch.Tensor,
937
+ down_proj_weight: torch.Tensor,
938
+ up_proj_bias: Optional[torch.Tensor] = None,
939
+ down_proj_bias: Optional[torch.Tensor] = None,
940
+ weighted_sum: bool = False,
941
+ activation_fn: Optional[Any] = None,
942
+ ):
943
+ self.shared_up_proj_weight = up_proj_weight
944
+ self.shared_down_proj_weight = down_proj_weight
945
+ self.shared_up_proj_bias = up_proj_bias
946
+ self.shared_down_proj_bias = down_proj_bias
947
+ self.shared_expert_weighted_sum = weighted_sum
948
+ self.shared_activation_fn = activation_fn
949
+
950
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
+ moe_top_k = getattr(self.router, "top_k", 4)
952
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
953
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
954
+ alpha = getattr(self.experts, "alpha", 1.0)
955
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
958
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
+
960
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
961
+ if expert_parallel_group is None:
962
+ device_mesh = get_device_mesh(self)
963
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
+
965
+ has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
966
+ forward_fn = parallel_forward_once if has_parallel else forward_once
967
+
968
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
969
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
970
+
971
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
+ x=x,
973
+ router_weight=self.router.weight,
974
+ moe_top_k=moe_top_k,
975
+ moe_num_experts=moe_num_experts,
976
+ moe_jitter_eps=moe_jitter_eps,
977
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
978
+ uniform_expert_assignment=uniform_expert_assignment,
979
+ training=self.training,
980
+ w1=self.experts.gate_up_proj,
981
+ w2=self.experts.down_proj,
982
+ w1_bias=self.experts.gate_up_proj_bias,
983
+ w2_bias=self.experts.down_proj_bias,
984
+ gradient_scale=gradient_scale,
985
+ alpha=alpha,
986
+ sort_end_bit=sort_end_bit,
987
+ expert_parallel_group=expert_parallel_group,
988
+ moe_capacity_factor=moe_capacity_factor,
989
+ moe_expert_model_parallelism=has_parallel,
990
+ forward_fn=forward_fn,
991
+ hidden_size=self.experts.hidden_size,
992
+ mlp_impl=mlp_impl,
993
+ # Shared expert parameters
994
+ shared_up_proj_weight=self.shared_up_proj_weight,
995
+ shared_down_proj_weight=self.shared_down_proj_weight,
996
+ shared_up_proj_bias=self.shared_up_proj_bias,
997
+ shared_down_proj_bias=self.shared_down_proj_bias,
998
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
+ shared_activation_fn=self.shared_activation_fn,
1000
+ )
1001
  return output, expert_weights_out
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:12c21d6f72f90950adbda156534691dd753476a18719b416541e8d6920a173b4
3
  size 11869392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02dffd561ef226c1ec17c99e462c3c771879f078dde9b1e5cd8bd5992be5b3da
3
  size 11869392
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3d958a0c77589a5ede72336d1cab80ea9d6324ef6f8a9a187af2da4db74e1894
3
- size 11869392
 
 
 
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_76c7de7
3
- ops = torch.ops._megablocks_76c7de7
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_76c7de7::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_89e2950
3
+ ops = torch.ops._megablocks_89e2950
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_89e2950::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
@@ -680,6 +740,136 @@ def moe_forward(
680
  return x, expert_weights, router_scores
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
691
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
692
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
693
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
694
-
695
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
 
 
 
696
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
697
  forward_fn = parallel_forward_once if has_parallel else forward_once
698
 
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
722
  hidden_size=self.experts.hidden_size,
723
  mlp_impl=mlp_impl,
724
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
  return output, expert_weights_out
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
155
+ # Shared expert MLP forward pass
156
+ def shared_mlp_forward(
157
+ x: torch.Tensor,
158
+ up_proj_weight: torch.Tensor,
159
+ down_proj_weight: torch.Tensor,
160
+ up_proj_bias: Optional[torch.Tensor] = None,
161
+ down_proj_bias: Optional[torch.Tensor] = None,
162
+ activation_fn: Optional[Any] = None,
163
+ gradient_scale: Optional[float] = None,
164
+ ) -> torch.Tensor:
165
+ # Default activation function
166
+ if activation_fn is None:
167
+ activation_fn = torch.nn.functional.gelu
168
+
169
+ # Scale weights
170
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
171
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
172
+ if up_proj_bias is not None:
173
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
174
+ if down_proj_bias is not None:
175
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
176
+
177
+ # Resolve dtensors
178
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
179
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
180
+ if up_proj_bias is not None:
181
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
182
+ if down_proj_bias is not None:
183
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
184
+
185
+ # Up projection
186
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
+
188
+ # Activation
189
+ x = activation_fn(x)
190
+
191
+ # Down projection
192
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
+
194
+ return x
195
+
196
+
197
+ # Combine outputs from shared expert and regular experts
198
+ def combine_expert_shared_outputs(
199
+ shared_expert_out: torch.Tensor,
200
+ expert_out: torch.Tensor,
201
+ shared_expert_weighted_sum: bool = False,
202
+ moe_top_k: int = 1,
203
+ ) -> torch.Tensor:
204
+ if shared_expert_weighted_sum:
205
+ # Weighted sum based on number of experts used
206
+ total_experts = moe_top_k + 1
207
+ shared_weight = 1.0 / total_experts
208
+ expert_weight = moe_top_k / total_experts
209
+ return shared_expert_out * shared_weight + expert_out * expert_weight
210
+ else:
211
+ # Simple addition
212
+ return shared_expert_out + expert_out
213
+
214
+
215
  # Global variable to store load balancing loss
216
  _LOAD_BALANCING_LOSS = []
217
 
 
740
  return x, expert_weights, router_scores
741
 
742
 
743
+ def moe_forward_with_shared_expert(
744
+ x: torch.Tensor,
745
+ router_weight: torch.Tensor,
746
+ moe_top_k: int,
747
+ moe_num_experts: int,
748
+ moe_jitter_eps: float = None,
749
+ moe_normalize_expert_weights: int = None,
750
+ uniform_expert_assignment: bool = False,
751
+ training: bool = False,
752
+ w1: torch.Tensor = None,
753
+ w2: torch.Tensor = None,
754
+ w1_bias: torch.Tensor = None,
755
+ w2_bias: torch.Tensor = None,
756
+ gradient_scale: Optional[float] = None,
757
+ alpha: float = 1.702,
758
+ sort_end_bit: int = 0,
759
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
760
+ moe_capacity_factor: float = 1.0,
761
+ moe_expert_model_parallelism: bool = False,
762
+ forward_fn: Any = None,
763
+ hidden_size: int = None,
764
+ mlp_impl: str = "grouped",
765
+ # Shared expert parameters
766
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
767
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
768
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
769
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
770
+ shared_expert_weighted_sum: bool = False,
771
+ shared_activation_fn: Optional[Any] = None,
772
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
+
774
+ # First, compute regular MoE forward pass
775
+ expert_out, expert_weights, router_scores = moe_forward(
776
+ x=x,
777
+ router_weight=router_weight,
778
+ moe_top_k=moe_top_k,
779
+ moe_num_experts=moe_num_experts,
780
+ moe_jitter_eps=moe_jitter_eps,
781
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
782
+ uniform_expert_assignment=uniform_expert_assignment,
783
+ training=training,
784
+ w1=w1,
785
+ w2=w2,
786
+ w1_bias=w1_bias,
787
+ w2_bias=w2_bias,
788
+ gradient_scale=gradient_scale,
789
+ alpha=alpha,
790
+ sort_end_bit=sort_end_bit,
791
+ expert_parallel_group=expert_parallel_group,
792
+ moe_capacity_factor=moe_capacity_factor,
793
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
794
+ forward_fn=forward_fn,
795
+ hidden_size=hidden_size,
796
+ mlp_impl=mlp_impl,
797
+ )
798
+
799
+ # If shared expert weights provided, compute shared expert output
800
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
+ shared_expert_out = shared_mlp_forward(
802
+ x=x,
803
+ up_proj_weight=shared_up_proj_weight,
804
+ down_proj_weight=shared_down_proj_weight,
805
+ up_proj_bias=shared_up_proj_bias,
806
+ down_proj_bias=shared_down_proj_bias,
807
+ activation_fn=shared_activation_fn,
808
+ gradient_scale=gradient_scale,
809
+ )
810
+
811
+ # Combine expert outputs
812
+ combined_out = combine_expert_shared_outputs(
813
+ shared_expert_out=shared_expert_out,
814
+ expert_out=expert_out,
815
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
816
+ moe_top_k=moe_top_k,
817
+ )
818
+
819
+ return combined_out, expert_weights, router_scores
820
+
821
+ # Return regular MoE output if no shared expert
822
+ return expert_out, expert_weights, router_scores
823
+
824
+
825
+ def create_shared_expert_weights(
826
+ hidden_size: int,
827
+ shared_expert_hidden_size: int,
828
+ device: torch.device,
829
+ dtype: torch.dtype,
830
+ init_method: Any,
831
+ output_layer_init_method: Any = None,
832
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
833
+
834
+ if output_layer_init_method is None:
835
+ output_layer_init_method = init_method
836
+
837
+ # Create weight tensors
838
+ up_proj_weight = torch.empty(
839
+ shared_expert_hidden_size,
840
+ hidden_size,
841
+ device=device,
842
+ dtype=dtype,
843
+ )
844
+ down_proj_weight = torch.empty(
845
+ hidden_size,
846
+ shared_expert_hidden_size,
847
+ device=device,
848
+ dtype=dtype,
849
+ )
850
+
851
+ # Initialize weights
852
+ init_method(up_proj_weight)
853
+ output_layer_init_method(down_proj_weight)
854
+
855
+ # No bias by default
856
+ return up_proj_weight, down_proj_weight, None, None
857
+
858
+ # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
+ # This exists because device_mesh is trapped in hook closures with no model attribute
860
+ # Fragile - breaks if hook structure changes or Python internals change
861
+ # TODO: Replace with a more robust solution when available
862
+ def get_device_mesh(model):
863
+ # Extract device_mesh from child's unused pre_hook closure
864
+ try:
865
+ # Find the pre-hook that contains 'device_mesh' in its closure
866
+ hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
867
+ # Extract the device_mesh from the closure
868
+ return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
869
+ except Exception:
870
+ return None
871
+
872
+
873
  class MegaBlocksMoeMLP(torch.nn.Module):
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
+
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
886
+ if expert_parallel_group is None:
887
+ device_mesh = get_device_mesh(self)
888
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
+
890
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
 
 
916
  hidden_size=self.experts.hidden_size,
917
  mlp_impl=mlp_impl,
918
  )
919
+ return output, expert_weights_out
920
+
921
+
922
+ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
+
924
+ def __init__(self):
925
+ super().__init__()
926
+ # Shared expert weights will be set by the user
927
+ self.shared_up_proj_weight = None
928
+ self.shared_down_proj_weight = None
929
+ self.shared_up_proj_bias = None
930
+ self.shared_down_proj_bias = None
931
+ self.shared_expert_weighted_sum = False
932
+ self.shared_activation_fn = None
933
+
934
+ def set_shared_expert_weights(
935
+ self,
936
+ up_proj_weight: torch.Tensor,
937
+ down_proj_weight: torch.Tensor,
938
+ up_proj_bias: Optional[torch.Tensor] = None,
939
+ down_proj_bias: Optional[torch.Tensor] = None,
940
+ weighted_sum: bool = False,
941
+ activation_fn: Optional[Any] = None,
942
+ ):
943
+ self.shared_up_proj_weight = up_proj_weight
944
+ self.shared_down_proj_weight = down_proj_weight
945
+ self.shared_up_proj_bias = up_proj_bias
946
+ self.shared_down_proj_bias = down_proj_bias
947
+ self.shared_expert_weighted_sum = weighted_sum
948
+ self.shared_activation_fn = activation_fn
949
+
950
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
+ moe_top_k = getattr(self.router, "top_k", 4)
952
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
953
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
954
+ alpha = getattr(self.experts, "alpha", 1.0)
955
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
958
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
+
960
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
961
+ if expert_parallel_group is None:
962
+ device_mesh = get_device_mesh(self)
963
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
+
965
+ has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
966
+ forward_fn = parallel_forward_once if has_parallel else forward_once
967
+
968
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
969
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
970
+
971
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
+ x=x,
973
+ router_weight=self.router.weight,
974
+ moe_top_k=moe_top_k,
975
+ moe_num_experts=moe_num_experts,
976
+ moe_jitter_eps=moe_jitter_eps,
977
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
978
+ uniform_expert_assignment=uniform_expert_assignment,
979
+ training=self.training,
980
+ w1=self.experts.gate_up_proj,
981
+ w2=self.experts.down_proj,
982
+ w1_bias=self.experts.gate_up_proj_bias,
983
+ w2_bias=self.experts.down_proj_bias,
984
+ gradient_scale=gradient_scale,
985
+ alpha=alpha,
986
+ sort_end_bit=sort_end_bit,
987
+ expert_parallel_group=expert_parallel_group,
988
+ moe_capacity_factor=moe_capacity_factor,
989
+ moe_expert_model_parallelism=has_parallel,
990
+ forward_fn=forward_fn,
991
+ hidden_size=self.experts.hidden_size,
992
+ mlp_impl=mlp_impl,
993
+ # Shared expert parameters
994
+ shared_up_proj_weight=self.shared_up_proj_weight,
995
+ shared_down_proj_weight=self.shared_down_proj_weight,
996
+ shared_up_proj_bias=self.shared_up_proj_bias,
997
+ shared_down_proj_bias=self.shared_down_proj_bias,
998
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
+ shared_activation_fn=self.shared_activation_fn,
1000
+ )
1001
  return output, expert_weights_out
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:db7e3b7c3c15af78fe9ef0ba50c33cb2cb988bdf5dfb1f46807b7871e7c8e70e
3
  size 11931048
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5aa4e066ddbd863693ca8a5ec37fba34996226442dfa407e4a49b779497001d
3
  size 11931048
build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d41a4f5bbc160f51b058d3ba36e9087e9f15d35ae4782f36c984dd7199ee8ede
3
- size 11931048
 
 
 
 
build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_76c7de7
3
- ops = torch.ops._megablocks_76c7de7
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_76c7de7::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_89e2950
3
+ ops = torch.ops._megablocks_89e2950
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_89e2950::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
@@ -680,6 +740,136 @@ def moe_forward(
680
  return x, expert_weights, router_scores
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
691
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
692
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
693
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
694
-
695
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
 
 
 
696
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
697
  forward_fn = parallel_forward_once if has_parallel else forward_once
698
 
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
722
  hidden_size=self.experts.hidden_size,
723
  mlp_impl=mlp_impl,
724
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
  return output, expert_weights_out
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
155
+ # Shared expert MLP forward pass
156
+ def shared_mlp_forward(
157
+ x: torch.Tensor,
158
+ up_proj_weight: torch.Tensor,
159
+ down_proj_weight: torch.Tensor,
160
+ up_proj_bias: Optional[torch.Tensor] = None,
161
+ down_proj_bias: Optional[torch.Tensor] = None,
162
+ activation_fn: Optional[Any] = None,
163
+ gradient_scale: Optional[float] = None,
164
+ ) -> torch.Tensor:
165
+ # Default activation function
166
+ if activation_fn is None:
167
+ activation_fn = torch.nn.functional.gelu
168
+
169
+ # Scale weights
170
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
171
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
172
+ if up_proj_bias is not None:
173
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
174
+ if down_proj_bias is not None:
175
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
176
+
177
+ # Resolve dtensors
178
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
179
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
180
+ if up_proj_bias is not None:
181
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
182
+ if down_proj_bias is not None:
183
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
184
+
185
+ # Up projection
186
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
+
188
+ # Activation
189
+ x = activation_fn(x)
190
+
191
+ # Down projection
192
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
+
194
+ return x
195
+
196
+
197
+ # Combine outputs from shared expert and regular experts
198
+ def combine_expert_shared_outputs(
199
+ shared_expert_out: torch.Tensor,
200
+ expert_out: torch.Tensor,
201
+ shared_expert_weighted_sum: bool = False,
202
+ moe_top_k: int = 1,
203
+ ) -> torch.Tensor:
204
+ if shared_expert_weighted_sum:
205
+ # Weighted sum based on number of experts used
206
+ total_experts = moe_top_k + 1
207
+ shared_weight = 1.0 / total_experts
208
+ expert_weight = moe_top_k / total_experts
209
+ return shared_expert_out * shared_weight + expert_out * expert_weight
210
+ else:
211
+ # Simple addition
212
+ return shared_expert_out + expert_out
213
+
214
+
215
  # Global variable to store load balancing loss
216
  _LOAD_BALANCING_LOSS = []
217
 
 
740
  return x, expert_weights, router_scores
741
 
742
 
743
+ def moe_forward_with_shared_expert(
744
+ x: torch.Tensor,
745
+ router_weight: torch.Tensor,
746
+ moe_top_k: int,
747
+ moe_num_experts: int,
748
+ moe_jitter_eps: float = None,
749
+ moe_normalize_expert_weights: int = None,
750
+ uniform_expert_assignment: bool = False,
751
+ training: bool = False,
752
+ w1: torch.Tensor = None,
753
+ w2: torch.Tensor = None,
754
+ w1_bias: torch.Tensor = None,
755
+ w2_bias: torch.Tensor = None,
756
+ gradient_scale: Optional[float] = None,
757
+ alpha: float = 1.702,
758
+ sort_end_bit: int = 0,
759
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
760
+ moe_capacity_factor: float = 1.0,
761
+ moe_expert_model_parallelism: bool = False,
762
+ forward_fn: Any = None,
763
+ hidden_size: int = None,
764
+ mlp_impl: str = "grouped",
765
+ # Shared expert parameters
766
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
767
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
768
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
769
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
770
+ shared_expert_weighted_sum: bool = False,
771
+ shared_activation_fn: Optional[Any] = None,
772
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
+
774
+ # First, compute regular MoE forward pass
775
+ expert_out, expert_weights, router_scores = moe_forward(
776
+ x=x,
777
+ router_weight=router_weight,
778
+ moe_top_k=moe_top_k,
779
+ moe_num_experts=moe_num_experts,
780
+ moe_jitter_eps=moe_jitter_eps,
781
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
782
+ uniform_expert_assignment=uniform_expert_assignment,
783
+ training=training,
784
+ w1=w1,
785
+ w2=w2,
786
+ w1_bias=w1_bias,
787
+ w2_bias=w2_bias,
788
+ gradient_scale=gradient_scale,
789
+ alpha=alpha,
790
+ sort_end_bit=sort_end_bit,
791
+ expert_parallel_group=expert_parallel_group,
792
+ moe_capacity_factor=moe_capacity_factor,
793
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
794
+ forward_fn=forward_fn,
795
+ hidden_size=hidden_size,
796
+ mlp_impl=mlp_impl,
797
+ )
798
+
799
+ # If shared expert weights provided, compute shared expert output
800
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
+ shared_expert_out = shared_mlp_forward(
802
+ x=x,
803
+ up_proj_weight=shared_up_proj_weight,
804
+ down_proj_weight=shared_down_proj_weight,
805
+ up_proj_bias=shared_up_proj_bias,
806
+ down_proj_bias=shared_down_proj_bias,
807
+ activation_fn=shared_activation_fn,
808
+ gradient_scale=gradient_scale,
809
+ )
810
+
811
+ # Combine expert outputs
812
+ combined_out = combine_expert_shared_outputs(
813
+ shared_expert_out=shared_expert_out,
814
+ expert_out=expert_out,
815
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
816
+ moe_top_k=moe_top_k,
817
+ )
818
+
819
+ return combined_out, expert_weights, router_scores
820
+
821
+ # Return regular MoE output if no shared expert
822
+ return expert_out, expert_weights, router_scores
823
+
824
+
825
+ def create_shared_expert_weights(
826
+ hidden_size: int,
827
+ shared_expert_hidden_size: int,
828
+ device: torch.device,
829
+ dtype: torch.dtype,
830
+ init_method: Any,
831
+ output_layer_init_method: Any = None,
832
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
833
+
834
+ if output_layer_init_method is None:
835
+ output_layer_init_method = init_method
836
+
837
+ # Create weight tensors
838
+ up_proj_weight = torch.empty(
839
+ shared_expert_hidden_size,
840
+ hidden_size,
841
+ device=device,
842
+ dtype=dtype,
843
+ )
844
+ down_proj_weight = torch.empty(
845
+ hidden_size,
846
+ shared_expert_hidden_size,
847
+ device=device,
848
+ dtype=dtype,
849
+ )
850
+
851
+ # Initialize weights
852
+ init_method(up_proj_weight)
853
+ output_layer_init_method(down_proj_weight)
854
+
855
+ # No bias by default
856
+ return up_proj_weight, down_proj_weight, None, None
857
+
858
+ # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
+ # This exists because device_mesh is trapped in hook closures with no model attribute
860
+ # Fragile - breaks if hook structure changes or Python internals change
861
+ # TODO: Replace with a more robust solution when available
862
+ def get_device_mesh(model):
863
+ # Extract device_mesh from child's unused pre_hook closure
864
+ try:
865
+ # Find the pre-hook that contains 'device_mesh' in its closure
866
+ hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
867
+ # Extract the device_mesh from the closure
868
+ return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
869
+ except Exception:
870
+ return None
871
+
872
+
873
  class MegaBlocksMoeMLP(torch.nn.Module):
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
+
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
886
+ if expert_parallel_group is None:
887
+ device_mesh = get_device_mesh(self)
888
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
+
890
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
 
 
916
  hidden_size=self.experts.hidden_size,
917
  mlp_impl=mlp_impl,
918
  )
919
+ return output, expert_weights_out
920
+
921
+
922
+ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
+
924
+ def __init__(self):
925
+ super().__init__()
926
+ # Shared expert weights will be set by the user
927
+ self.shared_up_proj_weight = None
928
+ self.shared_down_proj_weight = None
929
+ self.shared_up_proj_bias = None
930
+ self.shared_down_proj_bias = None
931
+ self.shared_expert_weighted_sum = False
932
+ self.shared_activation_fn = None
933
+
934
+ def set_shared_expert_weights(
935
+ self,
936
+ up_proj_weight: torch.Tensor,
937
+ down_proj_weight: torch.Tensor,
938
+ up_proj_bias: Optional[torch.Tensor] = None,
939
+ down_proj_bias: Optional[torch.Tensor] = None,
940
+ weighted_sum: bool = False,
941
+ activation_fn: Optional[Any] = None,
942
+ ):
943
+ self.shared_up_proj_weight = up_proj_weight
944
+ self.shared_down_proj_weight = down_proj_weight
945
+ self.shared_up_proj_bias = up_proj_bias
946
+ self.shared_down_proj_bias = down_proj_bias
947
+ self.shared_expert_weighted_sum = weighted_sum
948
+ self.shared_activation_fn = activation_fn
949
+
950
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
+ moe_top_k = getattr(self.router, "top_k", 4)
952
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
953
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
954
+ alpha = getattr(self.experts, "alpha", 1.0)
955
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
958
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
+
960
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
961
+ if expert_parallel_group is None:
962
+ device_mesh = get_device_mesh(self)
963
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
+
965
+ has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
966
+ forward_fn = parallel_forward_once if has_parallel else forward_once
967
+
968
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
969
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
970
+
971
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
+ x=x,
973
+ router_weight=self.router.weight,
974
+ moe_top_k=moe_top_k,
975
+ moe_num_experts=moe_num_experts,
976
+ moe_jitter_eps=moe_jitter_eps,
977
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
978
+ uniform_expert_assignment=uniform_expert_assignment,
979
+ training=self.training,
980
+ w1=self.experts.gate_up_proj,
981
+ w2=self.experts.down_proj,
982
+ w1_bias=self.experts.gate_up_proj_bias,
983
+ w2_bias=self.experts.down_proj_bias,
984
+ gradient_scale=gradient_scale,
985
+ alpha=alpha,
986
+ sort_end_bit=sort_end_bit,
987
+ expert_parallel_group=expert_parallel_group,
988
+ moe_capacity_factor=moe_capacity_factor,
989
+ moe_expert_model_parallelism=has_parallel,
990
+ forward_fn=forward_fn,
991
+ hidden_size=self.experts.hidden_size,
992
+ mlp_impl=mlp_impl,
993
+ # Shared expert parameters
994
+ shared_up_proj_weight=self.shared_up_proj_weight,
995
+ shared_down_proj_weight=self.shared_down_proj_weight,
996
+ shared_up_proj_bias=self.shared_up_proj_bias,
997
+ shared_down_proj_bias=self.shared_down_proj_bias,
998
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
+ shared_activation_fn=self.shared_activation_fn,
1000
+ )
1001
  return output, expert_weights_out
build/torch26-cxx98-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d9971a30d397598ee0a58118b8cca337d142de1ca34404532dfda6328122ab11
3
  size 10510040
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fababa7e0d2c20c98afaebef6165a8145b33d80cdadba28f895c14dd2a7b2823
3
  size 10510040
build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:01f0c774e900380d3c0721dfe15591c67be5d5eb5ad687af6c89a88ecdff4f2a
3
- size 10510040
 
 
 
 
build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_76c7de7
3
- ops = torch.ops._megablocks_76c7de7
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_76c7de7::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_89e2950
3
+ ops = torch.ops._megablocks_89e2950
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_89e2950::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
@@ -680,6 +740,136 @@ def moe_forward(
680
  return x, expert_weights, router_scores
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
691
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
692
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
693
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
694
-
695
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
 
 
 
696
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
697
  forward_fn = parallel_forward_once if has_parallel else forward_once
698
 
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
722
  hidden_size=self.experts.hidden_size,
723
  mlp_impl=mlp_impl,
724
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
  return output, expert_weights_out
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
155
+ # Shared expert MLP forward pass
156
+ def shared_mlp_forward(
157
+ x: torch.Tensor,
158
+ up_proj_weight: torch.Tensor,
159
+ down_proj_weight: torch.Tensor,
160
+ up_proj_bias: Optional[torch.Tensor] = None,
161
+ down_proj_bias: Optional[torch.Tensor] = None,
162
+ activation_fn: Optional[Any] = None,
163
+ gradient_scale: Optional[float] = None,
164
+ ) -> torch.Tensor:
165
+ # Default activation function
166
+ if activation_fn is None:
167
+ activation_fn = torch.nn.functional.gelu
168
+
169
+ # Scale weights
170
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
171
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
172
+ if up_proj_bias is not None:
173
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
174
+ if down_proj_bias is not None:
175
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
176
+
177
+ # Resolve dtensors
178
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
179
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
180
+ if up_proj_bias is not None:
181
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
182
+ if down_proj_bias is not None:
183
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
184
+
185
+ # Up projection
186
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
+
188
+ # Activation
189
+ x = activation_fn(x)
190
+
191
+ # Down projection
192
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
+
194
+ return x
195
+
196
+
197
+ # Combine outputs from shared expert and regular experts
198
+ def combine_expert_shared_outputs(
199
+ shared_expert_out: torch.Tensor,
200
+ expert_out: torch.Tensor,
201
+ shared_expert_weighted_sum: bool = False,
202
+ moe_top_k: int = 1,
203
+ ) -> torch.Tensor:
204
+ if shared_expert_weighted_sum:
205
+ # Weighted sum based on number of experts used
206
+ total_experts = moe_top_k + 1
207
+ shared_weight = 1.0 / total_experts
208
+ expert_weight = moe_top_k / total_experts
209
+ return shared_expert_out * shared_weight + expert_out * expert_weight
210
+ else:
211
+ # Simple addition
212
+ return shared_expert_out + expert_out
213
+
214
+
215
  # Global variable to store load balancing loss
216
  _LOAD_BALANCING_LOSS = []
217
 
 
740
  return x, expert_weights, router_scores
741
 
742
 
743
+ def moe_forward_with_shared_expert(
744
+ x: torch.Tensor,
745
+ router_weight: torch.Tensor,
746
+ moe_top_k: int,
747
+ moe_num_experts: int,
748
+ moe_jitter_eps: float = None,
749
+ moe_normalize_expert_weights: int = None,
750
+ uniform_expert_assignment: bool = False,
751
+ training: bool = False,
752
+ w1: torch.Tensor = None,
753
+ w2: torch.Tensor = None,
754
+ w1_bias: torch.Tensor = None,
755
+ w2_bias: torch.Tensor = None,
756
+ gradient_scale: Optional[float] = None,
757
+ alpha: float = 1.702,
758
+ sort_end_bit: int = 0,
759
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
760
+ moe_capacity_factor: float = 1.0,
761
+ moe_expert_model_parallelism: bool = False,
762
+ forward_fn: Any = None,
763
+ hidden_size: int = None,
764
+ mlp_impl: str = "grouped",
765
+ # Shared expert parameters
766
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
767
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
768
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
769
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
770
+ shared_expert_weighted_sum: bool = False,
771
+ shared_activation_fn: Optional[Any] = None,
772
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
+
774
+ # First, compute regular MoE forward pass
775
+ expert_out, expert_weights, router_scores = moe_forward(
776
+ x=x,
777
+ router_weight=router_weight,
778
+ moe_top_k=moe_top_k,
779
+ moe_num_experts=moe_num_experts,
780
+ moe_jitter_eps=moe_jitter_eps,
781
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
782
+ uniform_expert_assignment=uniform_expert_assignment,
783
+ training=training,
784
+ w1=w1,
785
+ w2=w2,
786
+ w1_bias=w1_bias,
787
+ w2_bias=w2_bias,
788
+ gradient_scale=gradient_scale,
789
+ alpha=alpha,
790
+ sort_end_bit=sort_end_bit,
791
+ expert_parallel_group=expert_parallel_group,
792
+ moe_capacity_factor=moe_capacity_factor,
793
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
794
+ forward_fn=forward_fn,
795
+ hidden_size=hidden_size,
796
+ mlp_impl=mlp_impl,
797
+ )
798
+
799
+ # If shared expert weights provided, compute shared expert output
800
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
+ shared_expert_out = shared_mlp_forward(
802
+ x=x,
803
+ up_proj_weight=shared_up_proj_weight,
804
+ down_proj_weight=shared_down_proj_weight,
805
+ up_proj_bias=shared_up_proj_bias,
806
+ down_proj_bias=shared_down_proj_bias,
807
+ activation_fn=shared_activation_fn,
808
+ gradient_scale=gradient_scale,
809
+ )
810
+
811
+ # Combine expert outputs
812
+ combined_out = combine_expert_shared_outputs(
813
+ shared_expert_out=shared_expert_out,
814
+ expert_out=expert_out,
815
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
816
+ moe_top_k=moe_top_k,
817
+ )
818
+
819
+ return combined_out, expert_weights, router_scores
820
+
821
+ # Return regular MoE output if no shared expert
822
+ return expert_out, expert_weights, router_scores
823
+
824
+
825
+ def create_shared_expert_weights(
826
+ hidden_size: int,
827
+ shared_expert_hidden_size: int,
828
+ device: torch.device,
829
+ dtype: torch.dtype,
830
+ init_method: Any,
831
+ output_layer_init_method: Any = None,
832
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
833
+
834
+ if output_layer_init_method is None:
835
+ output_layer_init_method = init_method
836
+
837
+ # Create weight tensors
838
+ up_proj_weight = torch.empty(
839
+ shared_expert_hidden_size,
840
+ hidden_size,
841
+ device=device,
842
+ dtype=dtype,
843
+ )
844
+ down_proj_weight = torch.empty(
845
+ hidden_size,
846
+ shared_expert_hidden_size,
847
+ device=device,
848
+ dtype=dtype,
849
+ )
850
+
851
+ # Initialize weights
852
+ init_method(up_proj_weight)
853
+ output_layer_init_method(down_proj_weight)
854
+
855
+ # No bias by default
856
+ return up_proj_weight, down_proj_weight, None, None
857
+
858
+ # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
+ # This exists because device_mesh is trapped in hook closures with no model attribute
860
+ # Fragile - breaks if hook structure changes or Python internals change
861
+ # TODO: Replace with a more robust solution when available
862
+ def get_device_mesh(model):
863
+ # Extract device_mesh from child's unused pre_hook closure
864
+ try:
865
+ # Find the pre-hook that contains 'device_mesh' in its closure
866
+ hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
867
+ # Extract the device_mesh from the closure
868
+ return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
869
+ except Exception:
870
+ return None
871
+
872
+
873
  class MegaBlocksMoeMLP(torch.nn.Module):
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
+
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
886
+ if expert_parallel_group is None:
887
+ device_mesh = get_device_mesh(self)
888
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
+
890
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
 
 
916
  hidden_size=self.experts.hidden_size,
917
  mlp_impl=mlp_impl,
918
  )
919
+ return output, expert_weights_out
920
+
921
+
922
+ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
+
924
+ def __init__(self):
925
+ super().__init__()
926
+ # Shared expert weights will be set by the user
927
+ self.shared_up_proj_weight = None
928
+ self.shared_down_proj_weight = None
929
+ self.shared_up_proj_bias = None
930
+ self.shared_down_proj_bias = None
931
+ self.shared_expert_weighted_sum = False
932
+ self.shared_activation_fn = None
933
+
934
+ def set_shared_expert_weights(
935
+ self,
936
+ up_proj_weight: torch.Tensor,
937
+ down_proj_weight: torch.Tensor,
938
+ up_proj_bias: Optional[torch.Tensor] = None,
939
+ down_proj_bias: Optional[torch.Tensor] = None,
940
+ weighted_sum: bool = False,
941
+ activation_fn: Optional[Any] = None,
942
+ ):
943
+ self.shared_up_proj_weight = up_proj_weight
944
+ self.shared_down_proj_weight = down_proj_weight
945
+ self.shared_up_proj_bias = up_proj_bias
946
+ self.shared_down_proj_bias = down_proj_bias
947
+ self.shared_expert_weighted_sum = weighted_sum
948
+ self.shared_activation_fn = activation_fn
949
+
950
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
+ moe_top_k = getattr(self.router, "top_k", 4)
952
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
953
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
954
+ alpha = getattr(self.experts, "alpha", 1.0)
955
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
958
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
+
960
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
961
+ if expert_parallel_group is None:
962
+ device_mesh = get_device_mesh(self)
963
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
+
965
+ has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
966
+ forward_fn = parallel_forward_once if has_parallel else forward_once
967
+
968
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
969
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
970
+
971
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
+ x=x,
973
+ router_weight=self.router.weight,
974
+ moe_top_k=moe_top_k,
975
+ moe_num_experts=moe_num_experts,
976
+ moe_jitter_eps=moe_jitter_eps,
977
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
978
+ uniform_expert_assignment=uniform_expert_assignment,
979
+ training=self.training,
980
+ w1=self.experts.gate_up_proj,
981
+ w2=self.experts.down_proj,
982
+ w1_bias=self.experts.gate_up_proj_bias,
983
+ w2_bias=self.experts.down_proj_bias,
984
+ gradient_scale=gradient_scale,
985
+ alpha=alpha,
986
+ sort_end_bit=sort_end_bit,
987
+ expert_parallel_group=expert_parallel_group,
988
+ moe_capacity_factor=moe_capacity_factor,
989
+ moe_expert_model_parallelism=has_parallel,
990
+ forward_fn=forward_fn,
991
+ hidden_size=self.experts.hidden_size,
992
+ mlp_impl=mlp_impl,
993
+ # Shared expert parameters
994
+ shared_up_proj_weight=self.shared_up_proj_weight,
995
+ shared_down_proj_weight=self.shared_down_proj_weight,
996
+ shared_up_proj_bias=self.shared_up_proj_bias,
997
+ shared_down_proj_bias=self.shared_down_proj_bias,
998
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
+ shared_activation_fn=self.shared_activation_fn,
1000
+ )
1001
  return output, expert_weights_out
build/torch26-cxx98-cu124-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c136e90b35e7fd43fcc4d987588f68b3f4cfea295a00f1fda343acc9c8848577
3
  size 11857920
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e3663f46030f07e030efe94c26495d17b2703551a46c0ca3acf8b25ecb2a238
3
  size 11857920
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:09a5f57ae37af9f5b14c4a0f21d1679e32f5b7424973c36dac9bbbecbfbf7374
3
- size 11857920
 
 
 
 
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_76c7de7
3
- ops = torch.ops._megablocks_76c7de7
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_76c7de7::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_89e2950
3
+ ops = torch.ops._megablocks_89e2950
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_89e2950::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
@@ -680,6 +740,136 @@ def moe_forward(
680
  return x, expert_weights, router_scores
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
691
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
692
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
693
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
694
-
695
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
 
 
 
696
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
697
  forward_fn = parallel_forward_once if has_parallel else forward_once
698
 
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
722
  hidden_size=self.experts.hidden_size,
723
  mlp_impl=mlp_impl,
724
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
  return output, expert_weights_out
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
155
+ # Shared expert MLP forward pass
156
+ def shared_mlp_forward(
157
+ x: torch.Tensor,
158
+ up_proj_weight: torch.Tensor,
159
+ down_proj_weight: torch.Tensor,
160
+ up_proj_bias: Optional[torch.Tensor] = None,
161
+ down_proj_bias: Optional[torch.Tensor] = None,
162
+ activation_fn: Optional[Any] = None,
163
+ gradient_scale: Optional[float] = None,
164
+ ) -> torch.Tensor:
165
+ # Default activation function
166
+ if activation_fn is None:
167
+ activation_fn = torch.nn.functional.gelu
168
+
169
+ # Scale weights
170
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
171
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
172
+ if up_proj_bias is not None:
173
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
174
+ if down_proj_bias is not None:
175
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
176
+
177
+ # Resolve dtensors
178
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
179
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
180
+ if up_proj_bias is not None:
181
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
182
+ if down_proj_bias is not None:
183
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
184
+
185
+ # Up projection
186
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
+
188
+ # Activation
189
+ x = activation_fn(x)
190
+
191
+ # Down projection
192
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
+
194
+ return x
195
+
196
+
197
+ # Combine outputs from shared expert and regular experts
198
+ def combine_expert_shared_outputs(
199
+ shared_expert_out: torch.Tensor,
200
+ expert_out: torch.Tensor,
201
+ shared_expert_weighted_sum: bool = False,
202
+ moe_top_k: int = 1,
203
+ ) -> torch.Tensor:
204
+ if shared_expert_weighted_sum:
205
+ # Weighted sum based on number of experts used
206
+ total_experts = moe_top_k + 1
207
+ shared_weight = 1.0 / total_experts
208
+ expert_weight = moe_top_k / total_experts
209
+ return shared_expert_out * shared_weight + expert_out * expert_weight
210
+ else:
211
+ # Simple addition
212
+ return shared_expert_out + expert_out
213
+
214
+
215
  # Global variable to store load balancing loss
216
  _LOAD_BALANCING_LOSS = []
217
 
 
740
  return x, expert_weights, router_scores
741
 
742
 
743
+ def moe_forward_with_shared_expert(
744
+ x: torch.Tensor,
745
+ router_weight: torch.Tensor,
746
+ moe_top_k: int,
747
+ moe_num_experts: int,
748
+ moe_jitter_eps: float = None,
749
+ moe_normalize_expert_weights: int = None,
750
+ uniform_expert_assignment: bool = False,
751
+ training: bool = False,
752
+ w1: torch.Tensor = None,
753
+ w2: torch.Tensor = None,
754
+ w1_bias: torch.Tensor = None,
755
+ w2_bias: torch.Tensor = None,
756
+ gradient_scale: Optional[float] = None,
757
+ alpha: float = 1.702,
758
+ sort_end_bit: int = 0,
759
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
760
+ moe_capacity_factor: float = 1.0,
761
+ moe_expert_model_parallelism: bool = False,
762
+ forward_fn: Any = None,
763
+ hidden_size: int = None,
764
+ mlp_impl: str = "grouped",
765
+ # Shared expert parameters
766
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
767
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
768
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
769
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
770
+ shared_expert_weighted_sum: bool = False,
771
+ shared_activation_fn: Optional[Any] = None,
772
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
+
774
+ # First, compute regular MoE forward pass
775
+ expert_out, expert_weights, router_scores = moe_forward(
776
+ x=x,
777
+ router_weight=router_weight,
778
+ moe_top_k=moe_top_k,
779
+ moe_num_experts=moe_num_experts,
780
+ moe_jitter_eps=moe_jitter_eps,
781
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
782
+ uniform_expert_assignment=uniform_expert_assignment,
783
+ training=training,
784
+ w1=w1,
785
+ w2=w2,
786
+ w1_bias=w1_bias,
787
+ w2_bias=w2_bias,
788
+ gradient_scale=gradient_scale,
789
+ alpha=alpha,
790
+ sort_end_bit=sort_end_bit,
791
+ expert_parallel_group=expert_parallel_group,
792
+ moe_capacity_factor=moe_capacity_factor,
793
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
794
+ forward_fn=forward_fn,
795
+ hidden_size=hidden_size,
796
+ mlp_impl=mlp_impl,
797
+ )
798
+
799
+ # If shared expert weights provided, compute shared expert output
800
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
+ shared_expert_out = shared_mlp_forward(
802
+ x=x,
803
+ up_proj_weight=shared_up_proj_weight,
804
+ down_proj_weight=shared_down_proj_weight,
805
+ up_proj_bias=shared_up_proj_bias,
806
+ down_proj_bias=shared_down_proj_bias,
807
+ activation_fn=shared_activation_fn,
808
+ gradient_scale=gradient_scale,
809
+ )
810
+
811
+ # Combine expert outputs
812
+ combined_out = combine_expert_shared_outputs(
813
+ shared_expert_out=shared_expert_out,
814
+ expert_out=expert_out,
815
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
816
+ moe_top_k=moe_top_k,
817
+ )
818
+
819
+ return combined_out, expert_weights, router_scores
820
+
821
+ # Return regular MoE output if no shared expert
822
+ return expert_out, expert_weights, router_scores
823
+
824
+
825
+ def create_shared_expert_weights(
826
+ hidden_size: int,
827
+ shared_expert_hidden_size: int,
828
+ device: torch.device,
829
+ dtype: torch.dtype,
830
+ init_method: Any,
831
+ output_layer_init_method: Any = None,
832
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
833
+
834
+ if output_layer_init_method is None:
835
+ output_layer_init_method = init_method
836
+
837
+ # Create weight tensors
838
+ up_proj_weight = torch.empty(
839
+ shared_expert_hidden_size,
840
+ hidden_size,
841
+ device=device,
842
+ dtype=dtype,
843
+ )
844
+ down_proj_weight = torch.empty(
845
+ hidden_size,
846
+ shared_expert_hidden_size,
847
+ device=device,
848
+ dtype=dtype,
849
+ )
850
+
851
+ # Initialize weights
852
+ init_method(up_proj_weight)
853
+ output_layer_init_method(down_proj_weight)
854
+
855
+ # No bias by default
856
+ return up_proj_weight, down_proj_weight, None, None
857
+
858
+ # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
+ # This exists because device_mesh is trapped in hook closures with no model attribute
860
+ # Fragile - breaks if hook structure changes or Python internals change
861
+ # TODO: Replace with a more robust solution when available
862
+ def get_device_mesh(model):
863
+ # Extract device_mesh from child's unused pre_hook closure
864
+ try:
865
+ # Find the pre-hook that contains 'device_mesh' in its closure
866
+ hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
867
+ # Extract the device_mesh from the closure
868
+ return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
869
+ except Exception:
870
+ return None
871
+
872
+
873
  class MegaBlocksMoeMLP(torch.nn.Module):
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
+
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
886
+ if expert_parallel_group is None:
887
+ device_mesh = get_device_mesh(self)
888
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
+
890
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
 
 
916
  hidden_size=self.experts.hidden_size,
917
  mlp_impl=mlp_impl,
918
  )
919
+ return output, expert_weights_out
920
+
921
+
922
+ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
+
924
+ def __init__(self):
925
+ super().__init__()
926
+ # Shared expert weights will be set by the user
927
+ self.shared_up_proj_weight = None
928
+ self.shared_down_proj_weight = None
929
+ self.shared_up_proj_bias = None
930
+ self.shared_down_proj_bias = None
931
+ self.shared_expert_weighted_sum = False
932
+ self.shared_activation_fn = None
933
+
934
+ def set_shared_expert_weights(
935
+ self,
936
+ up_proj_weight: torch.Tensor,
937
+ down_proj_weight: torch.Tensor,
938
+ up_proj_bias: Optional[torch.Tensor] = None,
939
+ down_proj_bias: Optional[torch.Tensor] = None,
940
+ weighted_sum: bool = False,
941
+ activation_fn: Optional[Any] = None,
942
+ ):
943
+ self.shared_up_proj_weight = up_proj_weight
944
+ self.shared_down_proj_weight = down_proj_weight
945
+ self.shared_up_proj_bias = up_proj_bias
946
+ self.shared_down_proj_bias = down_proj_bias
947
+ self.shared_expert_weighted_sum = weighted_sum
948
+ self.shared_activation_fn = activation_fn
949
+
950
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
+ moe_top_k = getattr(self.router, "top_k", 4)
952
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
953
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
954
+ alpha = getattr(self.experts, "alpha", 1.0)
955
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
958
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
+
960
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
961
+ if expert_parallel_group is None:
962
+ device_mesh = get_device_mesh(self)
963
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
+
965
+ has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
966
+ forward_fn = parallel_forward_once if has_parallel else forward_once
967
+
968
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
969
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
970
+
971
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
+ x=x,
973
+ router_weight=self.router.weight,
974
+ moe_top_k=moe_top_k,
975
+ moe_num_experts=moe_num_experts,
976
+ moe_jitter_eps=moe_jitter_eps,
977
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
978
+ uniform_expert_assignment=uniform_expert_assignment,
979
+ training=self.training,
980
+ w1=self.experts.gate_up_proj,
981
+ w2=self.experts.down_proj,
982
+ w1_bias=self.experts.gate_up_proj_bias,
983
+ w2_bias=self.experts.down_proj_bias,
984
+ gradient_scale=gradient_scale,
985
+ alpha=alpha,
986
+ sort_end_bit=sort_end_bit,
987
+ expert_parallel_group=expert_parallel_group,
988
+ moe_capacity_factor=moe_capacity_factor,
989
+ moe_expert_model_parallelism=has_parallel,
990
+ forward_fn=forward_fn,
991
+ hidden_size=self.experts.hidden_size,
992
+ mlp_impl=mlp_impl,
993
+ # Shared expert parameters
994
+ shared_up_proj_weight=self.shared_up_proj_weight,
995
+ shared_down_proj_weight=self.shared_down_proj_weight,
996
+ shared_up_proj_bias=self.shared_up_proj_bias,
997
+ shared_down_proj_bias=self.shared_down_proj_bias,
998
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
+ shared_activation_fn=self.shared_activation_fn,
1000
+ )
1001
  return output, expert_weights_out
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_76c7de7.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a3f893773ec7b8157a4531a57821807f5f27ac48ceaa695c342cc7a39ad318dc
3
- size 11927768
 
 
 
 
build/torch26-cxx98-cu126-x86_64-linux/megablocks/{_megablocks_9a1816c.abi3.so β†’ _megablocks_89e2950.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d155b22a3a413d23e1d1b6f65fd3700b2e004e45daf1cca1b397b8e0b4d68616
3
  size 11923672
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1571732c5954914d5ddf0f12ebc4074d88d907130d71d898de43958e3b9a5d1
3
  size 11923672
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_76c7de7
3
- ops = torch.ops._megablocks_76c7de7
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_76c7de7::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_89e2950
3
+ ops = torch.ops._megablocks_89e2950
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_89e2950::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
@@ -680,6 +740,136 @@ def moe_forward(
680
  return x, expert_weights, router_scores
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
691
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
692
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
693
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
694
-
695
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
 
 
 
696
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
697
  forward_fn = parallel_forward_once if has_parallel else forward_once
698
 
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
722
  hidden_size=self.experts.hidden_size,
723
  mlp_impl=mlp_impl,
724
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
  return output, expert_weights_out
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
155
+ # Shared expert MLP forward pass
156
+ def shared_mlp_forward(
157
+ x: torch.Tensor,
158
+ up_proj_weight: torch.Tensor,
159
+ down_proj_weight: torch.Tensor,
160
+ up_proj_bias: Optional[torch.Tensor] = None,
161
+ down_proj_bias: Optional[torch.Tensor] = None,
162
+ activation_fn: Optional[Any] = None,
163
+ gradient_scale: Optional[float] = None,
164
+ ) -> torch.Tensor:
165
+ # Default activation function
166
+ if activation_fn is None:
167
+ activation_fn = torch.nn.functional.gelu
168
+
169
+ # Scale weights
170
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
171
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
172
+ if up_proj_bias is not None:
173
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
174
+ if down_proj_bias is not None:
175
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
176
+
177
+ # Resolve dtensors
178
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
179
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
180
+ if up_proj_bias is not None:
181
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
182
+ if down_proj_bias is not None:
183
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
184
+
185
+ # Up projection
186
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
+
188
+ # Activation
189
+ x = activation_fn(x)
190
+
191
+ # Down projection
192
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
+
194
+ return x
195
+
196
+
197
+ # Combine outputs from shared expert and regular experts
198
+ def combine_expert_shared_outputs(
199
+ shared_expert_out: torch.Tensor,
200
+ expert_out: torch.Tensor,
201
+ shared_expert_weighted_sum: bool = False,
202
+ moe_top_k: int = 1,
203
+ ) -> torch.Tensor:
204
+ if shared_expert_weighted_sum:
205
+ # Weighted sum based on number of experts used
206
+ total_experts = moe_top_k + 1
207
+ shared_weight = 1.0 / total_experts
208
+ expert_weight = moe_top_k / total_experts
209
+ return shared_expert_out * shared_weight + expert_out * expert_weight
210
+ else:
211
+ # Simple addition
212
+ return shared_expert_out + expert_out
213
+
214
+
215
  # Global variable to store load balancing loss
216
  _LOAD_BALANCING_LOSS = []
217
 
 
740
  return x, expert_weights, router_scores
741
 
742
 
743
+ def moe_forward_with_shared_expert(
744
+ x: torch.Tensor,
745
+ router_weight: torch.Tensor,
746
+ moe_top_k: int,
747
+ moe_num_experts: int,
748
+ moe_jitter_eps: float = None,
749
+ moe_normalize_expert_weights: int = None,
750
+ uniform_expert_assignment: bool = False,
751
+ training: bool = False,
752
+ w1: torch.Tensor = None,
753
+ w2: torch.Tensor = None,
754
+ w1_bias: torch.Tensor = None,
755
+ w2_bias: torch.Tensor = None,
756
+ gradient_scale: Optional[float] = None,
757
+ alpha: float = 1.702,
758
+ sort_end_bit: int = 0,
759
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
760
+ moe_capacity_factor: float = 1.0,
761
+ moe_expert_model_parallelism: bool = False,
762
+ forward_fn: Any = None,
763
+ hidden_size: int = None,
764
+ mlp_impl: str = "grouped",
765
+ # Shared expert parameters
766
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
767
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
768
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
769
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
770
+ shared_expert_weighted_sum: bool = False,
771
+ shared_activation_fn: Optional[Any] = None,
772
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
+
774
+ # First, compute regular MoE forward pass
775
+ expert_out, expert_weights, router_scores = moe_forward(
776
+ x=x,
777
+ router_weight=router_weight,
778
+ moe_top_k=moe_top_k,
779
+ moe_num_experts=moe_num_experts,
780
+ moe_jitter_eps=moe_jitter_eps,
781
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
782
+ uniform_expert_assignment=uniform_expert_assignment,
783
+ training=training,
784
+ w1=w1,
785
+ w2=w2,
786
+ w1_bias=w1_bias,
787
+ w2_bias=w2_bias,
788
+ gradient_scale=gradient_scale,
789
+ alpha=alpha,
790
+ sort_end_bit=sort_end_bit,
791
+ expert_parallel_group=expert_parallel_group,
792
+ moe_capacity_factor=moe_capacity_factor,
793
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
794
+ forward_fn=forward_fn,
795
+ hidden_size=hidden_size,
796
+ mlp_impl=mlp_impl,
797
+ )
798
+
799
+ # If shared expert weights provided, compute shared expert output
800
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
+ shared_expert_out = shared_mlp_forward(
802
+ x=x,
803
+ up_proj_weight=shared_up_proj_weight,
804
+ down_proj_weight=shared_down_proj_weight,
805
+ up_proj_bias=shared_up_proj_bias,
806
+ down_proj_bias=shared_down_proj_bias,
807
+ activation_fn=shared_activation_fn,
808
+ gradient_scale=gradient_scale,
809
+ )
810
+
811
+ # Combine expert outputs
812
+ combined_out = combine_expert_shared_outputs(
813
+ shared_expert_out=shared_expert_out,
814
+ expert_out=expert_out,
815
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
816
+ moe_top_k=moe_top_k,
817
+ )
818
+
819
+ return combined_out, expert_weights, router_scores
820
+
821
+ # Return regular MoE output if no shared expert
822
+ return expert_out, expert_weights, router_scores
823
+
824
+
825
+ def create_shared_expert_weights(
826
+ hidden_size: int,
827
+ shared_expert_hidden_size: int,
828
+ device: torch.device,
829
+ dtype: torch.dtype,
830
+ init_method: Any,
831
+ output_layer_init_method: Any = None,
832
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
833
+
834
+ if output_layer_init_method is None:
835
+ output_layer_init_method = init_method
836
+
837
+ # Create weight tensors
838
+ up_proj_weight = torch.empty(
839
+ shared_expert_hidden_size,
840
+ hidden_size,
841
+ device=device,
842
+ dtype=dtype,
843
+ )
844
+ down_proj_weight = torch.empty(
845
+ hidden_size,
846
+ shared_expert_hidden_size,
847
+ device=device,
848
+ dtype=dtype,
849
+ )
850
+
851
+ # Initialize weights
852
+ init_method(up_proj_weight)
853
+ output_layer_init_method(down_proj_weight)
854
+
855
+ # No bias by default
856
+ return up_proj_weight, down_proj_weight, None, None
857
+
858
+ # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
+ # This exists because device_mesh is trapped in hook closures with no model attribute
860
+ # Fragile - breaks if hook structure changes or Python internals change
861
+ # TODO: Replace with a more robust solution when available
862
+ def get_device_mesh(model):
863
+ # Extract device_mesh from child's unused pre_hook closure
864
+ try:
865
+ # Find the pre-hook that contains 'device_mesh' in its closure
866
+ hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
867
+ # Extract the device_mesh from the closure
868
+ return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
869
+ except Exception:
870
+ return None
871
+
872
+
873
  class MegaBlocksMoeMLP(torch.nn.Module):
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
+
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
886
+ if expert_parallel_group is None:
887
+ device_mesh = get_device_mesh(self)
888
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
+
890
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
 
 
916
  hidden_size=self.experts.hidden_size,
917
  mlp_impl=mlp_impl,
918
  )
919
+ return output, expert_weights_out
920
+
921
+
922
+ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
+
924
+ def __init__(self):
925
+ super().__init__()
926
+ # Shared expert weights will be set by the user
927
+ self.shared_up_proj_weight = None
928
+ self.shared_down_proj_weight = None
929
+ self.shared_up_proj_bias = None
930
+ self.shared_down_proj_bias = None
931
+ self.shared_expert_weighted_sum = False
932
+ self.shared_activation_fn = None
933
+
934
+ def set_shared_expert_weights(
935
+ self,
936
+ up_proj_weight: torch.Tensor,
937
+ down_proj_weight: torch.Tensor,
938
+ up_proj_bias: Optional[torch.Tensor] = None,
939
+ down_proj_bias: Optional[torch.Tensor] = None,
940
+ weighted_sum: bool = False,
941
+ activation_fn: Optional[Any] = None,
942
+ ):
943
+ self.shared_up_proj_weight = up_proj_weight
944
+ self.shared_down_proj_weight = down_proj_weight
945
+ self.shared_up_proj_bias = up_proj_bias
946
+ self.shared_down_proj_bias = down_proj_bias
947
+ self.shared_expert_weighted_sum = weighted_sum
948
+ self.shared_activation_fn = activation_fn
949
+
950
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
+ moe_top_k = getattr(self.router, "top_k", 4)
952
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
953
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
954
+ alpha = getattr(self.experts, "alpha", 1.0)
955
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
958
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
+
960
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
961
+ if expert_parallel_group is None:
962
+ device_mesh = get_device_mesh(self)
963
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
+
965
+ has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
966
+ forward_fn = parallel_forward_once if has_parallel else forward_once
967
+
968
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
969
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
970
+
971
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
+ x=x,
973
+ router_weight=self.router.weight,
974
+ moe_top_k=moe_top_k,
975
+ moe_num_experts=moe_num_experts,
976
+ moe_jitter_eps=moe_jitter_eps,
977
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
978
+ uniform_expert_assignment=uniform_expert_assignment,
979
+ training=self.training,
980
+ w1=self.experts.gate_up_proj,
981
+ w2=self.experts.down_proj,
982
+ w1_bias=self.experts.gate_up_proj_bias,
983
+ w2_bias=self.experts.down_proj_bias,
984
+ gradient_scale=gradient_scale,
985
+ alpha=alpha,
986
+ sort_end_bit=sort_end_bit,
987
+ expert_parallel_group=expert_parallel_group,
988
+ moe_capacity_factor=moe_capacity_factor,
989
+ moe_expert_model_parallelism=has_parallel,
990
+ forward_fn=forward_fn,
991
+ hidden_size=self.experts.hidden_size,
992
+ mlp_impl=mlp_impl,
993
+ # Shared expert parameters
994
+ shared_up_proj_weight=self.shared_up_proj_weight,
995
+ shared_down_proj_weight=self.shared_down_proj_weight,
996
+ shared_up_proj_bias=self.shared_up_proj_bias,
997
+ shared_down_proj_bias=self.shared_down_proj_bias,
998
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
+ shared_activation_fn=self.shared_activation_fn,
1000
+ )
1001
  return output, expert_weights_out
build/torch27-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:aff7108245384777d22e9023ae3fd4cf2bcb0015a0938e314d556dbd3e59fe00
3
  size 10517816
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a39b315c5359b79a67282160b5b344853aa06b5a5c9d8efafb903eb4f249b645
3
  size 10517816
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:002c2687dbc5693308fe32eaebe2f45ed3c85454fd45bc06d7b30e9c1a6d8949
3
- size 10517816
 
 
 
 
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_76c7de7
3
- ops = torch.ops._megablocks_76c7de7
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_76c7de7::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_89e2950
3
+ ops = torch.ops._megablocks_89e2950
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_89e2950::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
@@ -680,6 +740,136 @@ def moe_forward(
680
  return x, expert_weights, router_scores
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
691
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
692
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
693
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
694
-
695
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
 
 
 
696
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
697
  forward_fn = parallel_forward_once if has_parallel else forward_once
698
 
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
722
  hidden_size=self.experts.hidden_size,
723
  mlp_impl=mlp_impl,
724
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
  return output, expert_weights_out
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
155
+ # Shared expert MLP forward pass
156
+ def shared_mlp_forward(
157
+ x: torch.Tensor,
158
+ up_proj_weight: torch.Tensor,
159
+ down_proj_weight: torch.Tensor,
160
+ up_proj_bias: Optional[torch.Tensor] = None,
161
+ down_proj_bias: Optional[torch.Tensor] = None,
162
+ activation_fn: Optional[Any] = None,
163
+ gradient_scale: Optional[float] = None,
164
+ ) -> torch.Tensor:
165
+ # Default activation function
166
+ if activation_fn is None:
167
+ activation_fn = torch.nn.functional.gelu
168
+
169
+ # Scale weights
170
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
171
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
172
+ if up_proj_bias is not None:
173
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
174
+ if down_proj_bias is not None:
175
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
176
+
177
+ # Resolve dtensors
178
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
179
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
180
+ if up_proj_bias is not None:
181
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
182
+ if down_proj_bias is not None:
183
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
184
+
185
+ # Up projection
186
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
+
188
+ # Activation
189
+ x = activation_fn(x)
190
+
191
+ # Down projection
192
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
+
194
+ return x
195
+
196
+
197
+ # Combine outputs from shared expert and regular experts
198
+ def combine_expert_shared_outputs(
199
+ shared_expert_out: torch.Tensor,
200
+ expert_out: torch.Tensor,
201
+ shared_expert_weighted_sum: bool = False,
202
+ moe_top_k: int = 1,
203
+ ) -> torch.Tensor:
204
+ if shared_expert_weighted_sum:
205
+ # Weighted sum based on number of experts used
206
+ total_experts = moe_top_k + 1
207
+ shared_weight = 1.0 / total_experts
208
+ expert_weight = moe_top_k / total_experts
209
+ return shared_expert_out * shared_weight + expert_out * expert_weight
210
+ else:
211
+ # Simple addition
212
+ return shared_expert_out + expert_out
213
+
214
+
215
  # Global variable to store load balancing loss
216
  _LOAD_BALANCING_LOSS = []
217
 
 
740
  return x, expert_weights, router_scores
741
 
742
 
743
+ def moe_forward_with_shared_expert(
744
+ x: torch.Tensor,
745
+ router_weight: torch.Tensor,
746
+ moe_top_k: int,
747
+ moe_num_experts: int,
748
+ moe_jitter_eps: float = None,
749
+ moe_normalize_expert_weights: int = None,
750
+ uniform_expert_assignment: bool = False,
751
+ training: bool = False,
752
+ w1: torch.Tensor = None,
753
+ w2: torch.Tensor = None,
754
+ w1_bias: torch.Tensor = None,
755
+ w2_bias: torch.Tensor = None,
756
+ gradient_scale: Optional[float] = None,
757
+ alpha: float = 1.702,
758
+ sort_end_bit: int = 0,
759
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
760
+ moe_capacity_factor: float = 1.0,
761
+ moe_expert_model_parallelism: bool = False,
762
+ forward_fn: Any = None,
763
+ hidden_size: int = None,
764
+ mlp_impl: str = "grouped",
765
+ # Shared expert parameters
766
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
767
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
768
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
769
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
770
+ shared_expert_weighted_sum: bool = False,
771
+ shared_activation_fn: Optional[Any] = None,
772
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
+
774
+ # First, compute regular MoE forward pass
775
+ expert_out, expert_weights, router_scores = moe_forward(
776
+ x=x,
777
+ router_weight=router_weight,
778
+ moe_top_k=moe_top_k,
779
+ moe_num_experts=moe_num_experts,
780
+ moe_jitter_eps=moe_jitter_eps,
781
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
782
+ uniform_expert_assignment=uniform_expert_assignment,
783
+ training=training,
784
+ w1=w1,
785
+ w2=w2,
786
+ w1_bias=w1_bias,
787
+ w2_bias=w2_bias,
788
+ gradient_scale=gradient_scale,
789
+ alpha=alpha,
790
+ sort_end_bit=sort_end_bit,
791
+ expert_parallel_group=expert_parallel_group,
792
+ moe_capacity_factor=moe_capacity_factor,
793
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
794
+ forward_fn=forward_fn,
795
+ hidden_size=hidden_size,
796
+ mlp_impl=mlp_impl,
797
+ )
798
+
799
+ # If shared expert weights provided, compute shared expert output
800
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
+ shared_expert_out = shared_mlp_forward(
802
+ x=x,
803
+ up_proj_weight=shared_up_proj_weight,
804
+ down_proj_weight=shared_down_proj_weight,
805
+ up_proj_bias=shared_up_proj_bias,
806
+ down_proj_bias=shared_down_proj_bias,
807
+ activation_fn=shared_activation_fn,
808
+ gradient_scale=gradient_scale,
809
+ )
810
+
811
+ # Combine expert outputs
812
+ combined_out = combine_expert_shared_outputs(
813
+ shared_expert_out=shared_expert_out,
814
+ expert_out=expert_out,
815
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
816
+ moe_top_k=moe_top_k,
817
+ )
818
+
819
+ return combined_out, expert_weights, router_scores
820
+
821
+ # Return regular MoE output if no shared expert
822
+ return expert_out, expert_weights, router_scores
823
+
824
+
825
+ def create_shared_expert_weights(
826
+ hidden_size: int,
827
+ shared_expert_hidden_size: int,
828
+ device: torch.device,
829
+ dtype: torch.dtype,
830
+ init_method: Any,
831
+ output_layer_init_method: Any = None,
832
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
833
+
834
+ if output_layer_init_method is None:
835
+ output_layer_init_method = init_method
836
+
837
+ # Create weight tensors
838
+ up_proj_weight = torch.empty(
839
+ shared_expert_hidden_size,
840
+ hidden_size,
841
+ device=device,
842
+ dtype=dtype,
843
+ )
844
+ down_proj_weight = torch.empty(
845
+ hidden_size,
846
+ shared_expert_hidden_size,
847
+ device=device,
848
+ dtype=dtype,
849
+ )
850
+
851
+ # Initialize weights
852
+ init_method(up_proj_weight)
853
+ output_layer_init_method(down_proj_weight)
854
+
855
+ # No bias by default
856
+ return up_proj_weight, down_proj_weight, None, None
857
+
858
+ # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
+ # This exists because device_mesh is trapped in hook closures with no model attribute
860
+ # Fragile - breaks if hook structure changes or Python internals change
861
+ # TODO: Replace with a more robust solution when available
862
+ def get_device_mesh(model):
863
+ # Extract device_mesh from child's unused pre_hook closure
864
+ try:
865
+ # Find the pre-hook that contains 'device_mesh' in its closure
866
+ hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
867
+ # Extract the device_mesh from the closure
868
+ return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
869
+ except Exception:
870
+ return None
871
+
872
+
873
  class MegaBlocksMoeMLP(torch.nn.Module):
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
+
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
886
+ if expert_parallel_group is None:
887
+ device_mesh = get_device_mesh(self)
888
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
+
890
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
 
 
916
  hidden_size=self.experts.hidden_size,
917
  mlp_impl=mlp_impl,
918
  )
919
+ return output, expert_weights_out
920
+
921
+
922
+ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
+
924
+ def __init__(self):
925
+ super().__init__()
926
+ # Shared expert weights will be set by the user
927
+ self.shared_up_proj_weight = None
928
+ self.shared_down_proj_weight = None
929
+ self.shared_up_proj_bias = None
930
+ self.shared_down_proj_bias = None
931
+ self.shared_expert_weighted_sum = False
932
+ self.shared_activation_fn = None
933
+
934
+ def set_shared_expert_weights(
935
+ self,
936
+ up_proj_weight: torch.Tensor,
937
+ down_proj_weight: torch.Tensor,
938
+ up_proj_bias: Optional[torch.Tensor] = None,
939
+ down_proj_bias: Optional[torch.Tensor] = None,
940
+ weighted_sum: bool = False,
941
+ activation_fn: Optional[Any] = None,
942
+ ):
943
+ self.shared_up_proj_weight = up_proj_weight
944
+ self.shared_down_proj_weight = down_proj_weight
945
+ self.shared_up_proj_bias = up_proj_bias
946
+ self.shared_down_proj_bias = down_proj_bias
947
+ self.shared_expert_weighted_sum = weighted_sum
948
+ self.shared_activation_fn = activation_fn
949
+
950
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
+ moe_top_k = getattr(self.router, "top_k", 4)
952
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
953
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
954
+ alpha = getattr(self.experts, "alpha", 1.0)
955
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
958
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
+
960
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
961
+ if expert_parallel_group is None:
962
+ device_mesh = get_device_mesh(self)
963
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
+
965
+ has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
966
+ forward_fn = parallel_forward_once if has_parallel else forward_once
967
+
968
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
969
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
970
+
971
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
+ x=x,
973
+ router_weight=self.router.weight,
974
+ moe_top_k=moe_top_k,
975
+ moe_num_experts=moe_num_experts,
976
+ moe_jitter_eps=moe_jitter_eps,
977
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
978
+ uniform_expert_assignment=uniform_expert_assignment,
979
+ training=self.training,
980
+ w1=self.experts.gate_up_proj,
981
+ w2=self.experts.down_proj,
982
+ w1_bias=self.experts.gate_up_proj_bias,
983
+ w2_bias=self.experts.down_proj_bias,
984
+ gradient_scale=gradient_scale,
985
+ alpha=alpha,
986
+ sort_end_bit=sort_end_bit,
987
+ expert_parallel_group=expert_parallel_group,
988
+ moe_capacity_factor=moe_capacity_factor,
989
+ moe_expert_model_parallelism=has_parallel,
990
+ forward_fn=forward_fn,
991
+ hidden_size=self.experts.hidden_size,
992
+ mlp_impl=mlp_impl,
993
+ # Shared expert parameters
994
+ shared_up_proj_weight=self.shared_up_proj_weight,
995
+ shared_down_proj_weight=self.shared_down_proj_weight,
996
+ shared_up_proj_bias=self.shared_up_proj_bias,
997
+ shared_down_proj_bias=self.shared_down_proj_bias,
998
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
+ shared_activation_fn=self.shared_activation_fn,
1000
+ )
1001
  return output, expert_weights_out
build/torch27-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:aa1eeccba0a3a26435538a2aa87bc22a40c0201a79979872f6296af984e7bf1e
3
  size 11931080
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4870e4a9a831c30c7177b9b23b2b20d64f47242f16d818be1884b4e130e063c1
3
  size 11931080
build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ef9197ea269734d4e0528887ab3c353fa8ba10ccf9a82c9abe85b72bc0ea3553
3
- size 11931080
 
 
 
 
build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_76c7de7
3
- ops = torch.ops._megablocks_76c7de7
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_76c7de7::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_89e2950
3
+ ops = torch.ops._megablocks_89e2950
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_89e2950::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
@@ -680,6 +740,136 @@ def moe_forward(
680
  return x, expert_weights, router_scores
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
691
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
692
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
693
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
694
-
695
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
 
 
 
696
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
697
  forward_fn = parallel_forward_once if has_parallel else forward_once
698
 
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
722
  hidden_size=self.experts.hidden_size,
723
  mlp_impl=mlp_impl,
724
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
  return output, expert_weights_out
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
155
+ # Shared expert MLP forward pass
156
+ def shared_mlp_forward(
157
+ x: torch.Tensor,
158
+ up_proj_weight: torch.Tensor,
159
+ down_proj_weight: torch.Tensor,
160
+ up_proj_bias: Optional[torch.Tensor] = None,
161
+ down_proj_bias: Optional[torch.Tensor] = None,
162
+ activation_fn: Optional[Any] = None,
163
+ gradient_scale: Optional[float] = None,
164
+ ) -> torch.Tensor:
165
+ # Default activation function
166
+ if activation_fn is None:
167
+ activation_fn = torch.nn.functional.gelu
168
+
169
+ # Scale weights
170
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
171
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
172
+ if up_proj_bias is not None:
173
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
174
+ if down_proj_bias is not None:
175
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
176
+
177
+ # Resolve dtensors
178
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
179
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
180
+ if up_proj_bias is not None:
181
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
182
+ if down_proj_bias is not None:
183
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
184
+
185
+ # Up projection
186
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
+
188
+ # Activation
189
+ x = activation_fn(x)
190
+
191
+ # Down projection
192
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
+
194
+ return x
195
+
196
+
197
+ # Combine outputs from shared expert and regular experts
198
+ def combine_expert_shared_outputs(
199
+ shared_expert_out: torch.Tensor,
200
+ expert_out: torch.Tensor,
201
+ shared_expert_weighted_sum: bool = False,
202
+ moe_top_k: int = 1,
203
+ ) -> torch.Tensor:
204
+ if shared_expert_weighted_sum:
205
+ # Weighted sum based on number of experts used
206
+ total_experts = moe_top_k + 1
207
+ shared_weight = 1.0 / total_experts
208
+ expert_weight = moe_top_k / total_experts
209
+ return shared_expert_out * shared_weight + expert_out * expert_weight
210
+ else:
211
+ # Simple addition
212
+ return shared_expert_out + expert_out
213
+
214
+
215
  # Global variable to store load balancing loss
216
  _LOAD_BALANCING_LOSS = []
217
 
 
740
  return x, expert_weights, router_scores
741
 
742
 
743
+ def moe_forward_with_shared_expert(
744
+ x: torch.Tensor,
745
+ router_weight: torch.Tensor,
746
+ moe_top_k: int,
747
+ moe_num_experts: int,
748
+ moe_jitter_eps: float = None,
749
+ moe_normalize_expert_weights: int = None,
750
+ uniform_expert_assignment: bool = False,
751
+ training: bool = False,
752
+ w1: torch.Tensor = None,
753
+ w2: torch.Tensor = None,
754
+ w1_bias: torch.Tensor = None,
755
+ w2_bias: torch.Tensor = None,
756
+ gradient_scale: Optional[float] = None,
757
+ alpha: float = 1.702,
758
+ sort_end_bit: int = 0,
759
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
760
+ moe_capacity_factor: float = 1.0,
761
+ moe_expert_model_parallelism: bool = False,
762
+ forward_fn: Any = None,
763
+ hidden_size: int = None,
764
+ mlp_impl: str = "grouped",
765
+ # Shared expert parameters
766
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
767
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
768
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
769
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
770
+ shared_expert_weighted_sum: bool = False,
771
+ shared_activation_fn: Optional[Any] = None,
772
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
+
774
+ # First, compute regular MoE forward pass
775
+ expert_out, expert_weights, router_scores = moe_forward(
776
+ x=x,
777
+ router_weight=router_weight,
778
+ moe_top_k=moe_top_k,
779
+ moe_num_experts=moe_num_experts,
780
+ moe_jitter_eps=moe_jitter_eps,
781
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
782
+ uniform_expert_assignment=uniform_expert_assignment,
783
+ training=training,
784
+ w1=w1,
785
+ w2=w2,
786
+ w1_bias=w1_bias,
787
+ w2_bias=w2_bias,
788
+ gradient_scale=gradient_scale,
789
+ alpha=alpha,
790
+ sort_end_bit=sort_end_bit,
791
+ expert_parallel_group=expert_parallel_group,
792
+ moe_capacity_factor=moe_capacity_factor,
793
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
794
+ forward_fn=forward_fn,
795
+ hidden_size=hidden_size,
796
+ mlp_impl=mlp_impl,
797
+ )
798
+
799
+ # If shared expert weights provided, compute shared expert output
800
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
+ shared_expert_out = shared_mlp_forward(
802
+ x=x,
803
+ up_proj_weight=shared_up_proj_weight,
804
+ down_proj_weight=shared_down_proj_weight,
805
+ up_proj_bias=shared_up_proj_bias,
806
+ down_proj_bias=shared_down_proj_bias,
807
+ activation_fn=shared_activation_fn,
808
+ gradient_scale=gradient_scale,
809
+ )
810
+
811
+ # Combine expert outputs
812
+ combined_out = combine_expert_shared_outputs(
813
+ shared_expert_out=shared_expert_out,
814
+ expert_out=expert_out,
815
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
816
+ moe_top_k=moe_top_k,
817
+ )
818
+
819
+ return combined_out, expert_weights, router_scores
820
+
821
+ # Return regular MoE output if no shared expert
822
+ return expert_out, expert_weights, router_scores
823
+
824
+
825
+ def create_shared_expert_weights(
826
+ hidden_size: int,
827
+ shared_expert_hidden_size: int,
828
+ device: torch.device,
829
+ dtype: torch.dtype,
830
+ init_method: Any,
831
+ output_layer_init_method: Any = None,
832
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
833
+
834
+ if output_layer_init_method is None:
835
+ output_layer_init_method = init_method
836
+
837
+ # Create weight tensors
838
+ up_proj_weight = torch.empty(
839
+ shared_expert_hidden_size,
840
+ hidden_size,
841
+ device=device,
842
+ dtype=dtype,
843
+ )
844
+ down_proj_weight = torch.empty(
845
+ hidden_size,
846
+ shared_expert_hidden_size,
847
+ device=device,
848
+ dtype=dtype,
849
+ )
850
+
851
+ # Initialize weights
852
+ init_method(up_proj_weight)
853
+ output_layer_init_method(down_proj_weight)
854
+
855
+ # No bias by default
856
+ return up_proj_weight, down_proj_weight, None, None
857
+
858
+ # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
+ # This exists because device_mesh is trapped in hook closures with no model attribute
860
+ # Fragile - breaks if hook structure changes or Python internals change
861
+ # TODO: Replace with a more robust solution when available
862
+ def get_device_mesh(model):
863
+ # Extract device_mesh from child's unused pre_hook closure
864
+ try:
865
+ # Find the pre-hook that contains 'device_mesh' in its closure
866
+ hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
867
+ # Extract the device_mesh from the closure
868
+ return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
869
+ except Exception:
870
+ return None
871
+
872
+
873
  class MegaBlocksMoeMLP(torch.nn.Module):
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
+
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
886
+ if expert_parallel_group is None:
887
+ device_mesh = get_device_mesh(self)
888
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
+
890
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
 
 
916
  hidden_size=self.experts.hidden_size,
917
  mlp_impl=mlp_impl,
918
  )
919
+ return output, expert_weights_out
920
+
921
+
922
+ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
+
924
+ def __init__(self):
925
+ super().__init__()
926
+ # Shared expert weights will be set by the user
927
+ self.shared_up_proj_weight = None
928
+ self.shared_down_proj_weight = None
929
+ self.shared_up_proj_bias = None
930
+ self.shared_down_proj_bias = None
931
+ self.shared_expert_weighted_sum = False
932
+ self.shared_activation_fn = None
933
+
934
+ def set_shared_expert_weights(
935
+ self,
936
+ up_proj_weight: torch.Tensor,
937
+ down_proj_weight: torch.Tensor,
938
+ up_proj_bias: Optional[torch.Tensor] = None,
939
+ down_proj_bias: Optional[torch.Tensor] = None,
940
+ weighted_sum: bool = False,
941
+ activation_fn: Optional[Any] = None,
942
+ ):
943
+ self.shared_up_proj_weight = up_proj_weight
944
+ self.shared_down_proj_weight = down_proj_weight
945
+ self.shared_up_proj_bias = up_proj_bias
946
+ self.shared_down_proj_bias = down_proj_bias
947
+ self.shared_expert_weighted_sum = weighted_sum
948
+ self.shared_activation_fn = activation_fn
949
+
950
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
+ moe_top_k = getattr(self.router, "top_k", 4)
952
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
953
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
954
+ alpha = getattr(self.experts, "alpha", 1.0)
955
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
958
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
+
960
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
961
+ if expert_parallel_group is None:
962
+ device_mesh = get_device_mesh(self)
963
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
+
965
+ has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
966
+ forward_fn = parallel_forward_once if has_parallel else forward_once
967
+
968
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
969
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
970
+
971
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
+ x=x,
973
+ router_weight=self.router.weight,
974
+ moe_top_k=moe_top_k,
975
+ moe_num_experts=moe_num_experts,
976
+ moe_jitter_eps=moe_jitter_eps,
977
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
978
+ uniform_expert_assignment=uniform_expert_assignment,
979
+ training=self.training,
980
+ w1=self.experts.gate_up_proj,
981
+ w2=self.experts.down_proj,
982
+ w1_bias=self.experts.gate_up_proj_bias,
983
+ w2_bias=self.experts.down_proj_bias,
984
+ gradient_scale=gradient_scale,
985
+ alpha=alpha,
986
+ sort_end_bit=sort_end_bit,
987
+ expert_parallel_group=expert_parallel_group,
988
+ moe_capacity_factor=moe_capacity_factor,
989
+ moe_expert_model_parallelism=has_parallel,
990
+ forward_fn=forward_fn,
991
+ hidden_size=self.experts.hidden_size,
992
+ mlp_impl=mlp_impl,
993
+ # Shared expert parameters
994
+ shared_up_proj_weight=self.shared_up_proj_weight,
995
+ shared_down_proj_weight=self.shared_down_proj_weight,
996
+ shared_up_proj_bias=self.shared_up_proj_bias,
997
+ shared_down_proj_bias=self.shared_down_proj_bias,
998
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
+ shared_activation_fn=self.shared_activation_fn,
1000
+ )
1001
  return output, expert_weights_out
build/torch27-cxx11-cu128-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β†’ _megablocks_89e2950.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:64e2fd33ed4a5e9497ad304763c3c174ade26702a8e43fe8e7b3d3e79eb1e021
3
  size 17892624
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37844f7b2972aae75a1eeb8cda3b573a93ef27dd5a73b2cfb95fca1f41da07d9
3
  size 17892624
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b071dec56af72c9e6b8408106b97fb42355b08e94cc1200bb6f4d3f42ba0e97e
3
- size 17892624
 
 
 
 
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_76c7de7
3
- ops = torch.ops._megablocks_76c7de7
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_76c7de7::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_89e2950
3
+ ops = torch.ops._megablocks_89e2950
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_89e2950::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
@@ -680,6 +740,136 @@ def moe_forward(
680
  return x, expert_weights, router_scores
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
691
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
692
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
693
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
694
-
695
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
 
 
 
696
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
697
  forward_fn = parallel_forward_once if has_parallel else forward_once
698
 
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
722
  hidden_size=self.experts.hidden_size,
723
  mlp_impl=mlp_impl,
724
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
  return output, expert_weights_out
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
155
+ # Shared expert MLP forward pass
156
+ def shared_mlp_forward(
157
+ x: torch.Tensor,
158
+ up_proj_weight: torch.Tensor,
159
+ down_proj_weight: torch.Tensor,
160
+ up_proj_bias: Optional[torch.Tensor] = None,
161
+ down_proj_bias: Optional[torch.Tensor] = None,
162
+ activation_fn: Optional[Any] = None,
163
+ gradient_scale: Optional[float] = None,
164
+ ) -> torch.Tensor:
165
+ # Default activation function
166
+ if activation_fn is None:
167
+ activation_fn = torch.nn.functional.gelu
168
+
169
+ # Scale weights
170
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
171
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
172
+ if up_proj_bias is not None:
173
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
174
+ if down_proj_bias is not None:
175
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
176
+
177
+ # Resolve dtensors
178
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
179
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
180
+ if up_proj_bias is not None:
181
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
182
+ if down_proj_bias is not None:
183
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
184
+
185
+ # Up projection
186
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
+
188
+ # Activation
189
+ x = activation_fn(x)
190
+
191
+ # Down projection
192
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
+
194
+ return x
195
+
196
+
197
+ # Combine outputs from shared expert and regular experts
198
+ def combine_expert_shared_outputs(
199
+ shared_expert_out: torch.Tensor,
200
+ expert_out: torch.Tensor,
201
+ shared_expert_weighted_sum: bool = False,
202
+ moe_top_k: int = 1,
203
+ ) -> torch.Tensor:
204
+ if shared_expert_weighted_sum:
205
+ # Weighted sum based on number of experts used
206
+ total_experts = moe_top_k + 1
207
+ shared_weight = 1.0 / total_experts
208
+ expert_weight = moe_top_k / total_experts
209
+ return shared_expert_out * shared_weight + expert_out * expert_weight
210
+ else:
211
+ # Simple addition
212
+ return shared_expert_out + expert_out
213
+
214
+
215
  # Global variable to store load balancing loss
216
  _LOAD_BALANCING_LOSS = []
217
 
 
740
  return x, expert_weights, router_scores
741
 
742
 
743
+ def moe_forward_with_shared_expert(
744
+ x: torch.Tensor,
745
+ router_weight: torch.Tensor,
746
+ moe_top_k: int,
747
+ moe_num_experts: int,
748
+ moe_jitter_eps: float = None,
749
+ moe_normalize_expert_weights: int = None,
750
+ uniform_expert_assignment: bool = False,
751
+ training: bool = False,
752
+ w1: torch.Tensor = None,
753
+ w2: torch.Tensor = None,
754
+ w1_bias: torch.Tensor = None,
755
+ w2_bias: torch.Tensor = None,
756
+ gradient_scale: Optional[float] = None,
757
+ alpha: float = 1.702,
758
+ sort_end_bit: int = 0,
759
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
760
+ moe_capacity_factor: float = 1.0,
761
+ moe_expert_model_parallelism: bool = False,
762
+ forward_fn: Any = None,
763
+ hidden_size: int = None,
764
+ mlp_impl: str = "grouped",
765
+ # Shared expert parameters
766
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
767
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
768
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
769
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
770
+ shared_expert_weighted_sum: bool = False,
771
+ shared_activation_fn: Optional[Any] = None,
772
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
+
774
+ # First, compute regular MoE forward pass
775
+ expert_out, expert_weights, router_scores = moe_forward(
776
+ x=x,
777
+ router_weight=router_weight,
778
+ moe_top_k=moe_top_k,
779
+ moe_num_experts=moe_num_experts,
780
+ moe_jitter_eps=moe_jitter_eps,
781
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
782
+ uniform_expert_assignment=uniform_expert_assignment,
783
+ training=training,
784
+ w1=w1,
785
+ w2=w2,
786
+ w1_bias=w1_bias,
787
+ w2_bias=w2_bias,
788
+ gradient_scale=gradient_scale,
789
+ alpha=alpha,
790
+ sort_end_bit=sort_end_bit,
791
+ expert_parallel_group=expert_parallel_group,
792
+ moe_capacity_factor=moe_capacity_factor,
793
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
794
+ forward_fn=forward_fn,
795
+ hidden_size=hidden_size,
796
+ mlp_impl=mlp_impl,
797
+ )
798
+
799
+ # If shared expert weights provided, compute shared expert output
800
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
+ shared_expert_out = shared_mlp_forward(
802
+ x=x,
803
+ up_proj_weight=shared_up_proj_weight,
804
+ down_proj_weight=shared_down_proj_weight,
805
+ up_proj_bias=shared_up_proj_bias,
806
+ down_proj_bias=shared_down_proj_bias,
807
+ activation_fn=shared_activation_fn,
808
+ gradient_scale=gradient_scale,
809
+ )
810
+
811
+ # Combine expert outputs
812
+ combined_out = combine_expert_shared_outputs(
813
+ shared_expert_out=shared_expert_out,
814
+ expert_out=expert_out,
815
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
816
+ moe_top_k=moe_top_k,
817
+ )
818
+
819
+ return combined_out, expert_weights, router_scores
820
+
821
+ # Return regular MoE output if no shared expert
822
+ return expert_out, expert_weights, router_scores
823
+
824
+
825
+ def create_shared_expert_weights(
826
+ hidden_size: int,
827
+ shared_expert_hidden_size: int,
828
+ device: torch.device,
829
+ dtype: torch.dtype,
830
+ init_method: Any,
831
+ output_layer_init_method: Any = None,
832
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
833
+
834
+ if output_layer_init_method is None:
835
+ output_layer_init_method = init_method
836
+
837
+ # Create weight tensors
838
+ up_proj_weight = torch.empty(
839
+ shared_expert_hidden_size,
840
+ hidden_size,
841
+ device=device,
842
+ dtype=dtype,
843
+ )
844
+ down_proj_weight = torch.empty(
845
+ hidden_size,
846
+ shared_expert_hidden_size,
847
+ device=device,
848
+ dtype=dtype,
849
+ )
850
+
851
+ # Initialize weights
852
+ init_method(up_proj_weight)
853
+ output_layer_init_method(down_proj_weight)
854
+
855
+ # No bias by default
856
+ return up_proj_weight, down_proj_weight, None, None
857
+
858
+ # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
+ # This exists because device_mesh is trapped in hook closures with no model attribute
860
+ # Fragile - breaks if hook structure changes or Python internals change
861
+ # TODO: Replace with a more robust solution when available
862
+ def get_device_mesh(model):
863
+ # Extract device_mesh from child's unused pre_hook closure
864
+ try:
865
+ # Find the pre-hook that contains 'device_mesh' in its closure
866
+ hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
867
+ # Extract the device_mesh from the closure
868
+ return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
869
+ except Exception:
870
+ return None
871
+
872
+
873
  class MegaBlocksMoeMLP(torch.nn.Module):
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
+
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
886
+ if expert_parallel_group is None:
887
+ device_mesh = get_device_mesh(self)
888
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
+
890
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
 
 
916
  hidden_size=self.experts.hidden_size,
917
  mlp_impl=mlp_impl,
918
  )
919
+ return output, expert_weights_out
920
+
921
+
922
+ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
+
924
+ def __init__(self):
925
+ super().__init__()
926
+ # Shared expert weights will be set by the user
927
+ self.shared_up_proj_weight = None
928
+ self.shared_down_proj_weight = None
929
+ self.shared_up_proj_bias = None
930
+ self.shared_down_proj_bias = None
931
+ self.shared_expert_weighted_sum = False
932
+ self.shared_activation_fn = None
933
+
934
+ def set_shared_expert_weights(
935
+ self,
936
+ up_proj_weight: torch.Tensor,
937
+ down_proj_weight: torch.Tensor,
938
+ up_proj_bias: Optional[torch.Tensor] = None,
939
+ down_proj_bias: Optional[torch.Tensor] = None,
940
+ weighted_sum: bool = False,
941
+ activation_fn: Optional[Any] = None,
942
+ ):
943
+ self.shared_up_proj_weight = up_proj_weight
944
+ self.shared_down_proj_weight = down_proj_weight
945
+ self.shared_up_proj_bias = up_proj_bias
946
+ self.shared_down_proj_bias = down_proj_bias
947
+ self.shared_expert_weighted_sum = weighted_sum
948
+ self.shared_activation_fn = activation_fn
949
+
950
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
+ moe_top_k = getattr(self.router, "top_k", 4)
952
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
953
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
954
+ alpha = getattr(self.experts, "alpha", 1.0)
955
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
958
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
+
960
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
961
+ if expert_parallel_group is None:
962
+ device_mesh = get_device_mesh(self)
963
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
+
965
+ has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
966
+ forward_fn = parallel_forward_once if has_parallel else forward_once
967
+
968
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
969
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
970
+
971
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
+ x=x,
973
+ router_weight=self.router.weight,
974
+ moe_top_k=moe_top_k,
975
+ moe_num_experts=moe_num_experts,
976
+ moe_jitter_eps=moe_jitter_eps,
977
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
978
+ uniform_expert_assignment=uniform_expert_assignment,
979
+ training=self.training,
980
+ w1=self.experts.gate_up_proj,
981
+ w2=self.experts.down_proj,
982
+ w1_bias=self.experts.gate_up_proj_bias,
983
+ w2_bias=self.experts.down_proj_bias,
984
+ gradient_scale=gradient_scale,
985
+ alpha=alpha,
986
+ sort_end_bit=sort_end_bit,
987
+ expert_parallel_group=expert_parallel_group,
988
+ moe_capacity_factor=moe_capacity_factor,
989
+ moe_expert_model_parallelism=has_parallel,
990
+ forward_fn=forward_fn,
991
+ hidden_size=self.experts.hidden_size,
992
+ mlp_impl=mlp_impl,
993
+ # Shared expert parameters
994
+ shared_up_proj_weight=self.shared_up_proj_weight,
995
+ shared_down_proj_weight=self.shared_down_proj_weight,
996
+ shared_up_proj_bias=self.shared_up_proj_bias,
997
+ shared_down_proj_bias=self.shared_down_proj_bias,
998
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
+ shared_activation_fn=self.shared_activation_fn,
1000
+ )
1001
  return output, expert_weights_out