drbh
commited on
Commit
Β·
3bdb4b8
1
Parent(s):
89e2950
feat: bump build for shared experts
Browse files- build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py +277 -1
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py +277 -1
- build/torch26-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py +277 -1
- build/torch26-cxx98-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py +277 -1
- build/torch26-cxx98-cu124-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py +277 -1
- build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_76c7de7.abi3.so +0 -3
- build/torch26-cxx98-cu126-x86_64-linux/megablocks/{_megablocks_9a1816c.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py +277 -1
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py +277 -1
- build/torch27-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py +277 -1
- build/torch27-cxx11-cu128-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py +277 -1
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 10517576
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:070067fec0e735e865610caf4fc33b384fe8c9c47a002c365f740c82c5af1bab
|
3 |
size 10517576
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:4e4c48e189572141f6a140dd83f9eca19eaebbc20c5cd686aa0263aafec14533
|
3 |
-
size 10517576
|
|
|
|
|
|
|
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _megablocks_89e2950
|
3 |
+
ops = torch.ops._megablocks_89e2950
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py
CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# Global variable to store load balancing loss
|
156 |
_LOAD_BALANCING_LOSS = []
|
157 |
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
680 |
return x, expert_weights, router_scores
|
681 |
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
-
|
695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
722 |
hidden_size=self.experts.hidden_size,
|
723 |
mlp_impl=mlp_impl,
|
724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
725 |
return output, expert_weights_out
|
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
155 |
+
# Shared expert MLP forward pass
|
156 |
+
def shared_mlp_forward(
|
157 |
+
x: torch.Tensor,
|
158 |
+
up_proj_weight: torch.Tensor,
|
159 |
+
down_proj_weight: torch.Tensor,
|
160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
162 |
+
activation_fn: Optional[Any] = None,
|
163 |
+
gradient_scale: Optional[float] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
# Default activation function
|
166 |
+
if activation_fn is None:
|
167 |
+
activation_fn = torch.nn.functional.gelu
|
168 |
+
|
169 |
+
# Scale weights
|
170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
172 |
+
if up_proj_bias is not None:
|
173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
174 |
+
if down_proj_bias is not None:
|
175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
176 |
+
|
177 |
+
# Resolve dtensors
|
178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
180 |
+
if up_proj_bias is not None:
|
181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
182 |
+
if down_proj_bias is not None:
|
183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
184 |
+
|
185 |
+
# Up projection
|
186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
187 |
+
|
188 |
+
# Activation
|
189 |
+
x = activation_fn(x)
|
190 |
+
|
191 |
+
# Down projection
|
192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
# Combine outputs from shared expert and regular experts
|
198 |
+
def combine_expert_shared_outputs(
|
199 |
+
shared_expert_out: torch.Tensor,
|
200 |
+
expert_out: torch.Tensor,
|
201 |
+
shared_expert_weighted_sum: bool = False,
|
202 |
+
moe_top_k: int = 1,
|
203 |
+
) -> torch.Tensor:
|
204 |
+
if shared_expert_weighted_sum:
|
205 |
+
# Weighted sum based on number of experts used
|
206 |
+
total_experts = moe_top_k + 1
|
207 |
+
shared_weight = 1.0 / total_experts
|
208 |
+
expert_weight = moe_top_k / total_experts
|
209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
210 |
+
else:
|
211 |
+
# Simple addition
|
212 |
+
return shared_expert_out + expert_out
|
213 |
+
|
214 |
+
|
215 |
# Global variable to store load balancing loss
|
216 |
_LOAD_BALANCING_LOSS = []
|
217 |
|
|
|
740 |
return x, expert_weights, router_scores
|
741 |
|
742 |
|
743 |
+
def moe_forward_with_shared_expert(
|
744 |
+
x: torch.Tensor,
|
745 |
+
router_weight: torch.Tensor,
|
746 |
+
moe_top_k: int,
|
747 |
+
moe_num_experts: int,
|
748 |
+
moe_jitter_eps: float = None,
|
749 |
+
moe_normalize_expert_weights: int = None,
|
750 |
+
uniform_expert_assignment: bool = False,
|
751 |
+
training: bool = False,
|
752 |
+
w1: torch.Tensor = None,
|
753 |
+
w2: torch.Tensor = None,
|
754 |
+
w1_bias: torch.Tensor = None,
|
755 |
+
w2_bias: torch.Tensor = None,
|
756 |
+
gradient_scale: Optional[float] = None,
|
757 |
+
alpha: float = 1.702,
|
758 |
+
sort_end_bit: int = 0,
|
759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
760 |
+
moe_capacity_factor: float = 1.0,
|
761 |
+
moe_expert_model_parallelism: bool = False,
|
762 |
+
forward_fn: Any = None,
|
763 |
+
hidden_size: int = None,
|
764 |
+
mlp_impl: str = "grouped",
|
765 |
+
# Shared expert parameters
|
766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
770 |
+
shared_expert_weighted_sum: bool = False,
|
771 |
+
shared_activation_fn: Optional[Any] = None,
|
772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
773 |
+
|
774 |
+
# First, compute regular MoE forward pass
|
775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
776 |
+
x=x,
|
777 |
+
router_weight=router_weight,
|
778 |
+
moe_top_k=moe_top_k,
|
779 |
+
moe_num_experts=moe_num_experts,
|
780 |
+
moe_jitter_eps=moe_jitter_eps,
|
781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
783 |
+
training=training,
|
784 |
+
w1=w1,
|
785 |
+
w2=w2,
|
786 |
+
w1_bias=w1_bias,
|
787 |
+
w2_bias=w2_bias,
|
788 |
+
gradient_scale=gradient_scale,
|
789 |
+
alpha=alpha,
|
790 |
+
sort_end_bit=sort_end_bit,
|
791 |
+
expert_parallel_group=expert_parallel_group,
|
792 |
+
moe_capacity_factor=moe_capacity_factor,
|
793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
794 |
+
forward_fn=forward_fn,
|
795 |
+
hidden_size=hidden_size,
|
796 |
+
mlp_impl=mlp_impl,
|
797 |
+
)
|
798 |
+
|
799 |
+
# If shared expert weights provided, compute shared expert output
|
800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
801 |
+
shared_expert_out = shared_mlp_forward(
|
802 |
+
x=x,
|
803 |
+
up_proj_weight=shared_up_proj_weight,
|
804 |
+
down_proj_weight=shared_down_proj_weight,
|
805 |
+
up_proj_bias=shared_up_proj_bias,
|
806 |
+
down_proj_bias=shared_down_proj_bias,
|
807 |
+
activation_fn=shared_activation_fn,
|
808 |
+
gradient_scale=gradient_scale,
|
809 |
+
)
|
810 |
+
|
811 |
+
# Combine expert outputs
|
812 |
+
combined_out = combine_expert_shared_outputs(
|
813 |
+
shared_expert_out=shared_expert_out,
|
814 |
+
expert_out=expert_out,
|
815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
816 |
+
moe_top_k=moe_top_k,
|
817 |
+
)
|
818 |
+
|
819 |
+
return combined_out, expert_weights, router_scores
|
820 |
+
|
821 |
+
# Return regular MoE output if no shared expert
|
822 |
+
return expert_out, expert_weights, router_scores
|
823 |
+
|
824 |
+
|
825 |
+
def create_shared_expert_weights(
|
826 |
+
hidden_size: int,
|
827 |
+
shared_expert_hidden_size: int,
|
828 |
+
device: torch.device,
|
829 |
+
dtype: torch.dtype,
|
830 |
+
init_method: Any,
|
831 |
+
output_layer_init_method: Any = None,
|
832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
833 |
+
|
834 |
+
if output_layer_init_method is None:
|
835 |
+
output_layer_init_method = init_method
|
836 |
+
|
837 |
+
# Create weight tensors
|
838 |
+
up_proj_weight = torch.empty(
|
839 |
+
shared_expert_hidden_size,
|
840 |
+
hidden_size,
|
841 |
+
device=device,
|
842 |
+
dtype=dtype,
|
843 |
+
)
|
844 |
+
down_proj_weight = torch.empty(
|
845 |
+
hidden_size,
|
846 |
+
shared_expert_hidden_size,
|
847 |
+
device=device,
|
848 |
+
dtype=dtype,
|
849 |
+
)
|
850 |
+
|
851 |
+
# Initialize weights
|
852 |
+
init_method(up_proj_weight)
|
853 |
+
output_layer_init_method(down_proj_weight)
|
854 |
+
|
855 |
+
# No bias by default
|
856 |
+
return up_proj_weight, down_proj_weight, None, None
|
857 |
+
|
858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
861 |
+
# TODO: Replace with a more robust solution when available
|
862 |
+
def get_device_mesh(model):
|
863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
864 |
+
try:
|
865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
867 |
+
# Extract the device_mesh from the closure
|
868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
869 |
+
except Exception:
|
870 |
+
return None
|
871 |
+
|
872 |
+
|
873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
874 |
|
875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
884 |
+
|
885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
886 |
+
if expert_parallel_group is None:
|
887 |
+
device_mesh = get_device_mesh(self)
|
888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
889 |
+
|
890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
892 |
|
|
|
916 |
hidden_size=self.experts.hidden_size,
|
917 |
mlp_impl=mlp_impl,
|
918 |
)
|
919 |
+
return output, expert_weights_out
|
920 |
+
|
921 |
+
|
922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
923 |
+
|
924 |
+
def __init__(self):
|
925 |
+
super().__init__()
|
926 |
+
# Shared expert weights will be set by the user
|
927 |
+
self.shared_up_proj_weight = None
|
928 |
+
self.shared_down_proj_weight = None
|
929 |
+
self.shared_up_proj_bias = None
|
930 |
+
self.shared_down_proj_bias = None
|
931 |
+
self.shared_expert_weighted_sum = False
|
932 |
+
self.shared_activation_fn = None
|
933 |
+
|
934 |
+
def set_shared_expert_weights(
|
935 |
+
self,
|
936 |
+
up_proj_weight: torch.Tensor,
|
937 |
+
down_proj_weight: torch.Tensor,
|
938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
940 |
+
weighted_sum: bool = False,
|
941 |
+
activation_fn: Optional[Any] = None,
|
942 |
+
):
|
943 |
+
self.shared_up_proj_weight = up_proj_weight
|
944 |
+
self.shared_down_proj_weight = down_proj_weight
|
945 |
+
self.shared_up_proj_bias = up_proj_bias
|
946 |
+
self.shared_down_proj_bias = down_proj_bias
|
947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
948 |
+
self.shared_activation_fn = activation_fn
|
949 |
+
|
950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
959 |
+
|
960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
961 |
+
if expert_parallel_group is None:
|
962 |
+
device_mesh = get_device_mesh(self)
|
963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
964 |
+
|
965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
967 |
+
|
968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
970 |
+
|
971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
972 |
+
x=x,
|
973 |
+
router_weight=self.router.weight,
|
974 |
+
moe_top_k=moe_top_k,
|
975 |
+
moe_num_experts=moe_num_experts,
|
976 |
+
moe_jitter_eps=moe_jitter_eps,
|
977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
979 |
+
training=self.training,
|
980 |
+
w1=self.experts.gate_up_proj,
|
981 |
+
w2=self.experts.down_proj,
|
982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
983 |
+
w2_bias=self.experts.down_proj_bias,
|
984 |
+
gradient_scale=gradient_scale,
|
985 |
+
alpha=alpha,
|
986 |
+
sort_end_bit=sort_end_bit,
|
987 |
+
expert_parallel_group=expert_parallel_group,
|
988 |
+
moe_capacity_factor=moe_capacity_factor,
|
989 |
+
moe_expert_model_parallelism=has_parallel,
|
990 |
+
forward_fn=forward_fn,
|
991 |
+
hidden_size=self.experts.hidden_size,
|
992 |
+
mlp_impl=mlp_impl,
|
993 |
+
# Shared expert parameters
|
994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
999 |
+
shared_activation_fn=self.shared_activation_fn,
|
1000 |
+
)
|
1001 |
return output, expert_weights_out
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 11869392
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:02dffd561ef226c1ec17c99e462c3c771879f078dde9b1e5cd8bd5992be5b3da
|
3 |
size 11869392
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:3d958a0c77589a5ede72336d1cab80ea9d6324ef6f8a9a187af2da4db74e1894
|
3 |
-
size 11869392
|
|
|
|
|
|
|
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _megablocks_89e2950
|
3 |
+
ops = torch.ops._megablocks_89e2950
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py
CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# Global variable to store load balancing loss
|
156 |
_LOAD_BALANCING_LOSS = []
|
157 |
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
680 |
return x, expert_weights, router_scores
|
681 |
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
-
|
695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
722 |
hidden_size=self.experts.hidden_size,
|
723 |
mlp_impl=mlp_impl,
|
724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
725 |
return output, expert_weights_out
|
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
155 |
+
# Shared expert MLP forward pass
|
156 |
+
def shared_mlp_forward(
|
157 |
+
x: torch.Tensor,
|
158 |
+
up_proj_weight: torch.Tensor,
|
159 |
+
down_proj_weight: torch.Tensor,
|
160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
162 |
+
activation_fn: Optional[Any] = None,
|
163 |
+
gradient_scale: Optional[float] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
# Default activation function
|
166 |
+
if activation_fn is None:
|
167 |
+
activation_fn = torch.nn.functional.gelu
|
168 |
+
|
169 |
+
# Scale weights
|
170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
172 |
+
if up_proj_bias is not None:
|
173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
174 |
+
if down_proj_bias is not None:
|
175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
176 |
+
|
177 |
+
# Resolve dtensors
|
178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
180 |
+
if up_proj_bias is not None:
|
181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
182 |
+
if down_proj_bias is not None:
|
183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
184 |
+
|
185 |
+
# Up projection
|
186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
187 |
+
|
188 |
+
# Activation
|
189 |
+
x = activation_fn(x)
|
190 |
+
|
191 |
+
# Down projection
|
192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
# Combine outputs from shared expert and regular experts
|
198 |
+
def combine_expert_shared_outputs(
|
199 |
+
shared_expert_out: torch.Tensor,
|
200 |
+
expert_out: torch.Tensor,
|
201 |
+
shared_expert_weighted_sum: bool = False,
|
202 |
+
moe_top_k: int = 1,
|
203 |
+
) -> torch.Tensor:
|
204 |
+
if shared_expert_weighted_sum:
|
205 |
+
# Weighted sum based on number of experts used
|
206 |
+
total_experts = moe_top_k + 1
|
207 |
+
shared_weight = 1.0 / total_experts
|
208 |
+
expert_weight = moe_top_k / total_experts
|
209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
210 |
+
else:
|
211 |
+
# Simple addition
|
212 |
+
return shared_expert_out + expert_out
|
213 |
+
|
214 |
+
|
215 |
# Global variable to store load balancing loss
|
216 |
_LOAD_BALANCING_LOSS = []
|
217 |
|
|
|
740 |
return x, expert_weights, router_scores
|
741 |
|
742 |
|
743 |
+
def moe_forward_with_shared_expert(
|
744 |
+
x: torch.Tensor,
|
745 |
+
router_weight: torch.Tensor,
|
746 |
+
moe_top_k: int,
|
747 |
+
moe_num_experts: int,
|
748 |
+
moe_jitter_eps: float = None,
|
749 |
+
moe_normalize_expert_weights: int = None,
|
750 |
+
uniform_expert_assignment: bool = False,
|
751 |
+
training: bool = False,
|
752 |
+
w1: torch.Tensor = None,
|
753 |
+
w2: torch.Tensor = None,
|
754 |
+
w1_bias: torch.Tensor = None,
|
755 |
+
w2_bias: torch.Tensor = None,
|
756 |
+
gradient_scale: Optional[float] = None,
|
757 |
+
alpha: float = 1.702,
|
758 |
+
sort_end_bit: int = 0,
|
759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
760 |
+
moe_capacity_factor: float = 1.0,
|
761 |
+
moe_expert_model_parallelism: bool = False,
|
762 |
+
forward_fn: Any = None,
|
763 |
+
hidden_size: int = None,
|
764 |
+
mlp_impl: str = "grouped",
|
765 |
+
# Shared expert parameters
|
766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
770 |
+
shared_expert_weighted_sum: bool = False,
|
771 |
+
shared_activation_fn: Optional[Any] = None,
|
772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
773 |
+
|
774 |
+
# First, compute regular MoE forward pass
|
775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
776 |
+
x=x,
|
777 |
+
router_weight=router_weight,
|
778 |
+
moe_top_k=moe_top_k,
|
779 |
+
moe_num_experts=moe_num_experts,
|
780 |
+
moe_jitter_eps=moe_jitter_eps,
|
781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
783 |
+
training=training,
|
784 |
+
w1=w1,
|
785 |
+
w2=w2,
|
786 |
+
w1_bias=w1_bias,
|
787 |
+
w2_bias=w2_bias,
|
788 |
+
gradient_scale=gradient_scale,
|
789 |
+
alpha=alpha,
|
790 |
+
sort_end_bit=sort_end_bit,
|
791 |
+
expert_parallel_group=expert_parallel_group,
|
792 |
+
moe_capacity_factor=moe_capacity_factor,
|
793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
794 |
+
forward_fn=forward_fn,
|
795 |
+
hidden_size=hidden_size,
|
796 |
+
mlp_impl=mlp_impl,
|
797 |
+
)
|
798 |
+
|
799 |
+
# If shared expert weights provided, compute shared expert output
|
800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
801 |
+
shared_expert_out = shared_mlp_forward(
|
802 |
+
x=x,
|
803 |
+
up_proj_weight=shared_up_proj_weight,
|
804 |
+
down_proj_weight=shared_down_proj_weight,
|
805 |
+
up_proj_bias=shared_up_proj_bias,
|
806 |
+
down_proj_bias=shared_down_proj_bias,
|
807 |
+
activation_fn=shared_activation_fn,
|
808 |
+
gradient_scale=gradient_scale,
|
809 |
+
)
|
810 |
+
|
811 |
+
# Combine expert outputs
|
812 |
+
combined_out = combine_expert_shared_outputs(
|
813 |
+
shared_expert_out=shared_expert_out,
|
814 |
+
expert_out=expert_out,
|
815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
816 |
+
moe_top_k=moe_top_k,
|
817 |
+
)
|
818 |
+
|
819 |
+
return combined_out, expert_weights, router_scores
|
820 |
+
|
821 |
+
# Return regular MoE output if no shared expert
|
822 |
+
return expert_out, expert_weights, router_scores
|
823 |
+
|
824 |
+
|
825 |
+
def create_shared_expert_weights(
|
826 |
+
hidden_size: int,
|
827 |
+
shared_expert_hidden_size: int,
|
828 |
+
device: torch.device,
|
829 |
+
dtype: torch.dtype,
|
830 |
+
init_method: Any,
|
831 |
+
output_layer_init_method: Any = None,
|
832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
833 |
+
|
834 |
+
if output_layer_init_method is None:
|
835 |
+
output_layer_init_method = init_method
|
836 |
+
|
837 |
+
# Create weight tensors
|
838 |
+
up_proj_weight = torch.empty(
|
839 |
+
shared_expert_hidden_size,
|
840 |
+
hidden_size,
|
841 |
+
device=device,
|
842 |
+
dtype=dtype,
|
843 |
+
)
|
844 |
+
down_proj_weight = torch.empty(
|
845 |
+
hidden_size,
|
846 |
+
shared_expert_hidden_size,
|
847 |
+
device=device,
|
848 |
+
dtype=dtype,
|
849 |
+
)
|
850 |
+
|
851 |
+
# Initialize weights
|
852 |
+
init_method(up_proj_weight)
|
853 |
+
output_layer_init_method(down_proj_weight)
|
854 |
+
|
855 |
+
# No bias by default
|
856 |
+
return up_proj_weight, down_proj_weight, None, None
|
857 |
+
|
858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
861 |
+
# TODO: Replace with a more robust solution when available
|
862 |
+
def get_device_mesh(model):
|
863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
864 |
+
try:
|
865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
867 |
+
# Extract the device_mesh from the closure
|
868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
869 |
+
except Exception:
|
870 |
+
return None
|
871 |
+
|
872 |
+
|
873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
874 |
|
875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
884 |
+
|
885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
886 |
+
if expert_parallel_group is None:
|
887 |
+
device_mesh = get_device_mesh(self)
|
888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
889 |
+
|
890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
892 |
|
|
|
916 |
hidden_size=self.experts.hidden_size,
|
917 |
mlp_impl=mlp_impl,
|
918 |
)
|
919 |
+
return output, expert_weights_out
|
920 |
+
|
921 |
+
|
922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
923 |
+
|
924 |
+
def __init__(self):
|
925 |
+
super().__init__()
|
926 |
+
# Shared expert weights will be set by the user
|
927 |
+
self.shared_up_proj_weight = None
|
928 |
+
self.shared_down_proj_weight = None
|
929 |
+
self.shared_up_proj_bias = None
|
930 |
+
self.shared_down_proj_bias = None
|
931 |
+
self.shared_expert_weighted_sum = False
|
932 |
+
self.shared_activation_fn = None
|
933 |
+
|
934 |
+
def set_shared_expert_weights(
|
935 |
+
self,
|
936 |
+
up_proj_weight: torch.Tensor,
|
937 |
+
down_proj_weight: torch.Tensor,
|
938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
940 |
+
weighted_sum: bool = False,
|
941 |
+
activation_fn: Optional[Any] = None,
|
942 |
+
):
|
943 |
+
self.shared_up_proj_weight = up_proj_weight
|
944 |
+
self.shared_down_proj_weight = down_proj_weight
|
945 |
+
self.shared_up_proj_bias = up_proj_bias
|
946 |
+
self.shared_down_proj_bias = down_proj_bias
|
947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
948 |
+
self.shared_activation_fn = activation_fn
|
949 |
+
|
950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
959 |
+
|
960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
961 |
+
if expert_parallel_group is None:
|
962 |
+
device_mesh = get_device_mesh(self)
|
963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
964 |
+
|
965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
967 |
+
|
968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
970 |
+
|
971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
972 |
+
x=x,
|
973 |
+
router_weight=self.router.weight,
|
974 |
+
moe_top_k=moe_top_k,
|
975 |
+
moe_num_experts=moe_num_experts,
|
976 |
+
moe_jitter_eps=moe_jitter_eps,
|
977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
979 |
+
training=self.training,
|
980 |
+
w1=self.experts.gate_up_proj,
|
981 |
+
w2=self.experts.down_proj,
|
982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
983 |
+
w2_bias=self.experts.down_proj_bias,
|
984 |
+
gradient_scale=gradient_scale,
|
985 |
+
alpha=alpha,
|
986 |
+
sort_end_bit=sort_end_bit,
|
987 |
+
expert_parallel_group=expert_parallel_group,
|
988 |
+
moe_capacity_factor=moe_capacity_factor,
|
989 |
+
moe_expert_model_parallelism=has_parallel,
|
990 |
+
forward_fn=forward_fn,
|
991 |
+
hidden_size=self.experts.hidden_size,
|
992 |
+
mlp_impl=mlp_impl,
|
993 |
+
# Shared expert parameters
|
994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
999 |
+
shared_activation_fn=self.shared_activation_fn,
|
1000 |
+
)
|
1001 |
return output, expert_weights_out
|
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 11931048
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b5aa4e066ddbd863693ca8a5ec37fba34996226442dfa407e4a49b779497001d
|
3 |
size 11931048
|
build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:d41a4f5bbc160f51b058d3ba36e9087e9f15d35ae4782f36c984dd7199ee8ede
|
3 |
-
size 11931048
|
|
|
|
|
|
|
|
build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _megablocks_89e2950
|
3 |
+
ops = torch.ops._megablocks_89e2950
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py
CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# Global variable to store load balancing loss
|
156 |
_LOAD_BALANCING_LOSS = []
|
157 |
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
680 |
return x, expert_weights, router_scores
|
681 |
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
-
|
695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
722 |
hidden_size=self.experts.hidden_size,
|
723 |
mlp_impl=mlp_impl,
|
724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
725 |
return output, expert_weights_out
|
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
155 |
+
# Shared expert MLP forward pass
|
156 |
+
def shared_mlp_forward(
|
157 |
+
x: torch.Tensor,
|
158 |
+
up_proj_weight: torch.Tensor,
|
159 |
+
down_proj_weight: torch.Tensor,
|
160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
162 |
+
activation_fn: Optional[Any] = None,
|
163 |
+
gradient_scale: Optional[float] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
# Default activation function
|
166 |
+
if activation_fn is None:
|
167 |
+
activation_fn = torch.nn.functional.gelu
|
168 |
+
|
169 |
+
# Scale weights
|
170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
172 |
+
if up_proj_bias is not None:
|
173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
174 |
+
if down_proj_bias is not None:
|
175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
176 |
+
|
177 |
+
# Resolve dtensors
|
178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
180 |
+
if up_proj_bias is not None:
|
181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
182 |
+
if down_proj_bias is not None:
|
183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
184 |
+
|
185 |
+
# Up projection
|
186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
187 |
+
|
188 |
+
# Activation
|
189 |
+
x = activation_fn(x)
|
190 |
+
|
191 |
+
# Down projection
|
192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
# Combine outputs from shared expert and regular experts
|
198 |
+
def combine_expert_shared_outputs(
|
199 |
+
shared_expert_out: torch.Tensor,
|
200 |
+
expert_out: torch.Tensor,
|
201 |
+
shared_expert_weighted_sum: bool = False,
|
202 |
+
moe_top_k: int = 1,
|
203 |
+
) -> torch.Tensor:
|
204 |
+
if shared_expert_weighted_sum:
|
205 |
+
# Weighted sum based on number of experts used
|
206 |
+
total_experts = moe_top_k + 1
|
207 |
+
shared_weight = 1.0 / total_experts
|
208 |
+
expert_weight = moe_top_k / total_experts
|
209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
210 |
+
else:
|
211 |
+
# Simple addition
|
212 |
+
return shared_expert_out + expert_out
|
213 |
+
|
214 |
+
|
215 |
# Global variable to store load balancing loss
|
216 |
_LOAD_BALANCING_LOSS = []
|
217 |
|
|
|
740 |
return x, expert_weights, router_scores
|
741 |
|
742 |
|
743 |
+
def moe_forward_with_shared_expert(
|
744 |
+
x: torch.Tensor,
|
745 |
+
router_weight: torch.Tensor,
|
746 |
+
moe_top_k: int,
|
747 |
+
moe_num_experts: int,
|
748 |
+
moe_jitter_eps: float = None,
|
749 |
+
moe_normalize_expert_weights: int = None,
|
750 |
+
uniform_expert_assignment: bool = False,
|
751 |
+
training: bool = False,
|
752 |
+
w1: torch.Tensor = None,
|
753 |
+
w2: torch.Tensor = None,
|
754 |
+
w1_bias: torch.Tensor = None,
|
755 |
+
w2_bias: torch.Tensor = None,
|
756 |
+
gradient_scale: Optional[float] = None,
|
757 |
+
alpha: float = 1.702,
|
758 |
+
sort_end_bit: int = 0,
|
759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
760 |
+
moe_capacity_factor: float = 1.0,
|
761 |
+
moe_expert_model_parallelism: bool = False,
|
762 |
+
forward_fn: Any = None,
|
763 |
+
hidden_size: int = None,
|
764 |
+
mlp_impl: str = "grouped",
|
765 |
+
# Shared expert parameters
|
766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
770 |
+
shared_expert_weighted_sum: bool = False,
|
771 |
+
shared_activation_fn: Optional[Any] = None,
|
772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
773 |
+
|
774 |
+
# First, compute regular MoE forward pass
|
775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
776 |
+
x=x,
|
777 |
+
router_weight=router_weight,
|
778 |
+
moe_top_k=moe_top_k,
|
779 |
+
moe_num_experts=moe_num_experts,
|
780 |
+
moe_jitter_eps=moe_jitter_eps,
|
781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
783 |
+
training=training,
|
784 |
+
w1=w1,
|
785 |
+
w2=w2,
|
786 |
+
w1_bias=w1_bias,
|
787 |
+
w2_bias=w2_bias,
|
788 |
+
gradient_scale=gradient_scale,
|
789 |
+
alpha=alpha,
|
790 |
+
sort_end_bit=sort_end_bit,
|
791 |
+
expert_parallel_group=expert_parallel_group,
|
792 |
+
moe_capacity_factor=moe_capacity_factor,
|
793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
794 |
+
forward_fn=forward_fn,
|
795 |
+
hidden_size=hidden_size,
|
796 |
+
mlp_impl=mlp_impl,
|
797 |
+
)
|
798 |
+
|
799 |
+
# If shared expert weights provided, compute shared expert output
|
800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
801 |
+
shared_expert_out = shared_mlp_forward(
|
802 |
+
x=x,
|
803 |
+
up_proj_weight=shared_up_proj_weight,
|
804 |
+
down_proj_weight=shared_down_proj_weight,
|
805 |
+
up_proj_bias=shared_up_proj_bias,
|
806 |
+
down_proj_bias=shared_down_proj_bias,
|
807 |
+
activation_fn=shared_activation_fn,
|
808 |
+
gradient_scale=gradient_scale,
|
809 |
+
)
|
810 |
+
|
811 |
+
# Combine expert outputs
|
812 |
+
combined_out = combine_expert_shared_outputs(
|
813 |
+
shared_expert_out=shared_expert_out,
|
814 |
+
expert_out=expert_out,
|
815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
816 |
+
moe_top_k=moe_top_k,
|
817 |
+
)
|
818 |
+
|
819 |
+
return combined_out, expert_weights, router_scores
|
820 |
+
|
821 |
+
# Return regular MoE output if no shared expert
|
822 |
+
return expert_out, expert_weights, router_scores
|
823 |
+
|
824 |
+
|
825 |
+
def create_shared_expert_weights(
|
826 |
+
hidden_size: int,
|
827 |
+
shared_expert_hidden_size: int,
|
828 |
+
device: torch.device,
|
829 |
+
dtype: torch.dtype,
|
830 |
+
init_method: Any,
|
831 |
+
output_layer_init_method: Any = None,
|
832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
833 |
+
|
834 |
+
if output_layer_init_method is None:
|
835 |
+
output_layer_init_method = init_method
|
836 |
+
|
837 |
+
# Create weight tensors
|
838 |
+
up_proj_weight = torch.empty(
|
839 |
+
shared_expert_hidden_size,
|
840 |
+
hidden_size,
|
841 |
+
device=device,
|
842 |
+
dtype=dtype,
|
843 |
+
)
|
844 |
+
down_proj_weight = torch.empty(
|
845 |
+
hidden_size,
|
846 |
+
shared_expert_hidden_size,
|
847 |
+
device=device,
|
848 |
+
dtype=dtype,
|
849 |
+
)
|
850 |
+
|
851 |
+
# Initialize weights
|
852 |
+
init_method(up_proj_weight)
|
853 |
+
output_layer_init_method(down_proj_weight)
|
854 |
+
|
855 |
+
# No bias by default
|
856 |
+
return up_proj_weight, down_proj_weight, None, None
|
857 |
+
|
858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
861 |
+
# TODO: Replace with a more robust solution when available
|
862 |
+
def get_device_mesh(model):
|
863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
864 |
+
try:
|
865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
867 |
+
# Extract the device_mesh from the closure
|
868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
869 |
+
except Exception:
|
870 |
+
return None
|
871 |
+
|
872 |
+
|
873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
874 |
|
875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
884 |
+
|
885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
886 |
+
if expert_parallel_group is None:
|
887 |
+
device_mesh = get_device_mesh(self)
|
888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
889 |
+
|
890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
892 |
|
|
|
916 |
hidden_size=self.experts.hidden_size,
|
917 |
mlp_impl=mlp_impl,
|
918 |
)
|
919 |
+
return output, expert_weights_out
|
920 |
+
|
921 |
+
|
922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
923 |
+
|
924 |
+
def __init__(self):
|
925 |
+
super().__init__()
|
926 |
+
# Shared expert weights will be set by the user
|
927 |
+
self.shared_up_proj_weight = None
|
928 |
+
self.shared_down_proj_weight = None
|
929 |
+
self.shared_up_proj_bias = None
|
930 |
+
self.shared_down_proj_bias = None
|
931 |
+
self.shared_expert_weighted_sum = False
|
932 |
+
self.shared_activation_fn = None
|
933 |
+
|
934 |
+
def set_shared_expert_weights(
|
935 |
+
self,
|
936 |
+
up_proj_weight: torch.Tensor,
|
937 |
+
down_proj_weight: torch.Tensor,
|
938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
940 |
+
weighted_sum: bool = False,
|
941 |
+
activation_fn: Optional[Any] = None,
|
942 |
+
):
|
943 |
+
self.shared_up_proj_weight = up_proj_weight
|
944 |
+
self.shared_down_proj_weight = down_proj_weight
|
945 |
+
self.shared_up_proj_bias = up_proj_bias
|
946 |
+
self.shared_down_proj_bias = down_proj_bias
|
947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
948 |
+
self.shared_activation_fn = activation_fn
|
949 |
+
|
950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
959 |
+
|
960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
961 |
+
if expert_parallel_group is None:
|
962 |
+
device_mesh = get_device_mesh(self)
|
963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
964 |
+
|
965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
967 |
+
|
968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
970 |
+
|
971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
972 |
+
x=x,
|
973 |
+
router_weight=self.router.weight,
|
974 |
+
moe_top_k=moe_top_k,
|
975 |
+
moe_num_experts=moe_num_experts,
|
976 |
+
moe_jitter_eps=moe_jitter_eps,
|
977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
979 |
+
training=self.training,
|
980 |
+
w1=self.experts.gate_up_proj,
|
981 |
+
w2=self.experts.down_proj,
|
982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
983 |
+
w2_bias=self.experts.down_proj_bias,
|
984 |
+
gradient_scale=gradient_scale,
|
985 |
+
alpha=alpha,
|
986 |
+
sort_end_bit=sort_end_bit,
|
987 |
+
expert_parallel_group=expert_parallel_group,
|
988 |
+
moe_capacity_factor=moe_capacity_factor,
|
989 |
+
moe_expert_model_parallelism=has_parallel,
|
990 |
+
forward_fn=forward_fn,
|
991 |
+
hidden_size=self.experts.hidden_size,
|
992 |
+
mlp_impl=mlp_impl,
|
993 |
+
# Shared expert parameters
|
994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
999 |
+
shared_activation_fn=self.shared_activation_fn,
|
1000 |
+
)
|
1001 |
return output, expert_weights_out
|
build/torch26-cxx98-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 10510040
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fababa7e0d2c20c98afaebef6165a8145b33d80cdadba28f895c14dd2a7b2823
|
3 |
size 10510040
|
build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:01f0c774e900380d3c0721dfe15591c67be5d5eb5ad687af6c89a88ecdff4f2a
|
3 |
-
size 10510040
|
|
|
|
|
|
|
|
build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _megablocks_89e2950
|
3 |
+
ops = torch.ops._megablocks_89e2950
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py
CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# Global variable to store load balancing loss
|
156 |
_LOAD_BALANCING_LOSS = []
|
157 |
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
680 |
return x, expert_weights, router_scores
|
681 |
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
-
|
695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
722 |
hidden_size=self.experts.hidden_size,
|
723 |
mlp_impl=mlp_impl,
|
724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
725 |
return output, expert_weights_out
|
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
155 |
+
# Shared expert MLP forward pass
|
156 |
+
def shared_mlp_forward(
|
157 |
+
x: torch.Tensor,
|
158 |
+
up_proj_weight: torch.Tensor,
|
159 |
+
down_proj_weight: torch.Tensor,
|
160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
162 |
+
activation_fn: Optional[Any] = None,
|
163 |
+
gradient_scale: Optional[float] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
# Default activation function
|
166 |
+
if activation_fn is None:
|
167 |
+
activation_fn = torch.nn.functional.gelu
|
168 |
+
|
169 |
+
# Scale weights
|
170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
172 |
+
if up_proj_bias is not None:
|
173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
174 |
+
if down_proj_bias is not None:
|
175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
176 |
+
|
177 |
+
# Resolve dtensors
|
178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
180 |
+
if up_proj_bias is not None:
|
181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
182 |
+
if down_proj_bias is not None:
|
183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
184 |
+
|
185 |
+
# Up projection
|
186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
187 |
+
|
188 |
+
# Activation
|
189 |
+
x = activation_fn(x)
|
190 |
+
|
191 |
+
# Down projection
|
192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
# Combine outputs from shared expert and regular experts
|
198 |
+
def combine_expert_shared_outputs(
|
199 |
+
shared_expert_out: torch.Tensor,
|
200 |
+
expert_out: torch.Tensor,
|
201 |
+
shared_expert_weighted_sum: bool = False,
|
202 |
+
moe_top_k: int = 1,
|
203 |
+
) -> torch.Tensor:
|
204 |
+
if shared_expert_weighted_sum:
|
205 |
+
# Weighted sum based on number of experts used
|
206 |
+
total_experts = moe_top_k + 1
|
207 |
+
shared_weight = 1.0 / total_experts
|
208 |
+
expert_weight = moe_top_k / total_experts
|
209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
210 |
+
else:
|
211 |
+
# Simple addition
|
212 |
+
return shared_expert_out + expert_out
|
213 |
+
|
214 |
+
|
215 |
# Global variable to store load balancing loss
|
216 |
_LOAD_BALANCING_LOSS = []
|
217 |
|
|
|
740 |
return x, expert_weights, router_scores
|
741 |
|
742 |
|
743 |
+
def moe_forward_with_shared_expert(
|
744 |
+
x: torch.Tensor,
|
745 |
+
router_weight: torch.Tensor,
|
746 |
+
moe_top_k: int,
|
747 |
+
moe_num_experts: int,
|
748 |
+
moe_jitter_eps: float = None,
|
749 |
+
moe_normalize_expert_weights: int = None,
|
750 |
+
uniform_expert_assignment: bool = False,
|
751 |
+
training: bool = False,
|
752 |
+
w1: torch.Tensor = None,
|
753 |
+
w2: torch.Tensor = None,
|
754 |
+
w1_bias: torch.Tensor = None,
|
755 |
+
w2_bias: torch.Tensor = None,
|
756 |
+
gradient_scale: Optional[float] = None,
|
757 |
+
alpha: float = 1.702,
|
758 |
+
sort_end_bit: int = 0,
|
759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
760 |
+
moe_capacity_factor: float = 1.0,
|
761 |
+
moe_expert_model_parallelism: bool = False,
|
762 |
+
forward_fn: Any = None,
|
763 |
+
hidden_size: int = None,
|
764 |
+
mlp_impl: str = "grouped",
|
765 |
+
# Shared expert parameters
|
766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
770 |
+
shared_expert_weighted_sum: bool = False,
|
771 |
+
shared_activation_fn: Optional[Any] = None,
|
772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
773 |
+
|
774 |
+
# First, compute regular MoE forward pass
|
775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
776 |
+
x=x,
|
777 |
+
router_weight=router_weight,
|
778 |
+
moe_top_k=moe_top_k,
|
779 |
+
moe_num_experts=moe_num_experts,
|
780 |
+
moe_jitter_eps=moe_jitter_eps,
|
781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
783 |
+
training=training,
|
784 |
+
w1=w1,
|
785 |
+
w2=w2,
|
786 |
+
w1_bias=w1_bias,
|
787 |
+
w2_bias=w2_bias,
|
788 |
+
gradient_scale=gradient_scale,
|
789 |
+
alpha=alpha,
|
790 |
+
sort_end_bit=sort_end_bit,
|
791 |
+
expert_parallel_group=expert_parallel_group,
|
792 |
+
moe_capacity_factor=moe_capacity_factor,
|
793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
794 |
+
forward_fn=forward_fn,
|
795 |
+
hidden_size=hidden_size,
|
796 |
+
mlp_impl=mlp_impl,
|
797 |
+
)
|
798 |
+
|
799 |
+
# If shared expert weights provided, compute shared expert output
|
800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
801 |
+
shared_expert_out = shared_mlp_forward(
|
802 |
+
x=x,
|
803 |
+
up_proj_weight=shared_up_proj_weight,
|
804 |
+
down_proj_weight=shared_down_proj_weight,
|
805 |
+
up_proj_bias=shared_up_proj_bias,
|
806 |
+
down_proj_bias=shared_down_proj_bias,
|
807 |
+
activation_fn=shared_activation_fn,
|
808 |
+
gradient_scale=gradient_scale,
|
809 |
+
)
|
810 |
+
|
811 |
+
# Combine expert outputs
|
812 |
+
combined_out = combine_expert_shared_outputs(
|
813 |
+
shared_expert_out=shared_expert_out,
|
814 |
+
expert_out=expert_out,
|
815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
816 |
+
moe_top_k=moe_top_k,
|
817 |
+
)
|
818 |
+
|
819 |
+
return combined_out, expert_weights, router_scores
|
820 |
+
|
821 |
+
# Return regular MoE output if no shared expert
|
822 |
+
return expert_out, expert_weights, router_scores
|
823 |
+
|
824 |
+
|
825 |
+
def create_shared_expert_weights(
|
826 |
+
hidden_size: int,
|
827 |
+
shared_expert_hidden_size: int,
|
828 |
+
device: torch.device,
|
829 |
+
dtype: torch.dtype,
|
830 |
+
init_method: Any,
|
831 |
+
output_layer_init_method: Any = None,
|
832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
833 |
+
|
834 |
+
if output_layer_init_method is None:
|
835 |
+
output_layer_init_method = init_method
|
836 |
+
|
837 |
+
# Create weight tensors
|
838 |
+
up_proj_weight = torch.empty(
|
839 |
+
shared_expert_hidden_size,
|
840 |
+
hidden_size,
|
841 |
+
device=device,
|
842 |
+
dtype=dtype,
|
843 |
+
)
|
844 |
+
down_proj_weight = torch.empty(
|
845 |
+
hidden_size,
|
846 |
+
shared_expert_hidden_size,
|
847 |
+
device=device,
|
848 |
+
dtype=dtype,
|
849 |
+
)
|
850 |
+
|
851 |
+
# Initialize weights
|
852 |
+
init_method(up_proj_weight)
|
853 |
+
output_layer_init_method(down_proj_weight)
|
854 |
+
|
855 |
+
# No bias by default
|
856 |
+
return up_proj_weight, down_proj_weight, None, None
|
857 |
+
|
858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
861 |
+
# TODO: Replace with a more robust solution when available
|
862 |
+
def get_device_mesh(model):
|
863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
864 |
+
try:
|
865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
867 |
+
# Extract the device_mesh from the closure
|
868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
869 |
+
except Exception:
|
870 |
+
return None
|
871 |
+
|
872 |
+
|
873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
874 |
|
875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
884 |
+
|
885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
886 |
+
if expert_parallel_group is None:
|
887 |
+
device_mesh = get_device_mesh(self)
|
888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
889 |
+
|
890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
892 |
|
|
|
916 |
hidden_size=self.experts.hidden_size,
|
917 |
mlp_impl=mlp_impl,
|
918 |
)
|
919 |
+
return output, expert_weights_out
|
920 |
+
|
921 |
+
|
922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
923 |
+
|
924 |
+
def __init__(self):
|
925 |
+
super().__init__()
|
926 |
+
# Shared expert weights will be set by the user
|
927 |
+
self.shared_up_proj_weight = None
|
928 |
+
self.shared_down_proj_weight = None
|
929 |
+
self.shared_up_proj_bias = None
|
930 |
+
self.shared_down_proj_bias = None
|
931 |
+
self.shared_expert_weighted_sum = False
|
932 |
+
self.shared_activation_fn = None
|
933 |
+
|
934 |
+
def set_shared_expert_weights(
|
935 |
+
self,
|
936 |
+
up_proj_weight: torch.Tensor,
|
937 |
+
down_proj_weight: torch.Tensor,
|
938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
940 |
+
weighted_sum: bool = False,
|
941 |
+
activation_fn: Optional[Any] = None,
|
942 |
+
):
|
943 |
+
self.shared_up_proj_weight = up_proj_weight
|
944 |
+
self.shared_down_proj_weight = down_proj_weight
|
945 |
+
self.shared_up_proj_bias = up_proj_bias
|
946 |
+
self.shared_down_proj_bias = down_proj_bias
|
947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
948 |
+
self.shared_activation_fn = activation_fn
|
949 |
+
|
950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
959 |
+
|
960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
961 |
+
if expert_parallel_group is None:
|
962 |
+
device_mesh = get_device_mesh(self)
|
963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
964 |
+
|
965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
967 |
+
|
968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
970 |
+
|
971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
972 |
+
x=x,
|
973 |
+
router_weight=self.router.weight,
|
974 |
+
moe_top_k=moe_top_k,
|
975 |
+
moe_num_experts=moe_num_experts,
|
976 |
+
moe_jitter_eps=moe_jitter_eps,
|
977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
979 |
+
training=self.training,
|
980 |
+
w1=self.experts.gate_up_proj,
|
981 |
+
w2=self.experts.down_proj,
|
982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
983 |
+
w2_bias=self.experts.down_proj_bias,
|
984 |
+
gradient_scale=gradient_scale,
|
985 |
+
alpha=alpha,
|
986 |
+
sort_end_bit=sort_end_bit,
|
987 |
+
expert_parallel_group=expert_parallel_group,
|
988 |
+
moe_capacity_factor=moe_capacity_factor,
|
989 |
+
moe_expert_model_parallelism=has_parallel,
|
990 |
+
forward_fn=forward_fn,
|
991 |
+
hidden_size=self.experts.hidden_size,
|
992 |
+
mlp_impl=mlp_impl,
|
993 |
+
# Shared expert parameters
|
994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
999 |
+
shared_activation_fn=self.shared_activation_fn,
|
1000 |
+
)
|
1001 |
return output, expert_weights_out
|
build/torch26-cxx98-cu124-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 11857920
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e3663f46030f07e030efe94c26495d17b2703551a46c0ca3acf8b25ecb2a238
|
3 |
size 11857920
|
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:09a5f57ae37af9f5b14c4a0f21d1679e32f5b7424973c36dac9bbbecbfbf7374
|
3 |
-
size 11857920
|
|
|
|
|
|
|
|
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _megablocks_89e2950
|
3 |
+
ops = torch.ops._megablocks_89e2950
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py
CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# Global variable to store load balancing loss
|
156 |
_LOAD_BALANCING_LOSS = []
|
157 |
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
680 |
return x, expert_weights, router_scores
|
681 |
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
-
|
695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
722 |
hidden_size=self.experts.hidden_size,
|
723 |
mlp_impl=mlp_impl,
|
724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
725 |
return output, expert_weights_out
|
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
155 |
+
# Shared expert MLP forward pass
|
156 |
+
def shared_mlp_forward(
|
157 |
+
x: torch.Tensor,
|
158 |
+
up_proj_weight: torch.Tensor,
|
159 |
+
down_proj_weight: torch.Tensor,
|
160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
162 |
+
activation_fn: Optional[Any] = None,
|
163 |
+
gradient_scale: Optional[float] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
# Default activation function
|
166 |
+
if activation_fn is None:
|
167 |
+
activation_fn = torch.nn.functional.gelu
|
168 |
+
|
169 |
+
# Scale weights
|
170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
172 |
+
if up_proj_bias is not None:
|
173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
174 |
+
if down_proj_bias is not None:
|
175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
176 |
+
|
177 |
+
# Resolve dtensors
|
178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
180 |
+
if up_proj_bias is not None:
|
181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
182 |
+
if down_proj_bias is not None:
|
183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
184 |
+
|
185 |
+
# Up projection
|
186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
187 |
+
|
188 |
+
# Activation
|
189 |
+
x = activation_fn(x)
|
190 |
+
|
191 |
+
# Down projection
|
192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
# Combine outputs from shared expert and regular experts
|
198 |
+
def combine_expert_shared_outputs(
|
199 |
+
shared_expert_out: torch.Tensor,
|
200 |
+
expert_out: torch.Tensor,
|
201 |
+
shared_expert_weighted_sum: bool = False,
|
202 |
+
moe_top_k: int = 1,
|
203 |
+
) -> torch.Tensor:
|
204 |
+
if shared_expert_weighted_sum:
|
205 |
+
# Weighted sum based on number of experts used
|
206 |
+
total_experts = moe_top_k + 1
|
207 |
+
shared_weight = 1.0 / total_experts
|
208 |
+
expert_weight = moe_top_k / total_experts
|
209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
210 |
+
else:
|
211 |
+
# Simple addition
|
212 |
+
return shared_expert_out + expert_out
|
213 |
+
|
214 |
+
|
215 |
# Global variable to store load balancing loss
|
216 |
_LOAD_BALANCING_LOSS = []
|
217 |
|
|
|
740 |
return x, expert_weights, router_scores
|
741 |
|
742 |
|
743 |
+
def moe_forward_with_shared_expert(
|
744 |
+
x: torch.Tensor,
|
745 |
+
router_weight: torch.Tensor,
|
746 |
+
moe_top_k: int,
|
747 |
+
moe_num_experts: int,
|
748 |
+
moe_jitter_eps: float = None,
|
749 |
+
moe_normalize_expert_weights: int = None,
|
750 |
+
uniform_expert_assignment: bool = False,
|
751 |
+
training: bool = False,
|
752 |
+
w1: torch.Tensor = None,
|
753 |
+
w2: torch.Tensor = None,
|
754 |
+
w1_bias: torch.Tensor = None,
|
755 |
+
w2_bias: torch.Tensor = None,
|
756 |
+
gradient_scale: Optional[float] = None,
|
757 |
+
alpha: float = 1.702,
|
758 |
+
sort_end_bit: int = 0,
|
759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
760 |
+
moe_capacity_factor: float = 1.0,
|
761 |
+
moe_expert_model_parallelism: bool = False,
|
762 |
+
forward_fn: Any = None,
|
763 |
+
hidden_size: int = None,
|
764 |
+
mlp_impl: str = "grouped",
|
765 |
+
# Shared expert parameters
|
766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
770 |
+
shared_expert_weighted_sum: bool = False,
|
771 |
+
shared_activation_fn: Optional[Any] = None,
|
772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
773 |
+
|
774 |
+
# First, compute regular MoE forward pass
|
775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
776 |
+
x=x,
|
777 |
+
router_weight=router_weight,
|
778 |
+
moe_top_k=moe_top_k,
|
779 |
+
moe_num_experts=moe_num_experts,
|
780 |
+
moe_jitter_eps=moe_jitter_eps,
|
781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
783 |
+
training=training,
|
784 |
+
w1=w1,
|
785 |
+
w2=w2,
|
786 |
+
w1_bias=w1_bias,
|
787 |
+
w2_bias=w2_bias,
|
788 |
+
gradient_scale=gradient_scale,
|
789 |
+
alpha=alpha,
|
790 |
+
sort_end_bit=sort_end_bit,
|
791 |
+
expert_parallel_group=expert_parallel_group,
|
792 |
+
moe_capacity_factor=moe_capacity_factor,
|
793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
794 |
+
forward_fn=forward_fn,
|
795 |
+
hidden_size=hidden_size,
|
796 |
+
mlp_impl=mlp_impl,
|
797 |
+
)
|
798 |
+
|
799 |
+
# If shared expert weights provided, compute shared expert output
|
800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
801 |
+
shared_expert_out = shared_mlp_forward(
|
802 |
+
x=x,
|
803 |
+
up_proj_weight=shared_up_proj_weight,
|
804 |
+
down_proj_weight=shared_down_proj_weight,
|
805 |
+
up_proj_bias=shared_up_proj_bias,
|
806 |
+
down_proj_bias=shared_down_proj_bias,
|
807 |
+
activation_fn=shared_activation_fn,
|
808 |
+
gradient_scale=gradient_scale,
|
809 |
+
)
|
810 |
+
|
811 |
+
# Combine expert outputs
|
812 |
+
combined_out = combine_expert_shared_outputs(
|
813 |
+
shared_expert_out=shared_expert_out,
|
814 |
+
expert_out=expert_out,
|
815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
816 |
+
moe_top_k=moe_top_k,
|
817 |
+
)
|
818 |
+
|
819 |
+
return combined_out, expert_weights, router_scores
|
820 |
+
|
821 |
+
# Return regular MoE output if no shared expert
|
822 |
+
return expert_out, expert_weights, router_scores
|
823 |
+
|
824 |
+
|
825 |
+
def create_shared_expert_weights(
|
826 |
+
hidden_size: int,
|
827 |
+
shared_expert_hidden_size: int,
|
828 |
+
device: torch.device,
|
829 |
+
dtype: torch.dtype,
|
830 |
+
init_method: Any,
|
831 |
+
output_layer_init_method: Any = None,
|
832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
833 |
+
|
834 |
+
if output_layer_init_method is None:
|
835 |
+
output_layer_init_method = init_method
|
836 |
+
|
837 |
+
# Create weight tensors
|
838 |
+
up_proj_weight = torch.empty(
|
839 |
+
shared_expert_hidden_size,
|
840 |
+
hidden_size,
|
841 |
+
device=device,
|
842 |
+
dtype=dtype,
|
843 |
+
)
|
844 |
+
down_proj_weight = torch.empty(
|
845 |
+
hidden_size,
|
846 |
+
shared_expert_hidden_size,
|
847 |
+
device=device,
|
848 |
+
dtype=dtype,
|
849 |
+
)
|
850 |
+
|
851 |
+
# Initialize weights
|
852 |
+
init_method(up_proj_weight)
|
853 |
+
output_layer_init_method(down_proj_weight)
|
854 |
+
|
855 |
+
# No bias by default
|
856 |
+
return up_proj_weight, down_proj_weight, None, None
|
857 |
+
|
858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
861 |
+
# TODO: Replace with a more robust solution when available
|
862 |
+
def get_device_mesh(model):
|
863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
864 |
+
try:
|
865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
867 |
+
# Extract the device_mesh from the closure
|
868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
869 |
+
except Exception:
|
870 |
+
return None
|
871 |
+
|
872 |
+
|
873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
874 |
|
875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
884 |
+
|
885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
886 |
+
if expert_parallel_group is None:
|
887 |
+
device_mesh = get_device_mesh(self)
|
888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
889 |
+
|
890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
892 |
|
|
|
916 |
hidden_size=self.experts.hidden_size,
|
917 |
mlp_impl=mlp_impl,
|
918 |
)
|
919 |
+
return output, expert_weights_out
|
920 |
+
|
921 |
+
|
922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
923 |
+
|
924 |
+
def __init__(self):
|
925 |
+
super().__init__()
|
926 |
+
# Shared expert weights will be set by the user
|
927 |
+
self.shared_up_proj_weight = None
|
928 |
+
self.shared_down_proj_weight = None
|
929 |
+
self.shared_up_proj_bias = None
|
930 |
+
self.shared_down_proj_bias = None
|
931 |
+
self.shared_expert_weighted_sum = False
|
932 |
+
self.shared_activation_fn = None
|
933 |
+
|
934 |
+
def set_shared_expert_weights(
|
935 |
+
self,
|
936 |
+
up_proj_weight: torch.Tensor,
|
937 |
+
down_proj_weight: torch.Tensor,
|
938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
940 |
+
weighted_sum: bool = False,
|
941 |
+
activation_fn: Optional[Any] = None,
|
942 |
+
):
|
943 |
+
self.shared_up_proj_weight = up_proj_weight
|
944 |
+
self.shared_down_proj_weight = down_proj_weight
|
945 |
+
self.shared_up_proj_bias = up_proj_bias
|
946 |
+
self.shared_down_proj_bias = down_proj_bias
|
947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
948 |
+
self.shared_activation_fn = activation_fn
|
949 |
+
|
950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
959 |
+
|
960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
961 |
+
if expert_parallel_group is None:
|
962 |
+
device_mesh = get_device_mesh(self)
|
963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
964 |
+
|
965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
967 |
+
|
968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
970 |
+
|
971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
972 |
+
x=x,
|
973 |
+
router_weight=self.router.weight,
|
974 |
+
moe_top_k=moe_top_k,
|
975 |
+
moe_num_experts=moe_num_experts,
|
976 |
+
moe_jitter_eps=moe_jitter_eps,
|
977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
979 |
+
training=self.training,
|
980 |
+
w1=self.experts.gate_up_proj,
|
981 |
+
w2=self.experts.down_proj,
|
982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
983 |
+
w2_bias=self.experts.down_proj_bias,
|
984 |
+
gradient_scale=gradient_scale,
|
985 |
+
alpha=alpha,
|
986 |
+
sort_end_bit=sort_end_bit,
|
987 |
+
expert_parallel_group=expert_parallel_group,
|
988 |
+
moe_capacity_factor=moe_capacity_factor,
|
989 |
+
moe_expert_model_parallelism=has_parallel,
|
990 |
+
forward_fn=forward_fn,
|
991 |
+
hidden_size=self.experts.hidden_size,
|
992 |
+
mlp_impl=mlp_impl,
|
993 |
+
# Shared expert parameters
|
994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
999 |
+
shared_activation_fn=self.shared_activation_fn,
|
1000 |
+
)
|
1001 |
return output, expert_weights_out
|
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_76c7de7.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:a3f893773ec7b8157a4531a57821807f5f27ac48ceaa695c342cc7a39ad318dc
|
3 |
-
size 11927768
|
|
|
|
|
|
|
|
build/torch26-cxx98-cu126-x86_64-linux/megablocks/{_megablocks_9a1816c.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 11923672
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1571732c5954914d5ddf0f12ebc4074d88d907130d71d898de43958e3b9a5d1
|
3 |
size 11923672
|
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _megablocks_89e2950
|
3 |
+
ops = torch.ops._megablocks_89e2950
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py
CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# Global variable to store load balancing loss
|
156 |
_LOAD_BALANCING_LOSS = []
|
157 |
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
680 |
return x, expert_weights, router_scores
|
681 |
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
-
|
695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
722 |
hidden_size=self.experts.hidden_size,
|
723 |
mlp_impl=mlp_impl,
|
724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
725 |
return output, expert_weights_out
|
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
155 |
+
# Shared expert MLP forward pass
|
156 |
+
def shared_mlp_forward(
|
157 |
+
x: torch.Tensor,
|
158 |
+
up_proj_weight: torch.Tensor,
|
159 |
+
down_proj_weight: torch.Tensor,
|
160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
162 |
+
activation_fn: Optional[Any] = None,
|
163 |
+
gradient_scale: Optional[float] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
# Default activation function
|
166 |
+
if activation_fn is None:
|
167 |
+
activation_fn = torch.nn.functional.gelu
|
168 |
+
|
169 |
+
# Scale weights
|
170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
172 |
+
if up_proj_bias is not None:
|
173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
174 |
+
if down_proj_bias is not None:
|
175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
176 |
+
|
177 |
+
# Resolve dtensors
|
178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
180 |
+
if up_proj_bias is not None:
|
181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
182 |
+
if down_proj_bias is not None:
|
183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
184 |
+
|
185 |
+
# Up projection
|
186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
187 |
+
|
188 |
+
# Activation
|
189 |
+
x = activation_fn(x)
|
190 |
+
|
191 |
+
# Down projection
|
192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
# Combine outputs from shared expert and regular experts
|
198 |
+
def combine_expert_shared_outputs(
|
199 |
+
shared_expert_out: torch.Tensor,
|
200 |
+
expert_out: torch.Tensor,
|
201 |
+
shared_expert_weighted_sum: bool = False,
|
202 |
+
moe_top_k: int = 1,
|
203 |
+
) -> torch.Tensor:
|
204 |
+
if shared_expert_weighted_sum:
|
205 |
+
# Weighted sum based on number of experts used
|
206 |
+
total_experts = moe_top_k + 1
|
207 |
+
shared_weight = 1.0 / total_experts
|
208 |
+
expert_weight = moe_top_k / total_experts
|
209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
210 |
+
else:
|
211 |
+
# Simple addition
|
212 |
+
return shared_expert_out + expert_out
|
213 |
+
|
214 |
+
|
215 |
# Global variable to store load balancing loss
|
216 |
_LOAD_BALANCING_LOSS = []
|
217 |
|
|
|
740 |
return x, expert_weights, router_scores
|
741 |
|
742 |
|
743 |
+
def moe_forward_with_shared_expert(
|
744 |
+
x: torch.Tensor,
|
745 |
+
router_weight: torch.Tensor,
|
746 |
+
moe_top_k: int,
|
747 |
+
moe_num_experts: int,
|
748 |
+
moe_jitter_eps: float = None,
|
749 |
+
moe_normalize_expert_weights: int = None,
|
750 |
+
uniform_expert_assignment: bool = False,
|
751 |
+
training: bool = False,
|
752 |
+
w1: torch.Tensor = None,
|
753 |
+
w2: torch.Tensor = None,
|
754 |
+
w1_bias: torch.Tensor = None,
|
755 |
+
w2_bias: torch.Tensor = None,
|
756 |
+
gradient_scale: Optional[float] = None,
|
757 |
+
alpha: float = 1.702,
|
758 |
+
sort_end_bit: int = 0,
|
759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
760 |
+
moe_capacity_factor: float = 1.0,
|
761 |
+
moe_expert_model_parallelism: bool = False,
|
762 |
+
forward_fn: Any = None,
|
763 |
+
hidden_size: int = None,
|
764 |
+
mlp_impl: str = "grouped",
|
765 |
+
# Shared expert parameters
|
766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
770 |
+
shared_expert_weighted_sum: bool = False,
|
771 |
+
shared_activation_fn: Optional[Any] = None,
|
772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
773 |
+
|
774 |
+
# First, compute regular MoE forward pass
|
775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
776 |
+
x=x,
|
777 |
+
router_weight=router_weight,
|
778 |
+
moe_top_k=moe_top_k,
|
779 |
+
moe_num_experts=moe_num_experts,
|
780 |
+
moe_jitter_eps=moe_jitter_eps,
|
781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
783 |
+
training=training,
|
784 |
+
w1=w1,
|
785 |
+
w2=w2,
|
786 |
+
w1_bias=w1_bias,
|
787 |
+
w2_bias=w2_bias,
|
788 |
+
gradient_scale=gradient_scale,
|
789 |
+
alpha=alpha,
|
790 |
+
sort_end_bit=sort_end_bit,
|
791 |
+
expert_parallel_group=expert_parallel_group,
|
792 |
+
moe_capacity_factor=moe_capacity_factor,
|
793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
794 |
+
forward_fn=forward_fn,
|
795 |
+
hidden_size=hidden_size,
|
796 |
+
mlp_impl=mlp_impl,
|
797 |
+
)
|
798 |
+
|
799 |
+
# If shared expert weights provided, compute shared expert output
|
800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
801 |
+
shared_expert_out = shared_mlp_forward(
|
802 |
+
x=x,
|
803 |
+
up_proj_weight=shared_up_proj_weight,
|
804 |
+
down_proj_weight=shared_down_proj_weight,
|
805 |
+
up_proj_bias=shared_up_proj_bias,
|
806 |
+
down_proj_bias=shared_down_proj_bias,
|
807 |
+
activation_fn=shared_activation_fn,
|
808 |
+
gradient_scale=gradient_scale,
|
809 |
+
)
|
810 |
+
|
811 |
+
# Combine expert outputs
|
812 |
+
combined_out = combine_expert_shared_outputs(
|
813 |
+
shared_expert_out=shared_expert_out,
|
814 |
+
expert_out=expert_out,
|
815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
816 |
+
moe_top_k=moe_top_k,
|
817 |
+
)
|
818 |
+
|
819 |
+
return combined_out, expert_weights, router_scores
|
820 |
+
|
821 |
+
# Return regular MoE output if no shared expert
|
822 |
+
return expert_out, expert_weights, router_scores
|
823 |
+
|
824 |
+
|
825 |
+
def create_shared_expert_weights(
|
826 |
+
hidden_size: int,
|
827 |
+
shared_expert_hidden_size: int,
|
828 |
+
device: torch.device,
|
829 |
+
dtype: torch.dtype,
|
830 |
+
init_method: Any,
|
831 |
+
output_layer_init_method: Any = None,
|
832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
833 |
+
|
834 |
+
if output_layer_init_method is None:
|
835 |
+
output_layer_init_method = init_method
|
836 |
+
|
837 |
+
# Create weight tensors
|
838 |
+
up_proj_weight = torch.empty(
|
839 |
+
shared_expert_hidden_size,
|
840 |
+
hidden_size,
|
841 |
+
device=device,
|
842 |
+
dtype=dtype,
|
843 |
+
)
|
844 |
+
down_proj_weight = torch.empty(
|
845 |
+
hidden_size,
|
846 |
+
shared_expert_hidden_size,
|
847 |
+
device=device,
|
848 |
+
dtype=dtype,
|
849 |
+
)
|
850 |
+
|
851 |
+
# Initialize weights
|
852 |
+
init_method(up_proj_weight)
|
853 |
+
output_layer_init_method(down_proj_weight)
|
854 |
+
|
855 |
+
# No bias by default
|
856 |
+
return up_proj_weight, down_proj_weight, None, None
|
857 |
+
|
858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
861 |
+
# TODO: Replace with a more robust solution when available
|
862 |
+
def get_device_mesh(model):
|
863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
864 |
+
try:
|
865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
867 |
+
# Extract the device_mesh from the closure
|
868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
869 |
+
except Exception:
|
870 |
+
return None
|
871 |
+
|
872 |
+
|
873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
874 |
|
875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
884 |
+
|
885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
886 |
+
if expert_parallel_group is None:
|
887 |
+
device_mesh = get_device_mesh(self)
|
888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
889 |
+
|
890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
892 |
|
|
|
916 |
hidden_size=self.experts.hidden_size,
|
917 |
mlp_impl=mlp_impl,
|
918 |
)
|
919 |
+
return output, expert_weights_out
|
920 |
+
|
921 |
+
|
922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
923 |
+
|
924 |
+
def __init__(self):
|
925 |
+
super().__init__()
|
926 |
+
# Shared expert weights will be set by the user
|
927 |
+
self.shared_up_proj_weight = None
|
928 |
+
self.shared_down_proj_weight = None
|
929 |
+
self.shared_up_proj_bias = None
|
930 |
+
self.shared_down_proj_bias = None
|
931 |
+
self.shared_expert_weighted_sum = False
|
932 |
+
self.shared_activation_fn = None
|
933 |
+
|
934 |
+
def set_shared_expert_weights(
|
935 |
+
self,
|
936 |
+
up_proj_weight: torch.Tensor,
|
937 |
+
down_proj_weight: torch.Tensor,
|
938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
940 |
+
weighted_sum: bool = False,
|
941 |
+
activation_fn: Optional[Any] = None,
|
942 |
+
):
|
943 |
+
self.shared_up_proj_weight = up_proj_weight
|
944 |
+
self.shared_down_proj_weight = down_proj_weight
|
945 |
+
self.shared_up_proj_bias = up_proj_bias
|
946 |
+
self.shared_down_proj_bias = down_proj_bias
|
947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
948 |
+
self.shared_activation_fn = activation_fn
|
949 |
+
|
950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
959 |
+
|
960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
961 |
+
if expert_parallel_group is None:
|
962 |
+
device_mesh = get_device_mesh(self)
|
963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
964 |
+
|
965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
967 |
+
|
968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
970 |
+
|
971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
972 |
+
x=x,
|
973 |
+
router_weight=self.router.weight,
|
974 |
+
moe_top_k=moe_top_k,
|
975 |
+
moe_num_experts=moe_num_experts,
|
976 |
+
moe_jitter_eps=moe_jitter_eps,
|
977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
979 |
+
training=self.training,
|
980 |
+
w1=self.experts.gate_up_proj,
|
981 |
+
w2=self.experts.down_proj,
|
982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
983 |
+
w2_bias=self.experts.down_proj_bias,
|
984 |
+
gradient_scale=gradient_scale,
|
985 |
+
alpha=alpha,
|
986 |
+
sort_end_bit=sort_end_bit,
|
987 |
+
expert_parallel_group=expert_parallel_group,
|
988 |
+
moe_capacity_factor=moe_capacity_factor,
|
989 |
+
moe_expert_model_parallelism=has_parallel,
|
990 |
+
forward_fn=forward_fn,
|
991 |
+
hidden_size=self.experts.hidden_size,
|
992 |
+
mlp_impl=mlp_impl,
|
993 |
+
# Shared expert parameters
|
994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
999 |
+
shared_activation_fn=self.shared_activation_fn,
|
1000 |
+
)
|
1001 |
return output, expert_weights_out
|
build/torch27-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 10517816
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a39b315c5359b79a67282160b5b344853aa06b5a5c9d8efafb903eb4f249b645
|
3 |
size 10517816
|
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:002c2687dbc5693308fe32eaebe2f45ed3c85454fd45bc06d7b30e9c1a6d8949
|
3 |
-
size 10517816
|
|
|
|
|
|
|
|
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _megablocks_89e2950
|
3 |
+
ops = torch.ops._megablocks_89e2950
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py
CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# Global variable to store load balancing loss
|
156 |
_LOAD_BALANCING_LOSS = []
|
157 |
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
680 |
return x, expert_weights, router_scores
|
681 |
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
-
|
695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
722 |
hidden_size=self.experts.hidden_size,
|
723 |
mlp_impl=mlp_impl,
|
724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
725 |
return output, expert_weights_out
|
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
155 |
+
# Shared expert MLP forward pass
|
156 |
+
def shared_mlp_forward(
|
157 |
+
x: torch.Tensor,
|
158 |
+
up_proj_weight: torch.Tensor,
|
159 |
+
down_proj_weight: torch.Tensor,
|
160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
162 |
+
activation_fn: Optional[Any] = None,
|
163 |
+
gradient_scale: Optional[float] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
# Default activation function
|
166 |
+
if activation_fn is None:
|
167 |
+
activation_fn = torch.nn.functional.gelu
|
168 |
+
|
169 |
+
# Scale weights
|
170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
172 |
+
if up_proj_bias is not None:
|
173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
174 |
+
if down_proj_bias is not None:
|
175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
176 |
+
|
177 |
+
# Resolve dtensors
|
178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
180 |
+
if up_proj_bias is not None:
|
181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
182 |
+
if down_proj_bias is not None:
|
183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
184 |
+
|
185 |
+
# Up projection
|
186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
187 |
+
|
188 |
+
# Activation
|
189 |
+
x = activation_fn(x)
|
190 |
+
|
191 |
+
# Down projection
|
192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
# Combine outputs from shared expert and regular experts
|
198 |
+
def combine_expert_shared_outputs(
|
199 |
+
shared_expert_out: torch.Tensor,
|
200 |
+
expert_out: torch.Tensor,
|
201 |
+
shared_expert_weighted_sum: bool = False,
|
202 |
+
moe_top_k: int = 1,
|
203 |
+
) -> torch.Tensor:
|
204 |
+
if shared_expert_weighted_sum:
|
205 |
+
# Weighted sum based on number of experts used
|
206 |
+
total_experts = moe_top_k + 1
|
207 |
+
shared_weight = 1.0 / total_experts
|
208 |
+
expert_weight = moe_top_k / total_experts
|
209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
210 |
+
else:
|
211 |
+
# Simple addition
|
212 |
+
return shared_expert_out + expert_out
|
213 |
+
|
214 |
+
|
215 |
# Global variable to store load balancing loss
|
216 |
_LOAD_BALANCING_LOSS = []
|
217 |
|
|
|
740 |
return x, expert_weights, router_scores
|
741 |
|
742 |
|
743 |
+
def moe_forward_with_shared_expert(
|
744 |
+
x: torch.Tensor,
|
745 |
+
router_weight: torch.Tensor,
|
746 |
+
moe_top_k: int,
|
747 |
+
moe_num_experts: int,
|
748 |
+
moe_jitter_eps: float = None,
|
749 |
+
moe_normalize_expert_weights: int = None,
|
750 |
+
uniform_expert_assignment: bool = False,
|
751 |
+
training: bool = False,
|
752 |
+
w1: torch.Tensor = None,
|
753 |
+
w2: torch.Tensor = None,
|
754 |
+
w1_bias: torch.Tensor = None,
|
755 |
+
w2_bias: torch.Tensor = None,
|
756 |
+
gradient_scale: Optional[float] = None,
|
757 |
+
alpha: float = 1.702,
|
758 |
+
sort_end_bit: int = 0,
|
759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
760 |
+
moe_capacity_factor: float = 1.0,
|
761 |
+
moe_expert_model_parallelism: bool = False,
|
762 |
+
forward_fn: Any = None,
|
763 |
+
hidden_size: int = None,
|
764 |
+
mlp_impl: str = "grouped",
|
765 |
+
# Shared expert parameters
|
766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
770 |
+
shared_expert_weighted_sum: bool = False,
|
771 |
+
shared_activation_fn: Optional[Any] = None,
|
772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
773 |
+
|
774 |
+
# First, compute regular MoE forward pass
|
775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
776 |
+
x=x,
|
777 |
+
router_weight=router_weight,
|
778 |
+
moe_top_k=moe_top_k,
|
779 |
+
moe_num_experts=moe_num_experts,
|
780 |
+
moe_jitter_eps=moe_jitter_eps,
|
781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
783 |
+
training=training,
|
784 |
+
w1=w1,
|
785 |
+
w2=w2,
|
786 |
+
w1_bias=w1_bias,
|
787 |
+
w2_bias=w2_bias,
|
788 |
+
gradient_scale=gradient_scale,
|
789 |
+
alpha=alpha,
|
790 |
+
sort_end_bit=sort_end_bit,
|
791 |
+
expert_parallel_group=expert_parallel_group,
|
792 |
+
moe_capacity_factor=moe_capacity_factor,
|
793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
794 |
+
forward_fn=forward_fn,
|
795 |
+
hidden_size=hidden_size,
|
796 |
+
mlp_impl=mlp_impl,
|
797 |
+
)
|
798 |
+
|
799 |
+
# If shared expert weights provided, compute shared expert output
|
800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
801 |
+
shared_expert_out = shared_mlp_forward(
|
802 |
+
x=x,
|
803 |
+
up_proj_weight=shared_up_proj_weight,
|
804 |
+
down_proj_weight=shared_down_proj_weight,
|
805 |
+
up_proj_bias=shared_up_proj_bias,
|
806 |
+
down_proj_bias=shared_down_proj_bias,
|
807 |
+
activation_fn=shared_activation_fn,
|
808 |
+
gradient_scale=gradient_scale,
|
809 |
+
)
|
810 |
+
|
811 |
+
# Combine expert outputs
|
812 |
+
combined_out = combine_expert_shared_outputs(
|
813 |
+
shared_expert_out=shared_expert_out,
|
814 |
+
expert_out=expert_out,
|
815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
816 |
+
moe_top_k=moe_top_k,
|
817 |
+
)
|
818 |
+
|
819 |
+
return combined_out, expert_weights, router_scores
|
820 |
+
|
821 |
+
# Return regular MoE output if no shared expert
|
822 |
+
return expert_out, expert_weights, router_scores
|
823 |
+
|
824 |
+
|
825 |
+
def create_shared_expert_weights(
|
826 |
+
hidden_size: int,
|
827 |
+
shared_expert_hidden_size: int,
|
828 |
+
device: torch.device,
|
829 |
+
dtype: torch.dtype,
|
830 |
+
init_method: Any,
|
831 |
+
output_layer_init_method: Any = None,
|
832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
833 |
+
|
834 |
+
if output_layer_init_method is None:
|
835 |
+
output_layer_init_method = init_method
|
836 |
+
|
837 |
+
# Create weight tensors
|
838 |
+
up_proj_weight = torch.empty(
|
839 |
+
shared_expert_hidden_size,
|
840 |
+
hidden_size,
|
841 |
+
device=device,
|
842 |
+
dtype=dtype,
|
843 |
+
)
|
844 |
+
down_proj_weight = torch.empty(
|
845 |
+
hidden_size,
|
846 |
+
shared_expert_hidden_size,
|
847 |
+
device=device,
|
848 |
+
dtype=dtype,
|
849 |
+
)
|
850 |
+
|
851 |
+
# Initialize weights
|
852 |
+
init_method(up_proj_weight)
|
853 |
+
output_layer_init_method(down_proj_weight)
|
854 |
+
|
855 |
+
# No bias by default
|
856 |
+
return up_proj_weight, down_proj_weight, None, None
|
857 |
+
|
858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
861 |
+
# TODO: Replace with a more robust solution when available
|
862 |
+
def get_device_mesh(model):
|
863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
864 |
+
try:
|
865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
867 |
+
# Extract the device_mesh from the closure
|
868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
869 |
+
except Exception:
|
870 |
+
return None
|
871 |
+
|
872 |
+
|
873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
874 |
|
875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
884 |
+
|
885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
886 |
+
if expert_parallel_group is None:
|
887 |
+
device_mesh = get_device_mesh(self)
|
888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
889 |
+
|
890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
892 |
|
|
|
916 |
hidden_size=self.experts.hidden_size,
|
917 |
mlp_impl=mlp_impl,
|
918 |
)
|
919 |
+
return output, expert_weights_out
|
920 |
+
|
921 |
+
|
922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
923 |
+
|
924 |
+
def __init__(self):
|
925 |
+
super().__init__()
|
926 |
+
# Shared expert weights will be set by the user
|
927 |
+
self.shared_up_proj_weight = None
|
928 |
+
self.shared_down_proj_weight = None
|
929 |
+
self.shared_up_proj_bias = None
|
930 |
+
self.shared_down_proj_bias = None
|
931 |
+
self.shared_expert_weighted_sum = False
|
932 |
+
self.shared_activation_fn = None
|
933 |
+
|
934 |
+
def set_shared_expert_weights(
|
935 |
+
self,
|
936 |
+
up_proj_weight: torch.Tensor,
|
937 |
+
down_proj_weight: torch.Tensor,
|
938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
940 |
+
weighted_sum: bool = False,
|
941 |
+
activation_fn: Optional[Any] = None,
|
942 |
+
):
|
943 |
+
self.shared_up_proj_weight = up_proj_weight
|
944 |
+
self.shared_down_proj_weight = down_proj_weight
|
945 |
+
self.shared_up_proj_bias = up_proj_bias
|
946 |
+
self.shared_down_proj_bias = down_proj_bias
|
947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
948 |
+
self.shared_activation_fn = activation_fn
|
949 |
+
|
950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
959 |
+
|
960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
961 |
+
if expert_parallel_group is None:
|
962 |
+
device_mesh = get_device_mesh(self)
|
963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
964 |
+
|
965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
967 |
+
|
968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
970 |
+
|
971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
972 |
+
x=x,
|
973 |
+
router_weight=self.router.weight,
|
974 |
+
moe_top_k=moe_top_k,
|
975 |
+
moe_num_experts=moe_num_experts,
|
976 |
+
moe_jitter_eps=moe_jitter_eps,
|
977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
979 |
+
training=self.training,
|
980 |
+
w1=self.experts.gate_up_proj,
|
981 |
+
w2=self.experts.down_proj,
|
982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
983 |
+
w2_bias=self.experts.down_proj_bias,
|
984 |
+
gradient_scale=gradient_scale,
|
985 |
+
alpha=alpha,
|
986 |
+
sort_end_bit=sort_end_bit,
|
987 |
+
expert_parallel_group=expert_parallel_group,
|
988 |
+
moe_capacity_factor=moe_capacity_factor,
|
989 |
+
moe_expert_model_parallelism=has_parallel,
|
990 |
+
forward_fn=forward_fn,
|
991 |
+
hidden_size=self.experts.hidden_size,
|
992 |
+
mlp_impl=mlp_impl,
|
993 |
+
# Shared expert parameters
|
994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
999 |
+
shared_activation_fn=self.shared_activation_fn,
|
1000 |
+
)
|
1001 |
return output, expert_weights_out
|
build/torch27-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 11931080
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4870e4a9a831c30c7177b9b23b2b20d64f47242f16d818be1884b4e130e063c1
|
3 |
size 11931080
|
build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:ef9197ea269734d4e0528887ab3c353fa8ba10ccf9a82c9abe85b72bc0ea3553
|
3 |
-
size 11931080
|
|
|
|
|
|
|
|
build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _megablocks_89e2950
|
3 |
+
ops = torch.ops._megablocks_89e2950
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py
CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# Global variable to store load balancing loss
|
156 |
_LOAD_BALANCING_LOSS = []
|
157 |
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
680 |
return x, expert_weights, router_scores
|
681 |
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
-
|
695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
722 |
hidden_size=self.experts.hidden_size,
|
723 |
mlp_impl=mlp_impl,
|
724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
725 |
return output, expert_weights_out
|
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
155 |
+
# Shared expert MLP forward pass
|
156 |
+
def shared_mlp_forward(
|
157 |
+
x: torch.Tensor,
|
158 |
+
up_proj_weight: torch.Tensor,
|
159 |
+
down_proj_weight: torch.Tensor,
|
160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
162 |
+
activation_fn: Optional[Any] = None,
|
163 |
+
gradient_scale: Optional[float] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
# Default activation function
|
166 |
+
if activation_fn is None:
|
167 |
+
activation_fn = torch.nn.functional.gelu
|
168 |
+
|
169 |
+
# Scale weights
|
170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
172 |
+
if up_proj_bias is not None:
|
173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
174 |
+
if down_proj_bias is not None:
|
175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
176 |
+
|
177 |
+
# Resolve dtensors
|
178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
180 |
+
if up_proj_bias is not None:
|
181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
182 |
+
if down_proj_bias is not None:
|
183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
184 |
+
|
185 |
+
# Up projection
|
186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
187 |
+
|
188 |
+
# Activation
|
189 |
+
x = activation_fn(x)
|
190 |
+
|
191 |
+
# Down projection
|
192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
# Combine outputs from shared expert and regular experts
|
198 |
+
def combine_expert_shared_outputs(
|
199 |
+
shared_expert_out: torch.Tensor,
|
200 |
+
expert_out: torch.Tensor,
|
201 |
+
shared_expert_weighted_sum: bool = False,
|
202 |
+
moe_top_k: int = 1,
|
203 |
+
) -> torch.Tensor:
|
204 |
+
if shared_expert_weighted_sum:
|
205 |
+
# Weighted sum based on number of experts used
|
206 |
+
total_experts = moe_top_k + 1
|
207 |
+
shared_weight = 1.0 / total_experts
|
208 |
+
expert_weight = moe_top_k / total_experts
|
209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
210 |
+
else:
|
211 |
+
# Simple addition
|
212 |
+
return shared_expert_out + expert_out
|
213 |
+
|
214 |
+
|
215 |
# Global variable to store load balancing loss
|
216 |
_LOAD_BALANCING_LOSS = []
|
217 |
|
|
|
740 |
return x, expert_weights, router_scores
|
741 |
|
742 |
|
743 |
+
def moe_forward_with_shared_expert(
|
744 |
+
x: torch.Tensor,
|
745 |
+
router_weight: torch.Tensor,
|
746 |
+
moe_top_k: int,
|
747 |
+
moe_num_experts: int,
|
748 |
+
moe_jitter_eps: float = None,
|
749 |
+
moe_normalize_expert_weights: int = None,
|
750 |
+
uniform_expert_assignment: bool = False,
|
751 |
+
training: bool = False,
|
752 |
+
w1: torch.Tensor = None,
|
753 |
+
w2: torch.Tensor = None,
|
754 |
+
w1_bias: torch.Tensor = None,
|
755 |
+
w2_bias: torch.Tensor = None,
|
756 |
+
gradient_scale: Optional[float] = None,
|
757 |
+
alpha: float = 1.702,
|
758 |
+
sort_end_bit: int = 0,
|
759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
760 |
+
moe_capacity_factor: float = 1.0,
|
761 |
+
moe_expert_model_parallelism: bool = False,
|
762 |
+
forward_fn: Any = None,
|
763 |
+
hidden_size: int = None,
|
764 |
+
mlp_impl: str = "grouped",
|
765 |
+
# Shared expert parameters
|
766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
770 |
+
shared_expert_weighted_sum: bool = False,
|
771 |
+
shared_activation_fn: Optional[Any] = None,
|
772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
773 |
+
|
774 |
+
# First, compute regular MoE forward pass
|
775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
776 |
+
x=x,
|
777 |
+
router_weight=router_weight,
|
778 |
+
moe_top_k=moe_top_k,
|
779 |
+
moe_num_experts=moe_num_experts,
|
780 |
+
moe_jitter_eps=moe_jitter_eps,
|
781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
783 |
+
training=training,
|
784 |
+
w1=w1,
|
785 |
+
w2=w2,
|
786 |
+
w1_bias=w1_bias,
|
787 |
+
w2_bias=w2_bias,
|
788 |
+
gradient_scale=gradient_scale,
|
789 |
+
alpha=alpha,
|
790 |
+
sort_end_bit=sort_end_bit,
|
791 |
+
expert_parallel_group=expert_parallel_group,
|
792 |
+
moe_capacity_factor=moe_capacity_factor,
|
793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
794 |
+
forward_fn=forward_fn,
|
795 |
+
hidden_size=hidden_size,
|
796 |
+
mlp_impl=mlp_impl,
|
797 |
+
)
|
798 |
+
|
799 |
+
# If shared expert weights provided, compute shared expert output
|
800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
801 |
+
shared_expert_out = shared_mlp_forward(
|
802 |
+
x=x,
|
803 |
+
up_proj_weight=shared_up_proj_weight,
|
804 |
+
down_proj_weight=shared_down_proj_weight,
|
805 |
+
up_proj_bias=shared_up_proj_bias,
|
806 |
+
down_proj_bias=shared_down_proj_bias,
|
807 |
+
activation_fn=shared_activation_fn,
|
808 |
+
gradient_scale=gradient_scale,
|
809 |
+
)
|
810 |
+
|
811 |
+
# Combine expert outputs
|
812 |
+
combined_out = combine_expert_shared_outputs(
|
813 |
+
shared_expert_out=shared_expert_out,
|
814 |
+
expert_out=expert_out,
|
815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
816 |
+
moe_top_k=moe_top_k,
|
817 |
+
)
|
818 |
+
|
819 |
+
return combined_out, expert_weights, router_scores
|
820 |
+
|
821 |
+
# Return regular MoE output if no shared expert
|
822 |
+
return expert_out, expert_weights, router_scores
|
823 |
+
|
824 |
+
|
825 |
+
def create_shared_expert_weights(
|
826 |
+
hidden_size: int,
|
827 |
+
shared_expert_hidden_size: int,
|
828 |
+
device: torch.device,
|
829 |
+
dtype: torch.dtype,
|
830 |
+
init_method: Any,
|
831 |
+
output_layer_init_method: Any = None,
|
832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
833 |
+
|
834 |
+
if output_layer_init_method is None:
|
835 |
+
output_layer_init_method = init_method
|
836 |
+
|
837 |
+
# Create weight tensors
|
838 |
+
up_proj_weight = torch.empty(
|
839 |
+
shared_expert_hidden_size,
|
840 |
+
hidden_size,
|
841 |
+
device=device,
|
842 |
+
dtype=dtype,
|
843 |
+
)
|
844 |
+
down_proj_weight = torch.empty(
|
845 |
+
hidden_size,
|
846 |
+
shared_expert_hidden_size,
|
847 |
+
device=device,
|
848 |
+
dtype=dtype,
|
849 |
+
)
|
850 |
+
|
851 |
+
# Initialize weights
|
852 |
+
init_method(up_proj_weight)
|
853 |
+
output_layer_init_method(down_proj_weight)
|
854 |
+
|
855 |
+
# No bias by default
|
856 |
+
return up_proj_weight, down_proj_weight, None, None
|
857 |
+
|
858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
861 |
+
# TODO: Replace with a more robust solution when available
|
862 |
+
def get_device_mesh(model):
|
863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
864 |
+
try:
|
865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
867 |
+
# Extract the device_mesh from the closure
|
868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
869 |
+
except Exception:
|
870 |
+
return None
|
871 |
+
|
872 |
+
|
873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
874 |
|
875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
884 |
+
|
885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
886 |
+
if expert_parallel_group is None:
|
887 |
+
device_mesh = get_device_mesh(self)
|
888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
889 |
+
|
890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
892 |
|
|
|
916 |
hidden_size=self.experts.hidden_size,
|
917 |
mlp_impl=mlp_impl,
|
918 |
)
|
919 |
+
return output, expert_weights_out
|
920 |
+
|
921 |
+
|
922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
923 |
+
|
924 |
+
def __init__(self):
|
925 |
+
super().__init__()
|
926 |
+
# Shared expert weights will be set by the user
|
927 |
+
self.shared_up_proj_weight = None
|
928 |
+
self.shared_down_proj_weight = None
|
929 |
+
self.shared_up_proj_bias = None
|
930 |
+
self.shared_down_proj_bias = None
|
931 |
+
self.shared_expert_weighted_sum = False
|
932 |
+
self.shared_activation_fn = None
|
933 |
+
|
934 |
+
def set_shared_expert_weights(
|
935 |
+
self,
|
936 |
+
up_proj_weight: torch.Tensor,
|
937 |
+
down_proj_weight: torch.Tensor,
|
938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
940 |
+
weighted_sum: bool = False,
|
941 |
+
activation_fn: Optional[Any] = None,
|
942 |
+
):
|
943 |
+
self.shared_up_proj_weight = up_proj_weight
|
944 |
+
self.shared_down_proj_weight = down_proj_weight
|
945 |
+
self.shared_up_proj_bias = up_proj_bias
|
946 |
+
self.shared_down_proj_bias = down_proj_bias
|
947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
948 |
+
self.shared_activation_fn = activation_fn
|
949 |
+
|
950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
959 |
+
|
960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
961 |
+
if expert_parallel_group is None:
|
962 |
+
device_mesh = get_device_mesh(self)
|
963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
964 |
+
|
965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
967 |
+
|
968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
970 |
+
|
971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
972 |
+
x=x,
|
973 |
+
router_weight=self.router.weight,
|
974 |
+
moe_top_k=moe_top_k,
|
975 |
+
moe_num_experts=moe_num_experts,
|
976 |
+
moe_jitter_eps=moe_jitter_eps,
|
977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
979 |
+
training=self.training,
|
980 |
+
w1=self.experts.gate_up_proj,
|
981 |
+
w2=self.experts.down_proj,
|
982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
983 |
+
w2_bias=self.experts.down_proj_bias,
|
984 |
+
gradient_scale=gradient_scale,
|
985 |
+
alpha=alpha,
|
986 |
+
sort_end_bit=sort_end_bit,
|
987 |
+
expert_parallel_group=expert_parallel_group,
|
988 |
+
moe_capacity_factor=moe_capacity_factor,
|
989 |
+
moe_expert_model_parallelism=has_parallel,
|
990 |
+
forward_fn=forward_fn,
|
991 |
+
hidden_size=self.experts.hidden_size,
|
992 |
+
mlp_impl=mlp_impl,
|
993 |
+
# Shared expert parameters
|
994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
999 |
+
shared_activation_fn=self.shared_activation_fn,
|
1000 |
+
)
|
1001 |
return output, expert_weights_out
|
build/torch27-cxx11-cu128-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 17892624
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37844f7b2972aae75a1eeb8cda3b573a93ef27dd5a73b2cfb95fca1f41da07d9
|
3 |
size 17892624
|
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:b071dec56af72c9e6b8408106b97fb42355b08e94cc1200bb6f4d3f42ba0e97e
|
3 |
-
size 17892624
|
|
|
|
|
|
|
|
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _megablocks_89e2950
|
3 |
+
ops = torch.ops._megablocks_89e2950
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py
CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# Global variable to store load balancing loss
|
156 |
_LOAD_BALANCING_LOSS = []
|
157 |
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
680 |
return x, expert_weights, router_scores
|
681 |
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
-
|
695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
722 |
hidden_size=self.experts.hidden_size,
|
723 |
mlp_impl=mlp_impl,
|
724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
725 |
return output, expert_weights_out
|
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
155 |
+
# Shared expert MLP forward pass
|
156 |
+
def shared_mlp_forward(
|
157 |
+
x: torch.Tensor,
|
158 |
+
up_proj_weight: torch.Tensor,
|
159 |
+
down_proj_weight: torch.Tensor,
|
160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
162 |
+
activation_fn: Optional[Any] = None,
|
163 |
+
gradient_scale: Optional[float] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
# Default activation function
|
166 |
+
if activation_fn is None:
|
167 |
+
activation_fn = torch.nn.functional.gelu
|
168 |
+
|
169 |
+
# Scale weights
|
170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
172 |
+
if up_proj_bias is not None:
|
173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
174 |
+
if down_proj_bias is not None:
|
175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
176 |
+
|
177 |
+
# Resolve dtensors
|
178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
180 |
+
if up_proj_bias is not None:
|
181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
182 |
+
if down_proj_bias is not None:
|
183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
184 |
+
|
185 |
+
# Up projection
|
186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
187 |
+
|
188 |
+
# Activation
|
189 |
+
x = activation_fn(x)
|
190 |
+
|
191 |
+
# Down projection
|
192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
# Combine outputs from shared expert and regular experts
|
198 |
+
def combine_expert_shared_outputs(
|
199 |
+
shared_expert_out: torch.Tensor,
|
200 |
+
expert_out: torch.Tensor,
|
201 |
+
shared_expert_weighted_sum: bool = False,
|
202 |
+
moe_top_k: int = 1,
|
203 |
+
) -> torch.Tensor:
|
204 |
+
if shared_expert_weighted_sum:
|
205 |
+
# Weighted sum based on number of experts used
|
206 |
+
total_experts = moe_top_k + 1
|
207 |
+
shared_weight = 1.0 / total_experts
|
208 |
+
expert_weight = moe_top_k / total_experts
|
209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
210 |
+
else:
|
211 |
+
# Simple addition
|
212 |
+
return shared_expert_out + expert_out
|
213 |
+
|
214 |
+
|
215 |
# Global variable to store load balancing loss
|
216 |
_LOAD_BALANCING_LOSS = []
|
217 |
|
|
|
740 |
return x, expert_weights, router_scores
|
741 |
|
742 |
|
743 |
+
def moe_forward_with_shared_expert(
|
744 |
+
x: torch.Tensor,
|
745 |
+
router_weight: torch.Tensor,
|
746 |
+
moe_top_k: int,
|
747 |
+
moe_num_experts: int,
|
748 |
+
moe_jitter_eps: float = None,
|
749 |
+
moe_normalize_expert_weights: int = None,
|
750 |
+
uniform_expert_assignment: bool = False,
|
751 |
+
training: bool = False,
|
752 |
+
w1: torch.Tensor = None,
|
753 |
+
w2: torch.Tensor = None,
|
754 |
+
w1_bias: torch.Tensor = None,
|
755 |
+
w2_bias: torch.Tensor = None,
|
756 |
+
gradient_scale: Optional[float] = None,
|
757 |
+
alpha: float = 1.702,
|
758 |
+
sort_end_bit: int = 0,
|
759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
760 |
+
moe_capacity_factor: float = 1.0,
|
761 |
+
moe_expert_model_parallelism: bool = False,
|
762 |
+
forward_fn: Any = None,
|
763 |
+
hidden_size: int = None,
|
764 |
+
mlp_impl: str = "grouped",
|
765 |
+
# Shared expert parameters
|
766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
770 |
+
shared_expert_weighted_sum: bool = False,
|
771 |
+
shared_activation_fn: Optional[Any] = None,
|
772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
773 |
+
|
774 |
+
# First, compute regular MoE forward pass
|
775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
776 |
+
x=x,
|
777 |
+
router_weight=router_weight,
|
778 |
+
moe_top_k=moe_top_k,
|
779 |
+
moe_num_experts=moe_num_experts,
|
780 |
+
moe_jitter_eps=moe_jitter_eps,
|
781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
783 |
+
training=training,
|
784 |
+
w1=w1,
|
785 |
+
w2=w2,
|
786 |
+
w1_bias=w1_bias,
|
787 |
+
w2_bias=w2_bias,
|
788 |
+
gradient_scale=gradient_scale,
|
789 |
+
alpha=alpha,
|
790 |
+
sort_end_bit=sort_end_bit,
|
791 |
+
expert_parallel_group=expert_parallel_group,
|
792 |
+
moe_capacity_factor=moe_capacity_factor,
|
793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
794 |
+
forward_fn=forward_fn,
|
795 |
+
hidden_size=hidden_size,
|
796 |
+
mlp_impl=mlp_impl,
|
797 |
+
)
|
798 |
+
|
799 |
+
# If shared expert weights provided, compute shared expert output
|
800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
801 |
+
shared_expert_out = shared_mlp_forward(
|
802 |
+
x=x,
|
803 |
+
up_proj_weight=shared_up_proj_weight,
|
804 |
+
down_proj_weight=shared_down_proj_weight,
|
805 |
+
up_proj_bias=shared_up_proj_bias,
|
806 |
+
down_proj_bias=shared_down_proj_bias,
|
807 |
+
activation_fn=shared_activation_fn,
|
808 |
+
gradient_scale=gradient_scale,
|
809 |
+
)
|
810 |
+
|
811 |
+
# Combine expert outputs
|
812 |
+
combined_out = combine_expert_shared_outputs(
|
813 |
+
shared_expert_out=shared_expert_out,
|
814 |
+
expert_out=expert_out,
|
815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
816 |
+
moe_top_k=moe_top_k,
|
817 |
+
)
|
818 |
+
|
819 |
+
return combined_out, expert_weights, router_scores
|
820 |
+
|
821 |
+
# Return regular MoE output if no shared expert
|
822 |
+
return expert_out, expert_weights, router_scores
|
823 |
+
|
824 |
+
|
825 |
+
def create_shared_expert_weights(
|
826 |
+
hidden_size: int,
|
827 |
+
shared_expert_hidden_size: int,
|
828 |
+
device: torch.device,
|
829 |
+
dtype: torch.dtype,
|
830 |
+
init_method: Any,
|
831 |
+
output_layer_init_method: Any = None,
|
832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
833 |
+
|
834 |
+
if output_layer_init_method is None:
|
835 |
+
output_layer_init_method = init_method
|
836 |
+
|
837 |
+
# Create weight tensors
|
838 |
+
up_proj_weight = torch.empty(
|
839 |
+
shared_expert_hidden_size,
|
840 |
+
hidden_size,
|
841 |
+
device=device,
|
842 |
+
dtype=dtype,
|
843 |
+
)
|
844 |
+
down_proj_weight = torch.empty(
|
845 |
+
hidden_size,
|
846 |
+
shared_expert_hidden_size,
|
847 |
+
device=device,
|
848 |
+
dtype=dtype,
|
849 |
+
)
|
850 |
+
|
851 |
+
# Initialize weights
|
852 |
+
init_method(up_proj_weight)
|
853 |
+
output_layer_init_method(down_proj_weight)
|
854 |
+
|
855 |
+
# No bias by default
|
856 |
+
return up_proj_weight, down_proj_weight, None, None
|
857 |
+
|
858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
861 |
+
# TODO: Replace with a more robust solution when available
|
862 |
+
def get_device_mesh(model):
|
863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
864 |
+
try:
|
865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
867 |
+
# Extract the device_mesh from the closure
|
868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
869 |
+
except Exception:
|
870 |
+
return None
|
871 |
+
|
872 |
+
|
873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
874 |
|
875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
884 |
+
|
885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
886 |
+
if expert_parallel_group is None:
|
887 |
+
device_mesh = get_device_mesh(self)
|
888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
889 |
+
|
890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
892 |
|
|
|
916 |
hidden_size=self.experts.hidden_size,
|
917 |
mlp_impl=mlp_impl,
|
918 |
)
|
919 |
+
return output, expert_weights_out
|
920 |
+
|
921 |
+
|
922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
923 |
+
|
924 |
+
def __init__(self):
|
925 |
+
super().__init__()
|
926 |
+
# Shared expert weights will be set by the user
|
927 |
+
self.shared_up_proj_weight = None
|
928 |
+
self.shared_down_proj_weight = None
|
929 |
+
self.shared_up_proj_bias = None
|
930 |
+
self.shared_down_proj_bias = None
|
931 |
+
self.shared_expert_weighted_sum = False
|
932 |
+
self.shared_activation_fn = None
|
933 |
+
|
934 |
+
def set_shared_expert_weights(
|
935 |
+
self,
|
936 |
+
up_proj_weight: torch.Tensor,
|
937 |
+
down_proj_weight: torch.Tensor,
|
938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
940 |
+
weighted_sum: bool = False,
|
941 |
+
activation_fn: Optional[Any] = None,
|
942 |
+
):
|
943 |
+
self.shared_up_proj_weight = up_proj_weight
|
944 |
+
self.shared_down_proj_weight = down_proj_weight
|
945 |
+
self.shared_up_proj_bias = up_proj_bias
|
946 |
+
self.shared_down_proj_bias = down_proj_bias
|
947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
948 |
+
self.shared_activation_fn = activation_fn
|
949 |
+
|
950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
959 |
+
|
960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
961 |
+
if expert_parallel_group is None:
|
962 |
+
device_mesh = get_device_mesh(self)
|
963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
964 |
+
|
965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
967 |
+
|
968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
970 |
+
|
971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
972 |
+
x=x,
|
973 |
+
router_weight=self.router.weight,
|
974 |
+
moe_top_k=moe_top_k,
|
975 |
+
moe_num_experts=moe_num_experts,
|
976 |
+
moe_jitter_eps=moe_jitter_eps,
|
977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
979 |
+
training=self.training,
|
980 |
+
w1=self.experts.gate_up_proj,
|
981 |
+
w2=self.experts.down_proj,
|
982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
983 |
+
w2_bias=self.experts.down_proj_bias,
|
984 |
+
gradient_scale=gradient_scale,
|
985 |
+
alpha=alpha,
|
986 |
+
sort_end_bit=sort_end_bit,
|
987 |
+
expert_parallel_group=expert_parallel_group,
|
988 |
+
moe_capacity_factor=moe_capacity_factor,
|
989 |
+
moe_expert_model_parallelism=has_parallel,
|
990 |
+
forward_fn=forward_fn,
|
991 |
+
hidden_size=self.experts.hidden_size,
|
992 |
+
mlp_impl=mlp_impl,
|
993 |
+
# Shared expert parameters
|
994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
999 |
+
shared_activation_fn=self.shared_activation_fn,
|
1000 |
+
)
|
1001 |
return output, expert_weights_out
|