Skip to content

Commit ec2bfbd

Browse files
authored
Merge pull request #29 from BloodAxe/develop
PyTorch Toolbelt 0.2.1
2 parents cc5e997 + def5c59 commit ec2bfbd

File tree

6 files changed

+106
-80
lines changed

6 files changed

+106
-80
lines changed

README.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,19 @@ merged_mask = tiler.crop_to_orignal_size(merged_mask)
125125
## Advanced examples
126126

127127
1. [Inria Sattelite Segmentation](https://github.com/BloodAxe/Catalyst-Inria-Segmentation-Example)
128-
1. [CamVid Semantic Segmentation](https://github.com/BloodAxe/Catalyst-CamVid-Segmentation-Example)
128+
1. [CamVid Semantic Segmentation](https://github.com/BloodAxe/Catalyst-CamVid-Segmentation-Example)
129+
130+
131+
## Citation
132+
133+
```
134+
@misc{Khvedchenya_Eugene_2019_PyTorch_Toolbelt,
135+
author = {Khvedchenya, Eugene},
136+
title = {PyTorch Toolbelt},
137+
year = {2019},
138+
publisher = {GitHub},
139+
journal = {GitHub repository},
140+
howpublished = {\url{https://github.com/BloodAxe/pytorch-toolbelt}},
141+
commit = {cc5e9973cdb0dcbf1c6b6e1401bf44b9c69e13f3}
142+
}
143+
```

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = "0.2.0"
3+
__version__ = "0.2.1"

pytorch_toolbelt/losses/focal.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from torch.nn.modules.loss import _Loss
44

5-
from .functional import sigmoid_focal_loss, reduced_focal_loss
5+
from .functional import focal_loss_with_logits
66

77
__all__ = ["BinaryFocalLoss", "FocalLoss"]
88

@@ -31,14 +31,15 @@ def __init__(
3131
self.ignore_index = ignore_index
3232
if reduced:
3333
self.focal_loss = partial(
34-
reduced_focal_loss,
34+
focal_loss_with_logits,
35+
alpha=None,
3536
gamma=gamma,
3637
threshold=threshold,
3738
reduction=reduction,
3839
)
3940
else:
4041
self.focal_loss = partial(
41-
sigmoid_focal_loss, gamma=gamma, alpha=alpha, reduction=reduction
42+
focal_loss_with_logits, alpha=alpha, gamma=gamma, reduction=reduction
4243
)
4344

4445
def forward(self, label_input, label_target):
@@ -87,7 +88,7 @@ def forward(self, label_input, label_target):
8788
cls_label_target = cls_label_target[not_ignored]
8889
cls_label_input = cls_label_input[not_ignored]
8990

90-
loss += sigmoid_focal_loss(
91+
loss += focal_loss_with_logits(
9192
cls_label_input, cls_label_target, gamma=self.gamma, alpha=self.alpha
9293
)
9394
return loss

pytorch_toolbelt/losses/functional.py

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
11
import math
2+
from typing import Optional
23

34
import torch
45
import torch.nn.functional as F
56

6-
__all__ = ["sigmoid_focal_loss", "soft_jaccard_score", "soft_dice_score", "wing_loss"]
7+
__all__ = [
8+
"focal_loss_with_logits",
9+
"sigmoid_focal_loss",
10+
"soft_jaccard_score",
11+
"soft_dice_score",
12+
"wing_loss",
13+
]
714

815

9-
def sigmoid_focal_loss(
10-
input: torch.Tensor, target: torch.Tensor, gamma=2.0, alpha=0.25, reduction="mean"
11-
):
16+
def focal_loss_with_logits(
17+
input: torch.Tensor,
18+
target: torch.Tensor,
19+
gamma=2.0,
20+
alpha: Optional[float] = 0.25,
21+
reduction="mean",
22+
normalized=False,
23+
threshold: Optional[float] = None,
24+
) -> torch.Tensor:
1225
"""Compute binary focal loss between target and output logits.
1326
1427
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
@@ -23,7 +36,8 @@ def sigmoid_focal_loss(
2336
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
2437
specifying either of those two args will override :attr:`reduction`.
2538
'batchwise_mean' computes mean loss per sample in batch. Default: 'mean'
26-
39+
normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
40+
threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347).
2741
References::
2842
2943
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py
@@ -34,11 +48,21 @@ def sigmoid_focal_loss(
3448
pt = torch.exp(logpt)
3549

3650
# compute the loss
37-
loss = -((1 - pt).pow(gamma)) * logpt
51+
if threshold is None:
52+
focal_term = (1 - pt).pow(gamma)
53+
else:
54+
focal_term = ((1.0 - pt) / threshold).pow(gamma)
55+
focal_term[pt < threshold] = 1
56+
57+
loss = -focal_term * logpt
3858

3959
if alpha is not None:
4060
loss = loss * (alpha * target + (1 - alpha) * (1 - target))
4161

62+
if normalized:
63+
norm_factor = focal_term.sum()
64+
loss = loss / norm_factor
65+
4266
if reduction == "mean":
4367
loss = loss.mean()
4468
if reduction == "sum":
@@ -49,51 +73,21 @@ def sigmoid_focal_loss(
4973
return loss
5074

5175

76+
# TODO: Mark as deprecated and emit warning
77+
sigmoid_focal_loss = focal_loss_with_logits
78+
79+
80+
# TODO: Mark as deprecated and emit warning
5281
def reduced_focal_loss(
5382
input: torch.Tensor,
5483
target: torch.Tensor,
5584
threshold=0.5,
5685
gamma=2.0,
5786
reduction="mean",
5887
):
59-
"""Compute reduced focal loss between target and output logits.
60-
61-
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
62-
63-
Args:
64-
input: Tensor of arbitrary shape
65-
target: Tensor of the same shape as input
66-
reduction (string, optional): Specifies the reduction to apply to the output:
67-
'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied,
68-
'mean': the sum of the output will be divided by the number of
69-
elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
70-
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
71-
specifying either of those two args will override :attr:`reduction`.
72-
'batchwise_mean' computes mean loss per sample in batch. Default: 'mean'
73-
74-
References::
75-
76-
https://arxiv.org/abs/1903.01347
77-
"""
78-
target = target.type(input.type())
79-
80-
logpt = -F.binary_cross_entropy_with_logits(input, target, reduction="none")
81-
pt = torch.exp(logpt)
82-
83-
# compute the loss
84-
focal_reduction = ((1.0 - pt) / threshold).pow(gamma)
85-
focal_reduction[pt < threshold] = 1
86-
87-
loss = -focal_reduction * logpt
88-
89-
if reduction == "mean":
90-
loss = loss.mean()
91-
if reduction == "sum":
92-
loss = loss.sum()
93-
if reduction == "batchwise_mean":
94-
loss = loss.sum(0)
95-
96-
return loss
88+
return focal_loss_with_logits(
89+
input, target, alpha=None, gamma=gamma, reduction=reduction, threshold=threshold
90+
)
9791

9892

9993
def soft_jaccard_score(

pytorch_toolbelt/modules/encoders.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -603,24 +603,34 @@ def __init__(
603603
strides: List[int],
604604
channels: List[int],
605605
layers: List[int],
606+
first_avg_pool=False,
606607
):
608+
if layers is None:
609+
layers = [1, 2, 3, 4]
610+
607611
super().__init__(channels, strides, layers)
608612

613+
def except_pool(block: nn.Module):
614+
del block.pool
615+
return block
616+
609617
self.layer0 = nn.Sequential(
610618
densenet.features.conv0, densenet.features.norm0, densenet.features.relu0
611619
)
612-
self.pool0 = densenet.features.pool0
620+
621+
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
622+
self.pool0 = self.avg_pool if first_avg_pool else densenet.features.pool0
613623

614624
self.layer1 = nn.Sequential(
615-
densenet.features.denseblock1, densenet.features.transition1
625+
densenet.features.denseblock1, except_pool(densenet.features.transition1)
616626
)
617627

618628
self.layer2 = nn.Sequential(
619-
densenet.features.denseblock2, densenet.features.transition2
629+
densenet.features.denseblock2, except_pool(densenet.features.transition2)
620630
)
621631

622632
self.layer3 = nn.Sequential(
623-
densenet.features.denseblock3, densenet.features.transition3
633+
densenet.features.denseblock3, except_pool(densenet.features.transition3)
624634
)
625635

626636
self.layer4 = nn.Sequential(densenet.features.denseblock4)
@@ -650,54 +660,60 @@ def forward(self, x):
650660
if layer == self.layer0:
651661
# Fist maxpool operator is not a part of layer0 because we want that layer0 output to have stride of 2
652662
output = self.pool0(output)
663+
else:
664+
output = self.avg_pool(output)
665+
653666
input = output
654667

655668
# Return only features that were requested
656669
return _take(output_features, self._layers)
657670

658671

659672
class DenseNet121Encoder(DenseNetEncoder):
660-
def __init__(self, layers=None, pretrained=True, memory_efficient=False):
661-
if layers is None:
662-
layers = [1, 2, 3, 4]
673+
def __init__(
674+
self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False
675+
):
663676
densenet = densenet121(pretrained=pretrained, memory_efficient=memory_efficient)
664677
strides = [2, 4, 8, 16, 32]
665678
channels = [64, 128, 256, 512, 1024]
666-
super().__init__(densenet, strides, channels, layers)
679+
super().__init__(densenet, strides, channels, layers, first_avg_pool)
667680

668681

669682
class DenseNet161Encoder(DenseNetEncoder):
670-
def __init__(self, layers=None, pretrained=True, memory_efficient=False):
671-
if layers is None:
672-
layers = [1, 2, 3, 4]
683+
def __init__(
684+
self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False
685+
):
673686
densenet = densenet161(pretrained=pretrained, memory_efficient=memory_efficient)
674687
strides = [2, 4, 8, 16, 32]
675688
channels = [96, 192, 384, 1056, 2208]
676-
super().__init__(densenet, strides, channels, layers)
689+
super().__init__(densenet, strides, channels, layers, first_avg_pool)
677690

678691

679692
class DenseNet169Encoder(DenseNetEncoder):
680-
def __init__(self, layers=None, pretrained=True, memory_efficient=False):
681-
if layers is None:
682-
layers = [1, 2, 3, 4]
693+
def __init__(
694+
self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False
695+
):
683696
densenet = densenet169(pretrained=pretrained, memory_efficient=memory_efficient)
684697
strides = [2, 4, 8, 16, 32]
685698
channels = [64, 128, 256, 640, 1664]
686-
super().__init__(densenet, strides, channels, layers)
699+
super().__init__(densenet, strides, channels, layers, first_avg_pool)
687700

688701

689702
class DenseNet201Encoder(DenseNetEncoder):
690-
def __init__(self, layers=None, pretrained=True, memory_efficient=False):
691-
if layers is None:
692-
layers = [1, 2, 3, 4]
703+
def __init__(
704+
self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False
705+
):
693706
densenet = densenet201(pretrained=pretrained, memory_efficient=memory_efficient)
694707
strides = [2, 4, 8, 16, 32]
695708
channels = [64, 128, 256, 896, 1920]
696-
super().__init__(densenet, strides, channels, layers)
709+
super().__init__(densenet, strides, channels, layers, first_avg_pool)
697710

698711

699712
class EfficientNetEncoder(EncoderModule):
700713
def __init__(self, efficientnet, filters, strides, layers):
714+
if layers is None:
715+
layers = [1, 2, 4, 6]
716+
701717
super().__init__(filters, strides, layers)
702718

703719
self.stem = efficientnet.stem
@@ -736,7 +752,7 @@ def forward(self, x):
736752

737753

738754
class EfficientNetB0Encoder(EfficientNetEncoder):
739-
def __init__(self, layers=[1, 2, 4, 6], **kwargs):
755+
def __init__(self, layers=None, **kwargs):
740756
super().__init__(
741757
efficient_net_b0(num_classes=1, **kwargs),
742758
[16, 24, 40, 80, 112, 192, 320],
@@ -746,7 +762,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
746762

747763

748764
class EfficientNetB1Encoder(EfficientNetEncoder):
749-
def __init__(self, layers=[1, 2, 4, 6], **kwargs):
765+
def __init__(self, layers=None, **kwargs):
750766
super().__init__(
751767
efficient_net_b1(num_classes=1, **kwargs),
752768
[16, 24, 40, 80, 112, 192, 320],
@@ -756,7 +772,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
756772

757773

758774
class EfficientNetB2Encoder(EfficientNetEncoder):
759-
def __init__(self, layers=[1, 2, 4, 6], **kwargs):
775+
def __init__(self, layers=None, **kwargs):
760776
super().__init__(
761777
efficient_net_b2(num_classes=1, **kwargs),
762778
[16, 24, 48, 88, 120, 208, 352],
@@ -766,7 +782,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
766782

767783

768784
class EfficientNetB3Encoder(EfficientNetEncoder):
769-
def __init__(self, layers=[1, 2, 4, 6], **kwargs):
785+
def __init__(self, layers=None, **kwargs):
770786
super().__init__(
771787
efficient_net_b3(num_classes=1, **kwargs),
772788
[24, 32, 48, 96, 136, 232, 384],
@@ -776,7 +792,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
776792

777793

778794
class EfficientNetB4Encoder(EfficientNetEncoder):
779-
def __init__(self, layers=[1, 2, 4, 6], **kwargs):
795+
def __init__(self, layers=None, **kwargs):
780796
super().__init__(
781797
efficient_net_b4(num_classes=1, **kwargs),
782798
[24, 32, 56, 112, 160, 272, 448],
@@ -786,7 +802,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
786802

787803

788804
class EfficientNetB5Encoder(EfficientNetEncoder):
789-
def __init__(self, layers=[1, 2, 4, 6], **kwargs):
805+
def __init__(self, layers=None, **kwargs):
790806
super().__init__(
791807
efficient_net_b5(num_classes=1, **kwargs),
792808
[24, 40, 64, 128, 176, 304, 512],
@@ -796,7 +812,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
796812

797813

798814
class EfficientNetB6Encoder(EfficientNetEncoder):
799-
def __init__(self, layers=[1, 2, 4, 6], **kwargs):
815+
def __init__(self, layers=None, **kwargs):
800816
super().__init__(
801817
efficient_net_b6(num_classes=1, **kwargs),
802818
[32, 40, 72, 144, 200, 344, 576],
@@ -806,7 +822,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
806822

807823

808824
class EfficientNetB7Encoder(EfficientNetEncoder):
809-
def __init__(self, layers=[1, 2, 4, 6], **kwargs):
825+
def __init__(self, layers=None, **kwargs):
810826
super().__init__(
811827
efficient_net_b7(num_classes=1, **kwargs),
812828
[32, 48, 80, 160, 224, 384, 640],

tests/test_losses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ def test_sigmoid_focal_loss():
99
input_bad = torch.Tensor([-1, 2, 0]).float()
1010
target = torch.Tensor([1, 0, 1])
1111

12-
loss_good = F.sigmoid_focal_loss(input_good, target)
13-
loss_bad = F.sigmoid_focal_loss(input_bad, target)
12+
loss_good = F.focal_loss_with_logits(input_good, target)
13+
loss_bad = F.focal_loss_with_logits(input_bad, target)
1414
assert loss_good < loss_bad
1515

1616

0 commit comments

Comments
 (0)