|
11 | 11 | from audiolm_pytorch.utils import AudioConditionerBase |
12 | 12 |
|
13 | 13 | import torch.distributed as dist |
14 | | -from musiclm_pytorch.distributed import all_gather |
| 14 | +from musiclm_pytorch.distributed import all_gather, all_gather_all_reduce_grads |
15 | 15 |
|
16 | 16 | from x_clip.tokenizer import tokenizer |
17 | 17 | from vector_quantize_pytorch import ResidualVQ |
@@ -320,23 +320,37 @@ def __init__( |
320 | 320 | self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp)) |
321 | 321 | self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias) |
322 | 322 |
|
| 323 | + self.needs_all_gather = dist.is_initialized() and dist.get_world_size() > 1 |
| 324 | + |
323 | 325 | @property |
324 | 326 | def device(self): |
325 | 327 | return next(self.parameters()).device |
326 | 328 |
|
327 | 329 | def forward(self, audio_latents, text_latents): |
| 330 | + device = self.device |
| 331 | + |
328 | 332 | if audio_latents.ndim == 2: |
329 | 333 | audio_latents = rearrange(audio_latents, '... -> 1 ...') |
330 | 334 |
|
331 | 335 | if text_latents.ndim == 2: |
332 | 336 | text_latents = rearrange(text_latents, '... -> 1 ...') |
333 | 337 |
|
334 | | - n = audio_latents.shape[1] |
| 338 | + if self.needs_all_gather: |
| 339 | + text_latents, batch_sizes = all_gather_all_reduce_grads(text_latents, 1, None) |
| 340 | + |
| 341 | + n = text_latents.shape[1] |
335 | 342 |
|
336 | 343 | sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) |
337 | 344 |
|
338 | 345 | sims = sims * self.temperatures.exp() + self.bias |
339 | | - labels = 2 * rearrange(torch.eye(n), 'i j -> 1 i j') - torch.ones_like(sims) |
| 346 | + |
| 347 | + labels = torch.eye(n, device = device) |
| 348 | + |
| 349 | + if self.needs_all_gather: |
| 350 | + labels_by_ranks = labels.split(batch_sizes.tolist(), dim = 0) |
| 351 | + labels = labels_by_ranks[dist.get_rank()] |
| 352 | + |
| 353 | + labels = 2 * rearrange(labels, 'i j -> 1 i j') - torch.ones_like(sims) |
340 | 354 |
|
341 | 355 | return -F.logsigmoid(labels * sims).sum() / n |
342 | 356 |
|
|
0 commit comments