Skip to content

Commit 80ad8cc

Browse files
committed
add sigmoid contrastive loss
1 parent b95e54b commit 80ad8cc

File tree

3 files changed

+93
-40
lines changed

3 files changed

+93
-40
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,14 @@ music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples
220220
}
221221
```
222222

223+
```bibtex
224+
@inproceedings{Zhai2023SigmoidLF,
225+
title = {Sigmoid Loss for Language Image Pre-Training},
226+
author = {Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer},
227+
year = {2023}
228+
}
229+
```
230+
223231
*The only truth is music.* - Jack Kerouac
224232

225233
*Music is the universal language of mankind.* - Henry Wadsworth Longfellow

musiclm_pytorch/musiclm_pytorch.py

Lines changed: 84 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from functools import wraps
23

34
import torch
@@ -248,6 +249,76 @@ def forward(
248249

249250
return x, torch.stack(layers[:-1])
250251

252+
# contrastive losses
253+
254+
class SoftmaxContrastiveLearning(nn.Module):
255+
def __init__(
256+
self,
257+
*,
258+
layers = 1,
259+
decoupled_contrastive_learning = False,
260+
init_temp = 10
261+
):
262+
super().__init__()
263+
self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
264+
self.decoupled_contrastive_learning = decoupled_contrastive_learning
265+
266+
@property
267+
def device(self):
268+
return next(self.parameters()).device
269+
270+
def forward(self, sims):
271+
batch = sims.shape[-1]
272+
273+
if sims.ndim == 2:
274+
sims = rearrange(sims, 'i j -> 1 i j')
275+
276+
sims = sims * self.temperatures.exp()
277+
278+
cosine_sims_exp = sims.exp()
279+
280+
numerator = matrix_diag(cosine_sims_exp)
281+
282+
if self.decoupled_contrastive_learning:
283+
eye = torch.eye(batch, device = self.device, dtype = torch.bool)
284+
cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.)
285+
286+
denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum')
287+
denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum')
288+
289+
contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))
290+
291+
contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean')
292+
return contrastive_loss.sum()
293+
294+
class SigmoidContrastiveLearning(nn.Module):
295+
""" https://arxiv.org/abs/2303.15343 """
296+
297+
def __init__(
298+
self,
299+
*,
300+
layers = 1,
301+
init_temp = 10,
302+
init_bias = -10
303+
):
304+
super().__init__()
305+
self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
306+
self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias)
307+
308+
@property
309+
def device(self):
310+
return next(self.parameters()).device
311+
312+
def forward(self, sims):
313+
if sims.ndim == 2:
314+
sims = rearrange(sims, 'i j -> 1 i j')
315+
316+
n = sims.shape[-1]
317+
sims = sims * self.temperatures.exp() + self.bias
318+
labels = 2 * rearrange(torch.eye(n), 'i j -> 1 i j') - torch.ones_like(sims)
319+
320+
return -F.logsigmoid(labels * sims).sum() / n
321+
251322
# Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778
252323

253324
def pair(t):
@@ -539,7 +610,8 @@ def __init__(
539610
text_dim,
540611
dim_latent,
541612
layers,
542-
decoupled_contrastive_learning = False
613+
decoupled_contrastive_learning = False,
614+
sigmoid_contrastive_loss = False
543615
):
544616
super().__init__()
545617
self.layers = layers
@@ -554,9 +626,8 @@ def __init__(
554626
self.text_latent_weight = nn.Parameter(torch.randn(layers, text_dim, dim_latent))
555627
self.text_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent))
556628

557-
self.temperatures = nn.Parameter(torch.ones(layers, 1, 1))
558-
559-
self.decoupled_contrastive_learning = decoupled_contrastive_learning
629+
klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning)
630+
self.contrast = klass(layers = layers)
560631

561632
def forward(self, *, audio_layers, text_layers):
562633
device, batch = audio_layers.device, audio_layers.shape[1]
@@ -571,23 +642,9 @@ def forward(self, *, audio_layers, text_layers):
571642
text_latents = einsum('l b d, l d e -> l b e', text_embeds, self.text_latent_weight) + self.text_latent_bias
572643
text_latents = l2norm(text_latents)
573644

574-
cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) * self.temperatures.exp()
575-
576-
cosine_sims_exp = cosine_sims.exp()
645+
cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents)
577646

578-
numerator = matrix_diag(cosine_sims_exp)
579-
580-
if self.decoupled_contrastive_learning:
581-
eye = torch.eye(batch, device = device, dtype = torch.bool)
582-
cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.)
583-
584-
denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum')
585-
denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum')
586-
587-
contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))
588-
589-
contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean')
590-
return contrastive_loss.sum()
647+
return self.contrast(cosine_sims)
591648

592649
# main classes
593650

@@ -600,20 +657,21 @@ def __init__(
600657
dim_latent = 128, # they use 128
601658
decoupled_contrastive_learning = True, # think this was used, make it optional
602659
hierarchical_contrastive_loss = False,
603-
hierarchical_contrastive_loss_layers = None
660+
hierarchical_contrastive_loss_layers = None,
661+
sigmoid_contrastive_loss = False
604662
):
605663
super().__init__()
606664
self.dim_latent = dim_latent
607665

608666
self.audio = audio_transformer
609667
self.text = text_transformer
610668

611-
self.temperature = nn.Parameter(torch.tensor(1.))
612669

613670
self.text_to_latents = nn.Linear(self.text.dim, dim_latent)
614671
self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)
615672

616-
self.decoupled_contrastive_learning = decoupled_contrastive_learning
673+
klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning)
674+
self.contrast = klass()
617675

618676
self.multi_layer_contrastive_learning = None
619677

@@ -629,7 +687,8 @@ def __init__(
629687
text_dim = self.text.dim,
630688
dim_latent = dim_latent,
631689
layers = num_layers,
632-
decoupled_contrastive_learning = decoupled_contrastive_learning
690+
decoupled_contrastive_learning = decoupled_contrastive_learning,
691+
sigmoid_contrastive_loss = sigmoid_contrastive_loss
633692
)
634693

635694
def get_audio_latents(
@@ -688,21 +747,7 @@ def forward(
688747
if return_pairwise_similarities:
689748
return cosine_sim
690749

691-
cosine_sim = cosine_sim * self.temperature.exp()
692-
693-
cosine_sim_exp = cosine_sim.exp()
694-
695-
numerator = cosine_sim_exp.diag()
696-
697-
if self.decoupled_contrastive_learning:
698-
eye = torch.eye(batch, device = device, dtype = torch.bool)
699-
cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.)
700-
701-
denominator_i = reduce(cosine_sim_exp, 'i j -> i', 'sum')
702-
denominator_j = reduce(cosine_sim_exp, 'i j -> j', 'sum')
703-
704-
contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))
705-
cl_loss = contrastive_loss.mean()
750+
cl_loss = self.contrast(cosine_sim)
706751

707752
if not exists(self.multi_layer_contrastive_learning):
708753
return cl_loss

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'musiclm-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.2',
6+
version = '0.2.0',
77
license='MIT',
88
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)