diff --git a/atomai/nets/blocks.py b/atomai/nets/blocks.py index 91872369..6f325089 100644 --- a/atomai/nets/blocks.py +++ b/atomai/nets/blocks.py @@ -20,7 +20,7 @@ class ConvBlock(nn.Module): Args: ndim: - Data dimensionality (1D or 2D) + Data dimensionality (1D, 2D, or 3D) nb_layers: Number of layers in the block input_channels: @@ -53,9 +53,9 @@ def __init__(self, Initializes module parameters """ super(ConvBlock, self).__init__() - if not 0 < ndim < 3: - raise AssertionError("ndim must be equal to 1 or 2") - conv = nn.Conv2d if ndim == 2 else nn.Conv1d + if not 0 < ndim < 4: + raise AssertionError("ndim must be equal to 1, 2, or 3") + conv = get_conv(ndim) block = [] for idx in range(nb_layers): input_channels = output_channels if idx > 0 else input_channels @@ -68,10 +68,7 @@ def __init__(self, block.append(nn.Dropout(dropout_)) block.append(nn.LeakyReLU(negative_slope=lrelu_a)) if batch_norm: - if ndim == 2: - block.append(nn.BatchNorm2d(output_channels)) - else: - block.append(nn.BatchNorm1d(output_channels)) + block.append(get_BatchNorm(ndim, output_channels)) self.block = nn.Sequential(*block) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -80,6 +77,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ output = self.block(x) return output + class UpsampleBlock(nn.Module): @@ -90,7 +88,7 @@ class UpsampleBlock(nn.Module): Args: ndim: - Data dimensionality (1D or 2D) + Data dimensionality (1D, 2D, or 3D) input_channels: Number of input channels for the block output_channels: @@ -98,7 +96,7 @@ class UpsampleBlock(nn.Module): scale_factor: Scale factor for upsampling mode: - Upsampling mode. Select between "bilinear" and "nearest" + Upsampling mode. Select between "bilinear", "nearest", and "trilinear" for 3D """ def __init__(self, ndim: int, @@ -110,14 +108,14 @@ def __init__(self, Initializes module parameters """ super(UpsampleBlock, self).__init__() - if not any([mode == 'bilinear', mode == 'nearest']): + if not any([mode == 'bilinear', mode == 'nearest', mode == 'trilinear']): raise NotImplementedError( - "use 'bilinear' or 'nearest' for upsampling mode") - if not 0 < ndim < 3: - raise AssertionError("ndim must be equal to 1 or 2") - conv = nn.Conv2d if ndim == 2 else nn.Conv1d + "use 'trilinear', 'bilinear', or 'nearest' for upsampling mode") + if not 0 < ndim < 4: + raise AssertionError("ndim must be equal to 1, 2, or 3") + conv = get_conv(ndim) self.scale_factor = scale_factor - self.mode = mode if ndim == 2 else "nearest" + self.mode = get_interpolate_mode(ndim) self.conv = conv( input_channels, output_channels, kernel_size=1, stride=1, padding=0) @@ -137,7 +135,7 @@ class ResBlock(nn.Module): Args: ndim: - Data dimensionality (1D or 2D) + Data dimensionality (1D, 2D, or 3D) nb_layers: Number of layers in the block input_channels: @@ -170,9 +168,9 @@ def __init__(self, Initializes block's parameters """ super(ResBlock, self).__init__() - if not 0 < ndim < 3: - raise AssertionError("ndim must be equal to 1 or 2") - conv = nn.Conv2d if ndim == 2 else nn.Conv1d + if not 0 < ndim < 4: + raise AssertionError("ndim must be equal to 1, 2, or 3") + conv = get_conv(ndim) self.lrelu_a = lrelu_a self.batch_norm = batch_norm self.c0 = conv(input_channels, @@ -191,9 +189,8 @@ def __init__(self, stride=1, padding=1) if batch_norm: - bn = nn.BatchNorm2d if ndim == 2 else nn.BatchNorm1d - self.bn1 = bn(output_channels) - self.bn2 = bn(output_channels) + self.bn1 = get_BatchNorm(ndim, output_channels) + self.bn2 = get_BatchNorm(ndim, output_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -218,7 +215,7 @@ class ResModule(nn.Module): Stitches multiple convolutional blocks with residual connections together Args: - ndim: Data dimensionality (1D or 2D) + ndim: Data dimensionality (1D, 2D, or 3D) res_depth: Number of residual blocks in a residual module input_channels: Number of filters in the input layer output_channels: Number of channels in the output layer @@ -260,7 +257,7 @@ class DilatedBlock(nn.Module): Args: ndim: - Data dimensionality (1D or 2D) + Data dimensionality (1D, 2D, or 3D) input_channels: Number of input channels for the block output_channels: @@ -268,7 +265,7 @@ class DilatedBlock(nn.Module): dilation_values: List of dilation rates for each convolution layer in the block (for example, dilation_values = [2, 4, 6] means that the dilated - block will 3 layers with dilation values of 2, 4, and 6). + block will have 3 layers with dilation values of 2, 4, and 6). padding_values: Edge padding for each dilated layer. The number of elements in this list should be equal to that in the dilated_values list and @@ -294,9 +291,9 @@ def __init__(self, ndim: int, input_channels: int, output_channels: int, Initializes module parameters """ super(DilatedBlock, self).__init__() - if not 0 < ndim < 3: - raise AssertionError("ndim must be equal to 1 or 2") - conv = nn.Conv2d if ndim == 2 else nn.Conv1d + if not 0 < ndim < 4: + raise AssertionError("ndim must be equal to 1, 2, or 3") + conv = get_conv(ndim) atrous_module = [] for idx, (dil, pad) in enumerate(zip(dilation_values, padding_values)): input_channels = output_channels if idx > 0 else input_channels @@ -311,10 +308,7 @@ def __init__(self, ndim: int, input_channels: int, output_channels: int, atrous_module.append(nn.Dropout(dropout_)) atrous_module.append(nn.LeakyReLU(negative_slope=lrelu_a)) if batch_norm: - if ndim == 2: - atrous_module.append(nn.BatchNorm2d(output_channels)) - else: - atrous_module.append(nn.BatchNorm1d(output_channels)) + atrous_module.append(get_BatchNorm(ndim, output_channels)) self.atrous_module = nn.Sequential(*atrous_module) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -326,3 +320,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = conv_layer(x) atrous_layers.append(x.unsqueeze(-1)) return torch.sum(torch.cat(atrous_layers, dim=-1), dim=-1) + + +def get_conv(ndim: int): + """ + Selects conv block based on dimensions + """ + conv_dict = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d} + return conv_dict[ndim] + + +def get_BatchNorm(ndim: int, output_channels: int): + """ + Selects BatchNorm block based on dimensions + """ + BatchNorm_dict = {1: nn.BatchNorm3d(output_channels), + 2: nn.BatchNorm3d(output_channels), + 3: nn.BatchNorm3d(output_channels)} + return BatchNorm_dict[ndim] + + +def get_interpolate_mode(ndim: int): + """ + Selects interpolation mode based on dimensions + """ + interpolate_mode_dict = {1: 'nearest', 2: 'bilinear', 3: 'trilinear'} + return interpolate_mode_dict[ndim] \ No newline at end of file diff --git a/atomai/nets/fcnn.py b/atomai/nets/fcnn.py index 107416f9..b2eff73c 100644 --- a/atomai/nets/fcnn.py +++ b/atomai/nets/fcnn.py @@ -32,10 +32,11 @@ class Unet(nn.Module): Use batch normalization after each convolutional layer (Default: True) upsampling_mode: - Select between "bilinear" or "nearest" upsampling method. - Bilinear is usually more accurate,but adds additional (small) - randomness. For full reproducibility, consider using 'nearest' - (this assumes that all other sources of randomness are fixed) + Select between "bilinear", "nearest", or "trilinear" upsampling + method. Bilinear is usually more accurate, but adds additional + (small) randomness. Trilinear is used for 3D data.For full + reproducibility, consider using 'nearest' (this assumes that all + other sources of randomness are fixed) with_dilation: Use dilated convolutions instead of regular ones in the bottleneck layers (Default: False) @@ -158,10 +159,11 @@ class dilnet(nn.Module): batch_norm: Add batch normalization for each convolutional layer (Default: True) upsampling_mode: - Select between "bilinear" or "nearest" upsampling method. - Bilinear is usually more accurate,but adds additional (small) - randomness. For full reproducibility, consider using 'nearest' - (this assumes that all other sources of randomness are fixed) + Select between "bilinear", "nearest", or "trilinear" upsampling + method. Bilinear is usually more accurate, but adds additional + (small) randomness. Trilinear is used for 3D data.For full + reproducibility, consider using 'nearest' (this assumes that all + other sources of randomness are fixed) **layers (list): List with a number of layers for each block (Default: [3, 3, 3, 3]) """ @@ -237,10 +239,11 @@ class ResHedNet(nn.Module): Number of filters in 1st residual block (gets multiplied by 2 in each next block) upsampling_mode: - Select between "bilinear" or "nearest" upsampling method. - Bilinear is usually more accurate,but adds additional (small) - randomness. For full reproducibility, consider using 'nearest' - (this assumes that all other sources of randomness are fixed) + Select between "bilinear", "nearest", or "trilinear" upsampling + method. Bilinear is usually more accurate, but adds additional + (small) randomness. Trilinear is used for 3D data.For full + reproducibility, consider using 'nearest' (this assumes that all + other sources of randomness are fixed) **layers (list): 3-element list with a number of residual blocks in each segment (Default: [3, 4, 5]) @@ -311,10 +314,11 @@ class SegResNet(nn.Module): Use batch normalization after each convolutional layer (Default: True) upsampling_mode: - Select between "bilinear" or "nearest" upsampling method. - Bilinear is usually more accurate,but adds additional (small) - randomness. For full reproducibility, consider using 'nearest' - (this assumes that all other sources of randomness are fixed) + Select between "bilinear", "nearest", or "trilinear" upsampling + method. Bilinear is usually more accurate, but adds additional + (small) randomness. Trilinear is used for 3D data.For full + reproducibility, consider using 'nearest' (this assumes that all + other sources of randomness are fixed) **layers (list): 3-element list with a number of residual blocks in each residual segment (Default: [2, 2]) @@ -377,7 +381,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def init_fcnn_model(model: Union[Type[nn.Module], str], - nb_classes: int, **kwargs: [bool, int, List] + nb_classes: int, **kwargs: List[bool, int, List] ) -> Type[nn.Module]: """ Initializes a fully convolutional neural network