1+ import math
12from functools import wraps
23
34import torch
@@ -248,6 +249,76 @@ def forward(
248249
249250 return x , torch .stack (layers [:- 1 ])
250251
252+ # contrastive losses
253+
254+ class SoftmaxContrastiveLearning (nn .Module ):
255+ def __init__ (
256+ self ,
257+ * ,
258+ layers = 1 ,
259+ decoupled_contrastive_learning = False ,
260+ init_temp = 10
261+ ):
262+ super ().__init__ ()
263+ self .temperatures = nn .Parameter (torch .ones (layers , 1 , 1 ) * math .log (init_temp ))
264+ self .decoupled_contrastive_learning = decoupled_contrastive_learning
265+
266+ @property
267+ def device (self ):
268+ return next (self .parameters ()).device
269+
270+ def forward (self , sims ):
271+ batch = sims .shape [- 1 ]
272+
273+ if sims .ndim == 2 :
274+ sims = rearrange (sims , 'i j -> 1 i j' )
275+
276+ sims = sims * self .temperatures .exp ()
277+
278+ cosine_sims_exp = sims .exp ()
279+
280+ numerator = matrix_diag (cosine_sims_exp )
281+
282+ if self .decoupled_contrastive_learning :
283+ eye = torch .eye (batch , device = self .device , dtype = torch .bool )
284+ cosine_sims_exp = cosine_sims_exp .masked_fill (eye , 0. )
285+
286+ denominator_i = reduce (cosine_sims_exp , 'l i j -> l i' , 'sum' )
287+ denominator_j = reduce (cosine_sims_exp , 'l i j -> l j' , 'sum' )
288+
289+ contrastive_loss = - log (numerator ) + 0.5 * (log (denominator_i ) + log (denominator_j ))
290+
291+ contrastive_loss = reduce (contrastive_loss , 'l n -> l' , 'mean' )
292+ return contrastive_loss .sum ()
293+
294+ class SigmoidContrastiveLearning (nn .Module ):
295+ """ https://arxiv.org/abs/2303.15343 """
296+
297+ def __init__ (
298+ self ,
299+ * ,
300+ layers = 1 ,
301+ init_temp = 10 ,
302+ init_bias = - 10
303+ ):
304+ super ().__init__ ()
305+ self .temperatures = nn .Parameter (torch .ones (layers , 1 , 1 ) * math .log (init_temp ))
306+ self .bias = nn .Parameter (torch .ones (layers , 1 , 1 ) * init_bias )
307+
308+ @property
309+ def device (self ):
310+ return next (self .parameters ()).device
311+
312+ def forward (self , sims ):
313+ if sims .ndim == 2 :
314+ sims = rearrange (sims , 'i j -> 1 i j' )
315+
316+ n = sims .shape [- 1 ]
317+ sims = sims * self .temperatures .exp () + self .bias
318+ labels = 2 * rearrange (torch .eye (n ), 'i j -> 1 i j' ) - torch .ones_like (sims )
319+
320+ return - F .logsigmoid (labels * sims ).sum () / n
321+
251322# Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778
252323
253324def pair (t ):
@@ -539,7 +610,8 @@ def __init__(
539610 text_dim ,
540611 dim_latent ,
541612 layers ,
542- decoupled_contrastive_learning = False
613+ decoupled_contrastive_learning = False ,
614+ sigmoid_contrastive_loss = False
543615 ):
544616 super ().__init__ ()
545617 self .layers = layers
@@ -554,9 +626,8 @@ def __init__(
554626 self .text_latent_weight = nn .Parameter (torch .randn (layers , text_dim , dim_latent ))
555627 self .text_latent_bias = nn .Parameter (torch .randn (layers , 1 , dim_latent ))
556628
557- self .temperatures = nn .Parameter (torch .ones (layers , 1 , 1 ))
558-
559- self .decoupled_contrastive_learning = decoupled_contrastive_learning
629+ klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial (SoftmaxContrastiveLearning , decoupled_contrastive_learning = decoupled_contrastive_learning )
630+ self .contrast = klass (layers = layers )
560631
561632 def forward (self , * , audio_layers , text_layers ):
562633 device , batch = audio_layers .device , audio_layers .shape [1 ]
@@ -571,23 +642,9 @@ def forward(self, *, audio_layers, text_layers):
571642 text_latents = einsum ('l b d, l d e -> l b e' , text_embeds , self .text_latent_weight ) + self .text_latent_bias
572643 text_latents = l2norm (text_latents )
573644
574- cosine_sims = einsum ('l i d, l j d -> l i j' , audio_latents , text_latents ) * self .temperatures .exp ()
575-
576- cosine_sims_exp = cosine_sims .exp ()
645+ cosine_sims = einsum ('l i d, l j d -> l i j' , audio_latents , text_latents )
577646
578- numerator = matrix_diag (cosine_sims_exp )
579-
580- if self .decoupled_contrastive_learning :
581- eye = torch .eye (batch , device = device , dtype = torch .bool )
582- cosine_sims_exp = cosine_sims_exp .masked_fill (eye , 0. )
583-
584- denominator_i = reduce (cosine_sims_exp , 'l i j -> l i' , 'sum' )
585- denominator_j = reduce (cosine_sims_exp , 'l i j -> l j' , 'sum' )
586-
587- contrastive_loss = - log (numerator ) + 0.5 * (log (denominator_i ) + log (denominator_j ))
588-
589- contrastive_loss = reduce (contrastive_loss , 'l n -> l' , 'mean' )
590- return contrastive_loss .sum ()
647+ return self .contrast (cosine_sims )
591648
592649# main classes
593650
@@ -600,20 +657,21 @@ def __init__(
600657 dim_latent = 128 , # they use 128
601658 decoupled_contrastive_learning = True , # think this was used, make it optional
602659 hierarchical_contrastive_loss = False ,
603- hierarchical_contrastive_loss_layers = None
660+ hierarchical_contrastive_loss_layers = None ,
661+ sigmoid_contrastive_loss = False
604662 ):
605663 super ().__init__ ()
606664 self .dim_latent = dim_latent
607665
608666 self .audio = audio_transformer
609667 self .text = text_transformer
610668
611- self .temperature = nn .Parameter (torch .tensor (1. ))
612669
613670 self .text_to_latents = nn .Linear (self .text .dim , dim_latent )
614671 self .audio_to_latents = nn .Linear (self .audio .dim , dim_latent )
615672
616- self .decoupled_contrastive_learning = decoupled_contrastive_learning
673+ klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial (SoftmaxContrastiveLearning , decoupled_contrastive_learning = decoupled_contrastive_learning )
674+ self .contrast = klass ()
617675
618676 self .multi_layer_contrastive_learning = None
619677
@@ -629,7 +687,8 @@ def __init__(
629687 text_dim = self .text .dim ,
630688 dim_latent = dim_latent ,
631689 layers = num_layers ,
632- decoupled_contrastive_learning = decoupled_contrastive_learning
690+ decoupled_contrastive_learning = decoupled_contrastive_learning ,
691+ sigmoid_contrastive_loss = sigmoid_contrastive_loss
633692 )
634693
635694 def get_audio_latents (
@@ -688,21 +747,7 @@ def forward(
688747 if return_pairwise_similarities :
689748 return cosine_sim
690749
691- cosine_sim = cosine_sim * self .temperature .exp ()
692-
693- cosine_sim_exp = cosine_sim .exp ()
694-
695- numerator = cosine_sim_exp .diag ()
696-
697- if self .decoupled_contrastive_learning :
698- eye = torch .eye (batch , device = device , dtype = torch .bool )
699- cosine_sim_exp = cosine_sim_exp .masked_fill (eye , 0. )
700-
701- denominator_i = reduce (cosine_sim_exp , 'i j -> i' , 'sum' )
702- denominator_j = reduce (cosine_sim_exp , 'i j -> j' , 'sum' )
703-
704- contrastive_loss = - log (numerator ) + 0.5 * (log (denominator_i ) + log (denominator_j ))
705- cl_loss = contrastive_loss .mean ()
750+ cl_loss = self .contrast (cosine_sim )
706751
707752 if not exists (self .multi_layer_contrastive_learning ):
708753 return cl_loss
0 commit comments