music-aesthetics / model_architecture.py
ChristophSchuhmann's picture
Upload model_architecture.py with huggingface_hub
1fd2f37 verified
raw
history blame
956 Bytes
import torch
import torch.nn as nn
class MusicAestheticsModel(nn.Module):
def __init__(self, input_dim=23040, bottleneck_dim=256, hidden_dim=64):
super().__init__()
self.metrics = ["Coherence", "Musicality", "Memorability", "Clarity", "Naturalness"]
# Shared Bottleneck
self.bottleneck = nn.Sequential(
nn.Linear(input_dim, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, bottleneck_dim),
nn.ReLU(),
nn.LayerNorm(bottleneck_dim)
)
# Multi-head Experts
self.heads = nn.ModuleDict({
n: nn.Sequential(
nn.Linear(bottleneck_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
) for n in self.metrics
})
def forward(self, x):
feat = self.bottleneck(x)
return {n: self.heads[n](feat) for n in self.heads}