Skip to content

Commit b5ca8ae

Browse files
committed
complete distributed logic for sigmoid contrastive loss
1 parent 0876027 commit b5ca8ae

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

musiclm_pytorch/distributed.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,21 @@ def backward(ctx, grads, _):
4949
return grads_by_rank[rank], None, None
5050

5151
all_gather = AllGather.apply
52+
53+
class AllGatherAllReduceGrads(Function):
54+
@staticmethod
55+
def forward(ctx, x, dim, sizes):
56+
assert distributed.is_initialized() and distributed.get_world_size() > 1
57+
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
58+
ctx.batch_sizes = batch_sizes.tolist()
59+
ctx.dim = dim
60+
return x, batch_sizes
61+
62+
@staticmethod
63+
def backward(ctx, grads, _):
64+
distributed.all_reduce(grads)
65+
batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
66+
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
67+
return grads_by_rank[rank], None, None
68+
69+
all_gather_all_reduce_grads = AllGatherAllReduceGrads.apply

musiclm_pytorch/musiclm_pytorch.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from audiolm_pytorch.utils import AudioConditionerBase
1212

1313
import torch.distributed as dist
14-
from musiclm_pytorch.distributed import all_gather
14+
from musiclm_pytorch.distributed import all_gather, all_gather_all_reduce_grads
1515

1616
from x_clip.tokenizer import tokenizer
1717
from vector_quantize_pytorch import ResidualVQ
@@ -320,23 +320,37 @@ def __init__(
320320
self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
321321
self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias)
322322

323+
self.needs_all_gather = dist.is_initialized() and dist.get_world_size() > 1
324+
323325
@property
324326
def device(self):
325327
return next(self.parameters()).device
326328

327329
def forward(self, audio_latents, text_latents):
330+
device = self.device
331+
328332
if audio_latents.ndim == 2:
329333
audio_latents = rearrange(audio_latents, '... -> 1 ...')
330334

331335
if text_latents.ndim == 2:
332336
text_latents = rearrange(text_latents, '... -> 1 ...')
333337

334-
n = audio_latents.shape[1]
338+
if self.needs_all_gather:
339+
text_latents, batch_sizes = all_gather_all_reduce_grads(text_latents, 1, None)
340+
341+
n = text_latents.shape[1]
335342

336343
sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents)
337344

338345
sims = sims * self.temperatures.exp() + self.bias
339-
labels = 2 * rearrange(torch.eye(n), 'i j -> 1 i j') - torch.ones_like(sims)
346+
347+
labels = torch.eye(n, device = device)
348+
349+
if self.needs_all_gather:
350+
labels_by_ranks = labels.split(batch_sizes.tolist(), dim = 0)
351+
labels = labels_by_ranks[dist.get_rank()]
352+
353+
labels = 2 * rearrange(labels, 'i j -> 1 i j') - torch.ones_like(sims)
340354

341355
return -F.logsigmoid(labels * sims).sum() / n
342356

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'musiclm-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.2.3',
6+
version = '0.2.4',
77
license='MIT',
88
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)