@@ -603,24 +603,34 @@ def __init__(
603603 strides : List [int ],
604604 channels : List [int ],
605605 layers : List [int ],
606+ first_avg_pool = False ,
606607 ):
608+ if layers is None :
609+ layers = [1 , 2 , 3 , 4 ]
610+
607611 super ().__init__ (channels , strides , layers )
608612
613+ def except_pool (block : nn .Module ):
614+ del block .pool
615+ return block
616+
609617 self .layer0 = nn .Sequential (
610618 densenet .features .conv0 , densenet .features .norm0 , densenet .features .relu0
611619 )
612- self .pool0 = densenet .features .pool0
620+
621+ self .avg_pool = nn .AvgPool2d (kernel_size = 2 , stride = 2 )
622+ self .pool0 = self .avg_pool if first_avg_pool else densenet .features .pool0
613623
614624 self .layer1 = nn .Sequential (
615- densenet .features .denseblock1 , densenet .features .transition1
625+ densenet .features .denseblock1 , except_pool ( densenet .features .transition1 )
616626 )
617627
618628 self .layer2 = nn .Sequential (
619- densenet .features .denseblock2 , densenet .features .transition2
629+ densenet .features .denseblock2 , except_pool ( densenet .features .transition2 )
620630 )
621631
622632 self .layer3 = nn .Sequential (
623- densenet .features .denseblock3 , densenet .features .transition3
633+ densenet .features .denseblock3 , except_pool ( densenet .features .transition3 )
624634 )
625635
626636 self .layer4 = nn .Sequential (densenet .features .denseblock4 )
@@ -650,54 +660,60 @@ def forward(self, x):
650660 if layer == self .layer0 :
651661 # Fist maxpool operator is not a part of layer0 because we want that layer0 output to have stride of 2
652662 output = self .pool0 (output )
663+ else :
664+ output = self .avg_pool (output )
665+
653666 input = output
654667
655668 # Return only features that were requested
656669 return _take (output_features , self ._layers )
657670
658671
659672class DenseNet121Encoder (DenseNetEncoder ):
660- def __init__ (self , layers = None , pretrained = True , memory_efficient = False ):
661- if layers is None :
662- layers = [ 1 , 2 , 3 , 4 ]
673+ def __init__ (
674+ self , layers = None , pretrained = True , memory_efficient = False , first_avg_pool = False
675+ ):
663676 densenet = densenet121 (pretrained = pretrained , memory_efficient = memory_efficient )
664677 strides = [2 , 4 , 8 , 16 , 32 ]
665678 channels = [64 , 128 , 256 , 512 , 1024 ]
666- super ().__init__ (densenet , strides , channels , layers )
679+ super ().__init__ (densenet , strides , channels , layers , first_avg_pool )
667680
668681
669682class DenseNet161Encoder (DenseNetEncoder ):
670- def __init__ (self , layers = None , pretrained = True , memory_efficient = False ):
671- if layers is None :
672- layers = [ 1 , 2 , 3 , 4 ]
683+ def __init__ (
684+ self , layers = None , pretrained = True , memory_efficient = False , first_avg_pool = False
685+ ):
673686 densenet = densenet161 (pretrained = pretrained , memory_efficient = memory_efficient )
674687 strides = [2 , 4 , 8 , 16 , 32 ]
675688 channels = [96 , 192 , 384 , 1056 , 2208 ]
676- super ().__init__ (densenet , strides , channels , layers )
689+ super ().__init__ (densenet , strides , channels , layers , first_avg_pool )
677690
678691
679692class DenseNet169Encoder (DenseNetEncoder ):
680- def __init__ (self , layers = None , pretrained = True , memory_efficient = False ):
681- if layers is None :
682- layers = [ 1 , 2 , 3 , 4 ]
693+ def __init__ (
694+ self , layers = None , pretrained = True , memory_efficient = False , first_avg_pool = False
695+ ):
683696 densenet = densenet169 (pretrained = pretrained , memory_efficient = memory_efficient )
684697 strides = [2 , 4 , 8 , 16 , 32 ]
685698 channels = [64 , 128 , 256 , 640 , 1664 ]
686- super ().__init__ (densenet , strides , channels , layers )
699+ super ().__init__ (densenet , strides , channels , layers , first_avg_pool )
687700
688701
689702class DenseNet201Encoder (DenseNetEncoder ):
690- def __init__ (self , layers = None , pretrained = True , memory_efficient = False ):
691- if layers is None :
692- layers = [ 1 , 2 , 3 , 4 ]
703+ def __init__ (
704+ self , layers = None , pretrained = True , memory_efficient = False , first_avg_pool = False
705+ ):
693706 densenet = densenet201 (pretrained = pretrained , memory_efficient = memory_efficient )
694707 strides = [2 , 4 , 8 , 16 , 32 ]
695708 channels = [64 , 128 , 256 , 896 , 1920 ]
696- super ().__init__ (densenet , strides , channels , layers )
709+ super ().__init__ (densenet , strides , channels , layers , first_avg_pool )
697710
698711
699712class EfficientNetEncoder (EncoderModule ):
700713 def __init__ (self , efficientnet , filters , strides , layers ):
714+ if layers is None :
715+ layers = [1 , 2 , 4 , 6 ]
716+
701717 super ().__init__ (filters , strides , layers )
702718
703719 self .stem = efficientnet .stem
@@ -736,7 +752,7 @@ def forward(self, x):
736752
737753
738754class EfficientNetB0Encoder (EfficientNetEncoder ):
739- def __init__ (self , layers = [ 1 , 2 , 4 , 6 ] , ** kwargs ):
755+ def __init__ (self , layers = None , ** kwargs ):
740756 super ().__init__ (
741757 efficient_net_b0 (num_classes = 1 , ** kwargs ),
742758 [16 , 24 , 40 , 80 , 112 , 192 , 320 ],
@@ -746,7 +762,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
746762
747763
748764class EfficientNetB1Encoder (EfficientNetEncoder ):
749- def __init__ (self , layers = [ 1 , 2 , 4 , 6 ] , ** kwargs ):
765+ def __init__ (self , layers = None , ** kwargs ):
750766 super ().__init__ (
751767 efficient_net_b1 (num_classes = 1 , ** kwargs ),
752768 [16 , 24 , 40 , 80 , 112 , 192 , 320 ],
@@ -756,7 +772,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
756772
757773
758774class EfficientNetB2Encoder (EfficientNetEncoder ):
759- def __init__ (self , layers = [ 1 , 2 , 4 , 6 ] , ** kwargs ):
775+ def __init__ (self , layers = None , ** kwargs ):
760776 super ().__init__ (
761777 efficient_net_b2 (num_classes = 1 , ** kwargs ),
762778 [16 , 24 , 48 , 88 , 120 , 208 , 352 ],
@@ -766,7 +782,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
766782
767783
768784class EfficientNetB3Encoder (EfficientNetEncoder ):
769- def __init__ (self , layers = [ 1 , 2 , 4 , 6 ] , ** kwargs ):
785+ def __init__ (self , layers = None , ** kwargs ):
770786 super ().__init__ (
771787 efficient_net_b3 (num_classes = 1 , ** kwargs ),
772788 [24 , 32 , 48 , 96 , 136 , 232 , 384 ],
@@ -776,7 +792,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
776792
777793
778794class EfficientNetB4Encoder (EfficientNetEncoder ):
779- def __init__ (self , layers = [ 1 , 2 , 4 , 6 ] , ** kwargs ):
795+ def __init__ (self , layers = None , ** kwargs ):
780796 super ().__init__ (
781797 efficient_net_b4 (num_classes = 1 , ** kwargs ),
782798 [24 , 32 , 56 , 112 , 160 , 272 , 448 ],
@@ -786,7 +802,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
786802
787803
788804class EfficientNetB5Encoder (EfficientNetEncoder ):
789- def __init__ (self , layers = [ 1 , 2 , 4 , 6 ] , ** kwargs ):
805+ def __init__ (self , layers = None , ** kwargs ):
790806 super ().__init__ (
791807 efficient_net_b5 (num_classes = 1 , ** kwargs ),
792808 [24 , 40 , 64 , 128 , 176 , 304 , 512 ],
@@ -796,7 +812,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
796812
797813
798814class EfficientNetB6Encoder (EfficientNetEncoder ):
799- def __init__ (self , layers = [ 1 , 2 , 4 , 6 ] , ** kwargs ):
815+ def __init__ (self , layers = None , ** kwargs ):
800816 super ().__init__ (
801817 efficient_net_b6 (num_classes = 1 , ** kwargs ),
802818 [32 , 40 , 72 , 144 , 200 , 344 , 576 ],
@@ -806,7 +822,7 @@ def __init__(self, layers=[1, 2, 4, 6], **kwargs):
806822
807823
808824class EfficientNetB7Encoder (EfficientNetEncoder ):
809- def __init__ (self , layers = [ 1 , 2 , 4 , 6 ] , ** kwargs ):
825+ def __init__ (self , layers = None , ** kwargs ):
810826 super ().__init__ (
811827 efficient_net_b7 (num_classes = 1 , ** kwargs ),
812828 [32 , 48 , 80 , 160 , 224 , 384 , 640 ],
0 commit comments