|
|
import tensorflow as tf
|
|
|
from tensorflow.keras import layers, Model
|
|
|
import numpy as np
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Optional, List, Tuple, Dict
|
|
|
import os
|
|
|
import logging
|
|
|
import time
|
|
|
import math
|
|
|
import json
|
|
|
from tokenizers import ByteLevelBPETokenizer
|
|
|
from transformers import PreTrainedTokenizerFast
|
|
|
|
|
|
|
|
|
try:
|
|
|
from transformers import AutoTokenizer
|
|
|
HAS_TRANSFORMERS = True
|
|
|
except ImportError:
|
|
|
print("Warning: transformers library not found. Tokenizer features will be disabled.")
|
|
|
HAS_TRANSFORMERS = False
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class MoEConfig:
|
|
|
"""Configuration for the MoE MiniGPT model."""
|
|
|
vocab_size: int = 10000
|
|
|
max_seq_len: int = 256
|
|
|
embed_dim: int = 512
|
|
|
num_heads: int = 8
|
|
|
num_layers: int = 8
|
|
|
ffn_dim: int = 2048
|
|
|
dropout: float = 0.1
|
|
|
layer_norm_epsilon: float = 1e-5
|
|
|
use_rotary_embeddings: bool = True
|
|
|
learning_rate: float = 2e-4
|
|
|
batch_size: int = 32
|
|
|
seq_len: int = 256
|
|
|
|
|
|
|
|
|
num_experts: int = 4
|
|
|
top_k_experts: int = 1
|
|
|
expert_capacity: Optional[int] = None
|
|
|
load_balancing_loss_weight: float = 0.01
|
|
|
router_z_loss_weight: float = 0.001
|
|
|
use_moe_layers: List[int] = None
|
|
|
|
|
|
def __post_init__(self):
|
|
|
if self.use_moe_layers is None:
|
|
|
|
|
|
self.use_moe_layers = list(range(2, self.num_layers, 2))
|
|
|
|
|
|
if self.expert_capacity is None:
|
|
|
|
|
|
self.expert_capacity = max(4, (self.seq_len * self.batch_size * self.top_k_experts) // self.num_experts)
|
|
|
|
|
|
class RotaryPositionalEmbedding(layers.Layer):
|
|
|
"""Rotary Positional Embedding layer with proper weight management."""
|
|
|
|
|
|
def __init__(self, dim=512, max_seq_len=256, **kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
self.dim = dim
|
|
|
self.max_seq_len = max_seq_len
|
|
|
|
|
|
def build(self, input_shape):
|
|
|
super().build(input_shape)
|
|
|
|
|
|
|
|
|
position = tf.range(self.max_seq_len, dtype=tf.int32)
|
|
|
position = tf.cast(position, tf.float32)
|
|
|
div_term = tf.exp(tf.range(0, self.dim, 2, dtype=tf.float32) * -(tf.math.log(10000.0) / self.dim))
|
|
|
|
|
|
|
|
|
pos = tf.expand_dims(position, 1) * tf.expand_dims(div_term, 0)
|
|
|
|
|
|
|
|
|
self.sin_vals = self.add_weight(
|
|
|
name='sin_vals',
|
|
|
shape=(self.max_seq_len, self.dim // 2),
|
|
|
initializer='zeros',
|
|
|
trainable=False
|
|
|
)
|
|
|
self.cos_vals = self.add_weight(
|
|
|
name='cos_vals',
|
|
|
shape=(self.max_seq_len, self.dim // 2),
|
|
|
initializer='zeros',
|
|
|
trainable=False
|
|
|
)
|
|
|
|
|
|
|
|
|
self.sin_vals.assign(tf.sin(pos))
|
|
|
self.cos_vals.assign(tf.cos(pos))
|
|
|
|
|
|
def _apply_rotary_pos_emb(self, x, sin, cos):
|
|
|
"""Apply rotary positional embeddings to input tensor."""
|
|
|
seq_len = tf.shape(x)[2]
|
|
|
sin = sin[:seq_len]
|
|
|
cos = cos[:seq_len]
|
|
|
sin = tf.expand_dims(tf.expand_dims(sin, 0), 0)
|
|
|
cos = tf.expand_dims(tf.expand_dims(cos, 0), 0)
|
|
|
|
|
|
x1, x2 = tf.split(x, 2, axis=-1)
|
|
|
rotated_x1 = cos * x1 - sin * x2
|
|
|
rotated_x2 = sin * x1 + cos * x2
|
|
|
|
|
|
return tf.concat([rotated_x1, rotated_x2], axis=-1)
|
|
|
|
|
|
def call(self, x):
|
|
|
return self._apply_rotary_pos_emb(x, self.sin_vals, self.cos_vals)
|
|
|
|
|
|
def get_config(self):
|
|
|
config = super().get_config()
|
|
|
config.update({
|
|
|
'dim': self.dim,
|
|
|
'max_seq_len': self.max_seq_len
|
|
|
})
|
|
|
return config
|
|
|
|
|
|
class Expert(layers.Layer):
|
|
|
"""Individual expert network (FFN)."""
|
|
|
|
|
|
def __init__(self, embed_dim=512, ffn_dim=2048, dropout=0.1, expert_id=0, **kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
self.embed_dim = embed_dim
|
|
|
self.ffn_dim = ffn_dim
|
|
|
self.expert_id = expert_id
|
|
|
|
|
|
self.dense1 = layers.Dense(ffn_dim, activation='gelu', name=f'expert_{expert_id}_dense1')
|
|
|
self.dense2 = layers.Dense(embed_dim, name=f'expert_{expert_id}_dense2')
|
|
|
self.dropout = layers.Dropout(dropout)
|
|
|
|
|
|
def call(self, x, training=False):
|
|
|
x = self.dense1(x)
|
|
|
x = self.dropout(x, training=training)
|
|
|
x = self.dense2(x)
|
|
|
return x
|
|
|
|
|
|
class Router(layers.Layer):
|
|
|
"""Router network for MoE that decides which experts to use."""
|
|
|
|
|
|
def __init__(self, embed_dim=512, num_experts=4, **kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
self.embed_dim = embed_dim
|
|
|
self.num_experts = num_experts
|
|
|
|
|
|
self.router_weights = layers.Dense(
|
|
|
num_experts,
|
|
|
use_bias=False,
|
|
|
kernel_initializer='truncated_normal',
|
|
|
name='router_weights'
|
|
|
)
|
|
|
|
|
|
def call(self, x):
|
|
|
"""
|
|
|
Args:
|
|
|
x: [batch_size, seq_len, embed_dim]
|
|
|
Returns:
|
|
|
router_logits: [batch_size, seq_len, num_experts]
|
|
|
"""
|
|
|
router_logits = self.router_weights(x)
|
|
|
return router_logits
|
|
|
|
|
|
class MixtureOfExperts(layers.Layer):
|
|
|
"""Mixture of Experts layer with efficient routing and load balancing."""
|
|
|
|
|
|
def __init__(self, embed_dim=512, ffn_dim=2048, num_experts=4, top_k=1,
|
|
|
expert_capacity=None, dropout=0.1, load_balancing_loss_weight=0.01,
|
|
|
router_z_loss_weight=0.001, **kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
self.embed_dim = embed_dim
|
|
|
self.ffn_dim = ffn_dim
|
|
|
self.num_experts = num_experts
|
|
|
self.top_k = top_k
|
|
|
self.expert_capacity = expert_capacity or 32
|
|
|
self.load_balancing_loss_weight = load_balancing_loss_weight
|
|
|
self.router_z_loss_weight = router_z_loss_weight
|
|
|
|
|
|
|
|
|
self.router = Router(embed_dim, num_experts, name='router')
|
|
|
|
|
|
|
|
|
self.experts = [
|
|
|
Expert(embed_dim, ffn_dim, dropout, expert_id=i, name=f'expert_{i}')
|
|
|
for i in range(num_experts)
|
|
|
]
|
|
|
|
|
|
def _compute_routing_probabilities(self, router_logits):
|
|
|
"""Compute routing probabilities and auxiliary losses."""
|
|
|
|
|
|
router_probs = tf.nn.softmax(router_logits, axis=-1)
|
|
|
|
|
|
|
|
|
top_k_probs, top_k_indices = tf.nn.top_k(router_probs, k=self.top_k)
|
|
|
|
|
|
|
|
|
top_k_probs = top_k_probs / (tf.reduce_sum(top_k_probs, axis=-1, keepdims=True) + 1e-8)
|
|
|
|
|
|
return router_probs, top_k_probs, top_k_indices
|
|
|
|
|
|
def _compute_auxiliary_losses(self, router_logits, router_probs):
|
|
|
"""Compute load balancing and router z-loss."""
|
|
|
|
|
|
|
|
|
expert_usage = tf.reduce_mean(router_probs, axis=[0, 1])
|
|
|
|
|
|
|
|
|
mean_usage = tf.reduce_mean(expert_usage)
|
|
|
variance_usage = tf.reduce_mean(tf.square(expert_usage - mean_usage))
|
|
|
load_balancing_loss = variance_usage / (mean_usage * mean_usage + 1e-8)
|
|
|
load_balancing_loss = load_balancing_loss * tf.cast(self.num_experts, tf.float32)
|
|
|
|
|
|
|
|
|
router_z_loss = tf.reduce_mean(tf.square(tf.reduce_logsumexp(router_logits, axis=-1)))
|
|
|
|
|
|
return load_balancing_loss, router_z_loss
|
|
|
|
|
|
@tf.function
|
|
|
def call(self, x, training=False):
|
|
|
"""
|
|
|
Args:
|
|
|
x: [batch_size, seq_len, embed_dim]
|
|
|
Returns:
|
|
|
output: [batch_size, seq_len, embed_dim]
|
|
|
aux_losses: dict with auxiliary losses
|
|
|
"""
|
|
|
batch_size = tf.shape(x)[0]
|
|
|
seq_len = tf.shape(x)[1]
|
|
|
|
|
|
|
|
|
router_logits = self.router(x)
|
|
|
router_probs, top_k_probs, top_k_indices = self._compute_routing_probabilities(router_logits)
|
|
|
|
|
|
|
|
|
load_balancing_loss, router_z_loss = self._compute_auxiliary_losses(router_logits, router_probs)
|
|
|
|
|
|
|
|
|
output = tf.zeros_like(x)
|
|
|
|
|
|
|
|
|
for k in range(self.top_k):
|
|
|
|
|
|
expert_indices = top_k_indices[:, :, k]
|
|
|
expert_weights = top_k_probs[:, :, k]
|
|
|
|
|
|
|
|
|
for expert_idx in range(self.num_experts):
|
|
|
|
|
|
expert_mask = tf.equal(expert_indices, expert_idx)
|
|
|
|
|
|
|
|
|
def apply_expert():
|
|
|
expert_output = self.experts[expert_idx](x, training=training)
|
|
|
expert_mask_expanded = tf.expand_dims(tf.cast(expert_mask, tf.float32), -1)
|
|
|
weights_expanded = tf.expand_dims(expert_weights, -1)
|
|
|
return expert_output * expert_mask_expanded * weights_expanded
|
|
|
|
|
|
def skip_expert():
|
|
|
return tf.zeros_like(x)
|
|
|
|
|
|
|
|
|
has_tokens = tf.reduce_any(expert_mask)
|
|
|
expert_contribution = tf.cond(has_tokens, apply_expert, skip_expert)
|
|
|
output += expert_contribution
|
|
|
|
|
|
|
|
|
aux_losses = {
|
|
|
'load_balancing_loss': load_balancing_loss * self.load_balancing_loss_weight,
|
|
|
'router_z_loss': router_z_loss * self.router_z_loss_weight
|
|
|
}
|
|
|
|
|
|
return output, aux_losses
|
|
|
|
|
|
class MultiHeadAttention(layers.Layer):
|
|
|
"""Multi-head attention layer with optional rotary embeddings"""
|
|
|
|
|
|
def __init__(self, embed_dim=512, num_heads=8, dropout=0.1, max_seq_len=256,
|
|
|
use_rotary_embeddings=True, **kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
self.embed_dim = embed_dim
|
|
|
self.num_heads = num_heads
|
|
|
self.dropout_rate = dropout
|
|
|
self.use_rotary_embeddings = use_rotary_embeddings
|
|
|
|
|
|
if embed_dim % num_heads != 0:
|
|
|
raise ValueError(f"embed_dim {embed_dim} must be divisible by num_heads {num_heads}")
|
|
|
|
|
|
self.head_dim = embed_dim // num_heads
|
|
|
self.scale = 1.0 / math.sqrt(self.head_dim)
|
|
|
|
|
|
self.query_proj = layers.Dense(embed_dim, use_bias=False, name='query_proj')
|
|
|
self.key_proj = layers.Dense(embed_dim, use_bias=False, name='key_proj')
|
|
|
self.value_proj = layers.Dense(embed_dim, use_bias=False, name='value_proj')
|
|
|
self.output_proj = layers.Dense(embed_dim, use_bias=False, name='output_proj')
|
|
|
|
|
|
if use_rotary_embeddings:
|
|
|
self.rope = RotaryPositionalEmbedding(
|
|
|
dim=self.head_dim,
|
|
|
max_seq_len=max_seq_len,
|
|
|
name='rope'
|
|
|
)
|
|
|
|
|
|
self.dropout = layers.Dropout(dropout)
|
|
|
|
|
|
def call(self, x, mask=None, training=False):
|
|
|
batch_size = tf.shape(x)[0]
|
|
|
seq_len = tf.shape(x)[1]
|
|
|
|
|
|
q = self.query_proj(x)
|
|
|
k = self.key_proj(x)
|
|
|
v = self.value_proj(x)
|
|
|
|
|
|
q = tf.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
|
|
k = tf.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
|
|
|
v = tf.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim))
|
|
|
|
|
|
q = tf.transpose(q, perm=[0, 2, 1, 3])
|
|
|
k = tf.transpose(k, perm=[0, 2, 1, 3])
|
|
|
v = tf.transpose(v, perm=[0, 2, 1, 3])
|
|
|
|
|
|
if self.use_rotary_embeddings:
|
|
|
q = self.rope(q)
|
|
|
k = self.rope(k)
|
|
|
|
|
|
scores = tf.matmul(q, k, transpose_b=True) * self.scale
|
|
|
|
|
|
|
|
|
causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), dtype=tf.float32), -1, 0)
|
|
|
causal_mask = tf.where(tf.equal(causal_mask, 0), -1e9, 0.0)
|
|
|
causal_mask = tf.expand_dims(tf.expand_dims(causal_mask, 0), 0)
|
|
|
|
|
|
scores = scores + causal_mask
|
|
|
|
|
|
if mask is not None:
|
|
|
mask = tf.expand_dims(tf.expand_dims(mask, 1), 1)
|
|
|
mask = tf.cast(mask, scores.dtype)
|
|
|
scores = scores + (1.0 - mask) * -1e9
|
|
|
|
|
|
attention_weights = tf.nn.softmax(scores, axis=-1)
|
|
|
attention_weights = self.dropout(attention_weights, training=training)
|
|
|
|
|
|
attention_output = tf.matmul(attention_weights, v)
|
|
|
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
|
|
|
attention_output = tf.reshape(attention_output, (batch_size, seq_len, self.embed_dim))
|
|
|
|
|
|
output = self.output_proj(attention_output)
|
|
|
return output
|
|
|
|
|
|
class FeedForward(layers.Layer):
|
|
|
"""Standard Feed-forward network (used in non-MoE layers)"""
|
|
|
|
|
|
def __init__(self, embed_dim=512, ffn_dim=2048, dropout=0.1, **kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
self.embed_dim = embed_dim
|
|
|
self.ffn_dim = ffn_dim
|
|
|
|
|
|
self.dense1 = layers.Dense(ffn_dim, activation='gelu', name='dense1')
|
|
|
self.dense2 = layers.Dense(embed_dim, name='dense2')
|
|
|
self.dropout = layers.Dropout(dropout)
|
|
|
|
|
|
def call(self, x, training=False):
|
|
|
x = self.dense1(x)
|
|
|
x = self.dropout(x, training=training)
|
|
|
x = self.dense2(x)
|
|
|
return x
|
|
|
|
|
|
class MoETransformerBlock(layers.Layer):
|
|
|
"""Transformer block with optional MoE FFN"""
|
|
|
|
|
|
def __init__(self, embed_dim=512, num_heads=8, ffn_dim=2048, dropout=0.1,
|
|
|
layer_norm_epsilon=1e-5, max_seq_len=256, block_id=0,
|
|
|
use_moe=False, num_experts=4, top_k_experts=1,
|
|
|
load_balancing_loss_weight=0.01, router_z_loss_weight=0.001, **kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
self.use_moe = use_moe
|
|
|
|
|
|
|
|
|
self.attention = MultiHeadAttention(
|
|
|
embed_dim=embed_dim,
|
|
|
num_heads=num_heads,
|
|
|
dropout=dropout,
|
|
|
max_seq_len=max_seq_len,
|
|
|
use_rotary_embeddings=True,
|
|
|
name=f'attention_{block_id}'
|
|
|
)
|
|
|
|
|
|
|
|
|
if use_moe:
|
|
|
self.ffn = MixtureOfExperts(
|
|
|
embed_dim=embed_dim,
|
|
|
ffn_dim=ffn_dim,
|
|
|
num_experts=num_experts,
|
|
|
top_k=top_k_experts,
|
|
|
dropout=dropout,
|
|
|
load_balancing_loss_weight=load_balancing_loss_weight,
|
|
|
router_z_loss_weight=router_z_loss_weight,
|
|
|
name=f'moe_ffn_{block_id}'
|
|
|
)
|
|
|
else:
|
|
|
self.ffn = FeedForward(
|
|
|
embed_dim=embed_dim,
|
|
|
ffn_dim=ffn_dim,
|
|
|
dropout=dropout,
|
|
|
name=f'ffn_{block_id}'
|
|
|
)
|
|
|
|
|
|
|
|
|
self.ln1 = layers.LayerNormalization(epsilon=layer_norm_epsilon, name=f'ln1_{block_id}')
|
|
|
self.ln2 = layers.LayerNormalization(epsilon=layer_norm_epsilon, name=f'ln2_{block_id}')
|
|
|
|
|
|
|
|
|
self.dropout1 = layers.Dropout(dropout, name=f'dropout1_{block_id}')
|
|
|
self.dropout2 = layers.Dropout(dropout, name=f'dropout2_{block_id}')
|
|
|
|
|
|
def call(self, x, mask=None, training=False):
|
|
|
|
|
|
attn_input = self.ln1(x)
|
|
|
attn_output = self.attention(attn_input, mask=mask, training=training)
|
|
|
attn_output = self.dropout1(attn_output, training=training)
|
|
|
x = x + attn_output
|
|
|
|
|
|
|
|
|
ffn_input = self.ln2(x)
|
|
|
|
|
|
if self.use_moe:
|
|
|
ffn_output, aux_losses = self.ffn(ffn_input, training=training)
|
|
|
|
|
|
ffn_output = self.dropout2(ffn_output, training=training)
|
|
|
x = x + ffn_output
|
|
|
return x, aux_losses
|
|
|
else:
|
|
|
ffn_output = self.ffn(ffn_input, training=training)
|
|
|
ffn_output = self.dropout2(ffn_output, training=training)
|
|
|
x = x + ffn_output
|
|
|
return x, {}
|
|
|
|
|
|
class MoEMiniGPT(Model):
|
|
|
"""MiniGPT model with Mixture of Experts."""
|
|
|
|
|
|
def __init__(self, config: MoEConfig = None, tokenizer_path="my-10k-bpe-tokenizer", **kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
self.config = config or MoEConfig()
|
|
|
self._tokenizer = None
|
|
|
|
|
|
|
|
|
try:
|
|
|
tokenizer = ByteLevelBPETokenizer(
|
|
|
vocab=f"{tokenizer_path}/vocab.json",
|
|
|
merges=f"{tokenizer_path}/merges.txt",
|
|
|
)
|
|
|
self._tokenizer = PreTrainedTokenizerFast(
|
|
|
tokenizer_object=tokenizer,
|
|
|
unk_token="<unk>",
|
|
|
pad_token="<pad>",
|
|
|
bos_token="<s>",
|
|
|
eos_token="</s>",
|
|
|
mask_token="<mask>",
|
|
|
)
|
|
|
except Exception as e:
|
|
|
print(f"Failed to load custom tokenizer: {e}")
|
|
|
self._tokenizer = None
|
|
|
|
|
|
|
|
|
self._build_layers()
|
|
|
|
|
|
def _build_layers(self):
|
|
|
"""Build the model architecture."""
|
|
|
|
|
|
self.token_embedding = layers.Embedding(
|
|
|
self.config.vocab_size,
|
|
|
self.config.embed_dim,
|
|
|
name="token_embedding"
|
|
|
)
|
|
|
|
|
|
|
|
|
if not self.config.use_rotary_embeddings:
|
|
|
self.pos_embedding = layers.Embedding(
|
|
|
self.config.max_seq_len,
|
|
|
self.config.embed_dim,
|
|
|
name="position_embedding"
|
|
|
)
|
|
|
|
|
|
|
|
|
self.transformer_blocks = []
|
|
|
for i in range(self.config.num_layers):
|
|
|
use_moe = i in self.config.use_moe_layers
|
|
|
block = MoETransformerBlock(
|
|
|
embed_dim=self.config.embed_dim,
|
|
|
num_heads=self.config.num_heads,
|
|
|
ffn_dim=self.config.ffn_dim,
|
|
|
dropout=self.config.dropout,
|
|
|
layer_norm_epsilon=self.config.layer_norm_epsilon,
|
|
|
max_seq_len=self.config.max_seq_len,
|
|
|
block_id=i,
|
|
|
use_moe=use_moe,
|
|
|
num_experts=self.config.num_experts,
|
|
|
top_k_experts=self.config.top_k_experts,
|
|
|
load_balancing_loss_weight=self.config.load_balancing_loss_weight,
|
|
|
router_z_loss_weight=self.config.router_z_loss_weight,
|
|
|
name=f"transformer_block_{i}"
|
|
|
)
|
|
|
self.transformer_blocks.append(block)
|
|
|
|
|
|
|
|
|
self.final_layer_norm = layers.LayerNormalization(
|
|
|
epsilon=self.config.layer_norm_epsilon,
|
|
|
name="final_layer_norm"
|
|
|
)
|
|
|
|
|
|
|
|
|
self.output_projection = layers.Dense(
|
|
|
self.config.vocab_size,
|
|
|
use_bias=False,
|
|
|
name="output_projection"
|
|
|
)
|
|
|
|
|
|
|
|
|
self.dropout = layers.Dropout(self.config.dropout, name="embedding_dropout")
|
|
|
|
|
|
@property
|
|
|
def tokenizer(self):
|
|
|
return self._tokenizer
|
|
|
|
|
|
def call(self, inputs, mask=None, training=False):
|
|
|
"""Fixed forward pass that properly handles auxiliary losses."""
|
|
|
if isinstance(inputs, dict):
|
|
|
input_ids = inputs['input_ids']
|
|
|
mask = inputs.get('attention_mask', mask)
|
|
|
else:
|
|
|
input_ids = inputs
|
|
|
|
|
|
|
|
|
input_ids = tf.cast(input_ids, tf.int32)
|
|
|
seq_len = tf.shape(input_ids)[1]
|
|
|
|
|
|
|
|
|
x = self.token_embedding(input_ids)
|
|
|
|
|
|
|
|
|
if not self.config.use_rotary_embeddings:
|
|
|
positions = tf.range(seq_len, dtype=tf.int32)
|
|
|
position_embeddings = self.pos_embedding(positions)
|
|
|
x = x + position_embeddings
|
|
|
|
|
|
|
|
|
x = self.dropout(x, training=training)
|
|
|
|
|
|
|
|
|
total_aux_losses = {}
|
|
|
|
|
|
|
|
|
for block in self.transformer_blocks:
|
|
|
if hasattr(block, 'use_moe') and block.use_moe:
|
|
|
x, aux_losses = block(x, mask=mask, training=training)
|
|
|
|
|
|
for loss_name, loss_value in aux_losses.items():
|
|
|
if loss_name in total_aux_losses:
|
|
|
total_aux_losses[loss_name] += loss_value
|
|
|
else:
|
|
|
total_aux_losses[loss_name] = loss_value
|
|
|
else:
|
|
|
x, _ = block(x, mask=mask, training=training)
|
|
|
|
|
|
|
|
|
x = self.final_layer_norm(x)
|
|
|
|
|
|
|
|
|
logits = self.output_projection(x)
|
|
|
return logits, total_aux_losses
|
|
|
|
|
|
def compute_loss_method(self, input_ids, labels=None):
|
|
|
"""Fixed compute loss method that properly handles auxiliary losses."""
|
|
|
|
|
|
input_ids = tf.cast(input_ids, tf.int32)
|
|
|
|
|
|
|
|
|
if len(input_ids.shape) == 1:
|
|
|
input_ids = tf.expand_dims(input_ids, 0)
|
|
|
|
|
|
|
|
|
if labels is None:
|
|
|
|
|
|
input_for_model = input_ids[:, :-1]
|
|
|
labels = input_ids[:, 1:]
|
|
|
else:
|
|
|
|
|
|
labels = tf.cast(labels, tf.int32)
|
|
|
if len(labels.shape) == 1:
|
|
|
labels = tf.expand_dims(labels, 0)
|
|
|
input_for_model = input_ids
|
|
|
|
|
|
|
|
|
if labels is None:
|
|
|
raise ValueError("Labels cannot be None after processing")
|
|
|
|
|
|
|
|
|
logits, aux_losses = self(input_for_model, training=True)
|
|
|
|
|
|
|
|
|
batch_size = tf.shape(labels)[0]
|
|
|
seq_len = tf.shape(labels)[1]
|
|
|
vocab_size = tf.shape(logits)[-1]
|
|
|
|
|
|
|
|
|
labels_flat = tf.reshape(labels, [-1])
|
|
|
logits_flat = tf.reshape(logits, [-1, vocab_size])
|
|
|
|
|
|
|
|
|
pad_token_id = -100
|
|
|
valid_mask = tf.not_equal(labels_flat, pad_token_id)
|
|
|
|
|
|
|
|
|
losses = tf.keras.losses.sparse_categorical_crossentropy(
|
|
|
labels_flat,
|
|
|
logits_flat,
|
|
|
from_logits=True
|
|
|
)
|
|
|
|
|
|
|
|
|
num_valid = tf.reduce_sum(tf.cast(valid_mask, tf.float32))
|
|
|
if tf.greater(num_valid, 0):
|
|
|
|
|
|
masked_losses = tf.where(valid_mask, losses, 0.0)
|
|
|
main_loss = tf.reduce_sum(masked_losses) / tf.maximum(num_valid, 1.0)
|
|
|
else:
|
|
|
|
|
|
main_loss = tf.reduce_mean(losses)
|
|
|
|
|
|
|
|
|
auxiliary_loss = 0.0
|
|
|
if aux_losses:
|
|
|
auxiliary_loss = tf.add_n([v for v in aux_losses.values()])
|
|
|
total_loss = main_loss + auxiliary_loss
|
|
|
|
|
|
return total_loss
|
|
|
|
|
|
def get_expert_utilization(self):
|
|
|
"""Get expert utilization statistics from MoE layers."""
|
|
|
utilization_stats = {}
|
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
|
if hasattr(block, 'ffn') and hasattr(block.ffn, 'router'):
|
|
|
layer_name = f"layer_{i}"
|
|
|
utilization_stats[layer_name] = {
|
|
|
'is_moe_layer': True,
|
|
|
'num_experts': block.ffn.num_experts,
|
|
|
'top_k': block.ffn.top_k
|
|
|
}
|
|
|
else:
|
|
|
layer_name = f"layer_{i}"
|
|
|
utilization_stats[layer_name] = {
|
|
|
'is_moe_layer': False
|
|
|
}
|
|
|
|
|
|
return utilization_stats
|
|
|
|
|
|
def get_model_info(self):
|
|
|
"""Get comprehensive model information."""
|
|
|
|
|
|
if not self.built:
|
|
|
dummy_input = tf.ones((1, self.config.max_seq_len), dtype=tf.int32)
|
|
|
_ = self(dummy_input)
|
|
|
total_params = sum([tf.reduce_prod(var.shape) for var in self.trainable_variables])
|
|
|
|
|
|
moe_layers = [i for i in range(self.config.num_layers) if i in self.config.use_moe_layers]
|
|
|
standard_layers = [i for i in range(self.config.num_layers) if i not in self.config.use_moe_layers]
|
|
|
|
|
|
info = {
|
|
|
'model_type': 'MoE MiniGPT',
|
|
|
'total_parameters': int(total_params),
|
|
|
'vocab_size': self.config.vocab_size,
|
|
|
'embed_dim': self.config.embed_dim,
|
|
|
'num_layers': self.config.num_layers,
|
|
|
'num_heads': self.config.num_heads,
|
|
|
'max_seq_len': self.config.max_seq_len,
|
|
|
'moe_config': {
|
|
|
'num_experts': self.config.num_experts,
|
|
|
'top_k_experts': self.config.top_k_experts,
|
|
|
'moe_layers': moe_layers,
|
|
|
'standard_layers': standard_layers,
|
|
|
'expert_capacity': self.config.expert_capacity
|
|
|
},
|
|
|
'tokenizer_available': self.tokenizer is not None
|
|
|
}
|
|
|
|
|
|
return info
|
|
|
|
|
|
def save_model(self, filepath: str):
|
|
|
"""Save model weights and configuration."""
|
|
|
|
|
|
self.save_weights(filepath + '_weights')
|
|
|
|
|
|
|
|
|
config_dict = {
|
|
|
'vocab_size': self.config.vocab_size,
|
|
|
'max_seq_len': self.config.max_seq_len,
|
|
|
'embed_dim': self.config.embed_dim,
|
|
|
'num_heads': self.config.num_heads,
|
|
|
'num_layers': self.config.num_layers,
|
|
|
'ffn_dim': self.config.ffn_dim,
|
|
|
'dropout': self.config.dropout,
|
|
|
'layer_norm_epsilon': self.config.layer_norm_epsilon,
|
|
|
'use_rotary_embeddings': self.config.use_rotary_embeddings,
|
|
|
'num_experts': self.config.num_experts,
|
|
|
'top_k_experts': self.config.top_k_experts,
|
|
|
'expert_capacity': self.config.expert_capacity,
|
|
|
'load_balancing_loss_weight': self.config.load_balancing_loss_weight,
|
|
|
'router_z_loss_weight': self.config.router_z_loss_weight,
|
|
|
'use_moe_layers': self.config.use_moe_layers
|
|
|
}
|
|
|
|
|
|
with open(filepath + '_config.json', 'w') as f:
|
|
|
json.dump(config_dict, f, indent=2)
|
|
|
|
|
|
logger.info(f"Model saved to {filepath}")
|
|
|
|
|
|
@classmethod
|
|
|
def load_model(cls, filepath: str):
|
|
|
"""Load model from saved weights and configuration."""
|
|
|
|
|
|
with open(filepath + '_config.json', 'r') as f:
|
|
|
config_dict = json.load(f)
|
|
|
|
|
|
config = MoEConfig(**config_dict)
|
|
|
model = cls(config)
|
|
|
|
|
|
|
|
|
dummy_input = tf.ones((1, config.max_seq_len), dtype=tf.int32)
|
|
|
_ = model(dummy_input)
|
|
|
|
|
|
|
|
|
model.load_weights(filepath + '_weights')
|
|
|
|
|
|
logger.info(f"Model loaded from {filepath}")
|
|
|
return model
|
|
|
|
|
|
def generate_text(self, prompt, max_length=50, temperature=1.0):
|
|
|
"""
|
|
|
Simple greedy text generation for demonstration.
|
|
|
"""
|
|
|
if self.tokenizer is None:
|
|
|
raise RuntimeError("Tokenizer is not available for text generation.")
|
|
|
input_ids = self.tokenizer.encode(
|
|
|
prompt, return_tensors="tf", max_length=self.config.max_seq_len, truncation=True
|
|
|
)
|
|
|
for _ in range(max_length):
|
|
|
logits, _ = self(input_ids)
|
|
|
logits = logits[:, -1, :] / temperature
|
|
|
next_token_id = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
|
|
next_token_id = tf.expand_dims(next_token_id, axis=-1)
|
|
|
input_ids = tf.concat([input_ids, next_token_id], axis=-1)
|
|
|
|
|
|
if hasattr(self.tokenizer, "eos_token_id") and next_token_id[0, 0].numpy() == self.tokenizer.eos_token_id:
|
|
|
break
|
|
|
output_ids = input_ids.numpy()[0]
|
|
|
return self.tokenizer.decode(output_ids, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
def create_sample_model():
|
|
|
"""Create a sample MoE MiniGPT model for testing."""
|
|
|
config = MoEConfig(
|
|
|
vocab_size=10000,
|
|
|
max_seq_len=256,
|
|
|
embed_dim=512,
|
|
|
num_heads=8,
|
|
|
num_layers=8,
|
|
|
ffn_dim=2048,
|
|
|
num_experts=4,
|
|
|
top_k_experts=1,
|
|
|
use_moe_layers=[2, 4, 6],
|
|
|
batch_size=32,
|
|
|
seq_len=256
|
|
|
)
|
|
|
model = MoEMiniGPT(config, tokenizer_path="my-10k-bpe-tokenizer")
|
|
|
dummy_input = tf.ones((1, config.max_seq_len), dtype=tf.int32)
|
|
|
_ = model(dummy_input)
|
|
|
logger.info("Sample MoE MiniGPT model created successfully!")
|
|
|
logger.info(f"Model info: {model.get_model_info()}")
|
|
|
return model
|
|
|
|
|
|
|
|
|
def create_dummy_dataset(vocab_size: int = 10000, seq_len: int = 256,
|
|
|
batch_size: int = 48, num_batches: int = 100):
|
|
|
"""Create a dummy dataset for testing, yielding (input_ids, labels) pairs."""
|
|
|
def generator():
|
|
|
for _ in range(num_batches):
|
|
|
|
|
|
batch = tf.random.uniform(
|
|
|
(batch_size, seq_len),
|
|
|
minval=0,
|
|
|
maxval=vocab_size,
|
|
|
dtype=tf.int32
|
|
|
)
|
|
|
|
|
|
input_ids = batch[:, :-1]
|
|
|
labels = batch[:, 1:]
|
|
|
yield input_ids, labels
|
|
|
dataset = tf.data.Dataset.from_generator(
|
|
|
generator,
|
|
|
output_signature=(
|
|
|
tf.TensorSpec(shape=(batch_size, seq_len-1), dtype=tf.int32),
|
|
|
tf.TensorSpec(shape=(batch_size, seq_len-1), dtype=tf.int32)
|
|
|
)
|
|
|
)
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
model = create_sample_model()
|
|
|
|
|
|
|
|
|
if model.tokenizer is not None and hasattr(model, "generate_text"):
|
|
|
print("MiniGPT Chat! Type 'quit' to exit.")
|
|
|
while True:
|
|
|
user_input = input("You: ")
|
|
|
if user_input.strip().lower() == "quit":
|
|
|
print("Exiting chat.")
|
|
|
break
|
|
|
response = model.generate_text(user_input, max_length=50)
|
|
|
print(f"MiniGPT: {response}")
|
|
|
else:
|
|
|
print("Tokenizer not available or generate_text not implemented. Model created successfully but text generation requires transformers library.")
|
|
|
|
|
|
|
|
|
print(f"Model info: {model.get_model_info()}")
|
|
|
print(f"Expert utilization: {model.get_expert_utilization()}")
|
|
|
|
|
|
|