hail.network
1import torch 2from torch import nn 3import torch.nn.functional as F 4import math 5 6 7 8class UNet(nn.Module): 9 def __init__(self, in_ch, out_ch, conditional_ch=0, num_lvs=4, base_ch=16, final_act='noact'): 10 super().__init__() 11 self.final_act = final_act 12 self.in_conv = nn.Conv2d(in_ch, base_ch, 3, 1, 1) 13 14 self.down_convs = nn.ModuleList() 15 self.down_samples = nn.ModuleList() 16 self.up_samples = nn.ModuleList() 17 self.up_convs = nn.ModuleList() 18 for lv in range(num_lvs): 19 ch = base_ch * (2 ** lv) 20 self.down_convs.append(ConvBlock2d(ch + conditional_ch, ch * 2, ch * 2)) 21 self.down_samples.append(nn.MaxPool2d(kernel_size=2, stride=2)) 22 self.up_samples.append(Upsample(ch * 4)) 23 self.up_convs.append(ConvBlock2d(ch * 4, ch * 2, ch * 2)) 24 bottleneck_ch = base_ch * (2 ** num_lvs) 25 self.bottleneck_conv = ConvBlock2d(bottleneck_ch, bottleneck_ch * 2, bottleneck_ch * 2) 26 self.out_conv = nn.Sequential(nn.Conv2d(base_ch * 2, base_ch, 3, 1, 1), 27 nn.LeakyReLU(0.1), 28 nn.Conv2d(base_ch, out_ch, 3, 1, 1)) 29 30 def forward(self, in_tensor, condition=None): 31 encoded_features = [] 32 x = self.in_conv(in_tensor) 33 for down_conv, down_sample in zip(self.down_convs, self.down_samples): 34 if condition is not None: 35 feature_dim = x.shape[-1] 36 down_conv_out = down_conv(torch.cat([x, condition.repeat(1, 1, feature_dim, feature_dim)], dim=1)) 37 else: 38 down_conv_out = down_conv(x) 39 x = down_sample(down_conv_out) 40 encoded_features.append(down_conv_out) 41 x = self.bottleneck_conv(x) 42 for encoded_feature, up_conv, up_sample in zip(reversed(encoded_features), 43 reversed(self.up_convs), 44 reversed(self.up_samples)): 45 x = up_sample(x, encoded_feature) 46 x = up_conv(x) 47 x = self.out_conv(x) 48 if self.final_act == 'sigmoid': 49 x = torch.sigmoid(x) 50 elif self.final_act == "relu": 51 x = torch.relu(x) 52 elif self.final_act == 'tanh': 53 x = torch.tanh(x) 54 else: 55 x = x 56 return x 57 58 59class ConvBlock2d(nn.Module): 60 def __init__(self, in_ch, mid_ch, out_ch): 61 super().__init__() 62 self.conv = nn.Sequential( 63 nn.Conv2d(in_ch, mid_ch, 3, 1, 1), 64 nn.InstanceNorm2d(mid_ch), 65 nn.LeakyReLU(0.1), 66 nn.Conv2d(mid_ch, out_ch, 3, 1, 1), 67 nn.InstanceNorm2d(out_ch), 68 nn.LeakyReLU(0.1) 69 ) 70 71 def forward(self, in_tensor): 72 return self.conv(in_tensor) 73 74 75class Upsample(nn.Module): 76 def __init__(self, in_ch): 77 super().__init__() 78 out_ch = in_ch // 2 79 self.conv = nn.Sequential( 80 nn.Conv2d(in_ch, out_ch, 3, 1, 1), 81 nn.InstanceNorm2d(out_ch), 82 nn.LeakyReLU(0.1) 83 ) 84 85 def forward(self, in_tensor, encoded_feature): 86 up_sampled_tensor = F.interpolate(in_tensor, size=None, scale_factor=2, mode='bilinear', align_corners=False) 87 up_sampled_tensor = self.conv(up_sampled_tensor) 88 return torch.cat([encoded_feature, up_sampled_tensor], dim=1) 89 90 91class EtaEncoder(nn.Module): 92 def __init__(self, in_ch=1, out_ch=2): 93 super().__init__() 94 self.in_conv = nn.Sequential( 95 nn.Conv2d(in_ch, 16, 5, 1, 2), # (*, 16, 224, 224) 96 nn.InstanceNorm2d(16), 97 nn.LeakyReLU(0.1), 98 nn.Conv2d(16, 64, 3, 1, 1), # (*, 64, 224, 224) 99 nn.InstanceNorm2d(64), 100 nn.LeakyReLU(0.1) 101 ) 102 self.seq = nn.Sequential( 103 nn.Conv2d(64 + in_ch, 32, 32, 32, 0), # (*, 32, 7, 7) 104 nn.InstanceNorm2d(32), 105 nn.LeakyReLU(0.1), 106 nn.Conv2d(32, out_ch, 7, 7, 0)) 107 108 def forward(self, x): 109 return self.seq(torch.cat([self.in_conv(x), x], dim=1)) 110 111 112class Patchifier(nn.Module): 113 def __init__(self, in_ch, out_ch=1): 114 super().__init__() 115 self.conv = nn.Sequential( 116 nn.Conv2d(in_ch, 64, 32, 32, 0), # (*, in_ch, 224, 224) --> (*, 64, 7, 7) 117 nn.LeakyReLU(0.1), 118 nn.Conv2d(64, out_ch, 1, 1, 0)) 119 120 def forward(self, x): 121 return self.conv(x) 122 123 124class ThetaEncoder(nn.Module): 125 def __init__(self, in_ch, out_ch): 126 super().__init__() 127 self.conv = nn.Sequential( 128 nn.Conv2d(in_ch, 32, 17, 9, 4), 129 nn.InstanceNorm2d(32), 130 nn.LeakyReLU(0.1), # (*, 32, 28, 28) 131 nn.Conv2d(32, 64, 4, 2, 1), 132 nn.InstanceNorm2d(64), 133 nn.LeakyReLU(0.1), # (*, 64, 14, 14) 134 nn.Conv2d(64, 64, 4, 2, 1), 135 nn.InstanceNorm2d(64), 136 nn.LeakyReLU(0.1)) # (* 64, 7, 7) 137 self.mean_conv = nn.Sequential( 138 nn.Conv2d(64, 32, 3, 1, 1), 139 nn.InstanceNorm2d(32), 140 nn.LeakyReLU(0.1), 141 nn.Conv2d(32, out_ch, 6, 6, 0)) 142 self.logvar_conv = nn.Sequential( 143 nn.Conv2d(64, 32, 3, 1, 1), 144 nn.InstanceNorm2d(32), 145 nn.LeakyReLU(0.1), 146 nn.Conv2d(32, out_ch, 6, 6, 0)) 147 148 def forward(self, x): 149 M = self.conv(x) 150 mu = self.mean_conv(M) 151 logvar = self.logvar_conv(M) 152 return mu, logvar 153 154 155class AttentionModule(nn.Module): 156 def __init__(self, dim, v_ch=5): 157 super().__init__() 158 self.dim = dim 159 self.v_ch = v_ch 160 self.q_fc = nn.Sequential( 161 nn.Linear(dim, 128), 162 nn.LeakyReLU(0.1), 163 nn.Linear(128, 16), 164 nn.LayerNorm(16)) 165 self.k_fc = nn.Sequential( 166 nn.Linear(dim, 128), 167 nn.LeakyReLU(0.1), 168 nn.Linear(128, 16), 169 nn.LayerNorm(16)) 170 171 self.scale = self.dim ** (-0.5) 172 173 def forward(self, q, k, v, modality_dropout=None, temperature=10.0): 174 """ 175 Attention module for optimal anatomy fusion. 176 177 ===INPUTS=== 178 * q: torch.Tensor (batch_size, feature_dim_q, num_q_patches=1) 179 Query variable. In HAIL, query is the concatenation of target \theta and target \eta. 180 * k: torch.Tensor (batch_size, feature_dim_k, num_k_patches=1, num_contrasts=4) 181 Key variable. In HAIL, keys are \theta and \eta's of source images. 182 * v: torch.Tensor (batch_size, self.v_ch=5, num_v_patches=224*224, num_contrasts=4) 183 Value variable. In HAIL, values are multi-channel logits of source images. 184 self.v_ch is the number of \beta channels. 185 * modality_dropout: torch.Tensor (batch_size, num_contrasts=4) 186 Indicates which contrast indexes have been dropped out. 1: if dropped out, 0: if exists. 187 """ 188 batch_size, feature_dim_q, num_q_patches = q.shape 189 _, feature_dim_k, _, num_contrasts = k.shape 190 num_v_patches = v.shape[2] 191 assert ( 192 feature_dim_k == feature_dim_q or feature_dim_q == self.feature_dim 193 ), 'Feature dimensions do not match.' 194 195 # q.shape: (batch_size, num_q_patches=1, 1, feature_dim_q) 196 q = q.reshape(batch_size, feature_dim_q, num_q_patches, 1).permute(0, 2, 3, 1) 197 # k.shape: (batch_size, num_k_patches=1, num_contrasts=4, feature_dim_k) 198 k = k.permute(0, 2, 3, 1) 199 # v.shape: (batch_size, num_v_patches=224*224, num_contrasts=4, v_ch=5) 200 v = v.permute(0, 2, 3, 1) 201 q = self.q_fc(q) 202 # k.shape: (batch_size, num_k_patches=1, feature_dim_k, num_contrasts=4) 203 k = self.k_fc(k).permute(0, 1, 3, 2) 204 205 # dot_prod.shape: (batch_size, num_q_patches=1, 1, num_contrasts=4) 206 dot_prod = (q @ k) * self.scale 207 interpolation_factor = int(math.sqrt(num_v_patches // num_q_patches)) 208 209 q_spatial_dim = int(math.sqrt(num_q_patches)) 210 dot_prod = dot_prod.view(batch_size, q_spatial_dim, q_spatial_dim, num_contrasts) 211 212 image_dim = int(math.sqrt(num_v_patches)) 213 # dot_prod_interp.shape: (batch_size, image_dim, image_dim, num_contrasts) 214 dot_prod_interp = dot_prod.repeat(1, interpolation_factor, interpolation_factor, 1) 215 if modality_dropout is not None: 216 modality_dropout = modality_dropout.view(batch_size, num_contrasts, 1, 1).permute(0, 2, 3, 1) 217 dot_prod_interp = dot_prod_interp - (modality_dropout.repeat(1, image_dim, image_dim, 1).detach() * 1e5) 218 219 attention = (dot_prod_interp / temperature).softmax(dim=-1) 220 v = attention.view(batch_size, num_v_patches, 1, num_contrasts) @ v 221 v = v.view(batch_size, image_dim, image_dim, self.v_ch).permute(0, 3, 1, 2) 222 attention = attention.view(batch_size, image_dim, image_dim, num_contrasts).permute(0, 3, 1, 2) 223 return v, attention
9class UNet(nn.Module): 10 def __init__(self, in_ch, out_ch, conditional_ch=0, num_lvs=4, base_ch=16, final_act='noact'): 11 super().__init__() 12 self.final_act = final_act 13 self.in_conv = nn.Conv2d(in_ch, base_ch, 3, 1, 1) 14 15 self.down_convs = nn.ModuleList() 16 self.down_samples = nn.ModuleList() 17 self.up_samples = nn.ModuleList() 18 self.up_convs = nn.ModuleList() 19 for lv in range(num_lvs): 20 ch = base_ch * (2 ** lv) 21 self.down_convs.append(ConvBlock2d(ch + conditional_ch, ch * 2, ch * 2)) 22 self.down_samples.append(nn.MaxPool2d(kernel_size=2, stride=2)) 23 self.up_samples.append(Upsample(ch * 4)) 24 self.up_convs.append(ConvBlock2d(ch * 4, ch * 2, ch * 2)) 25 bottleneck_ch = base_ch * (2 ** num_lvs) 26 self.bottleneck_conv = ConvBlock2d(bottleneck_ch, bottleneck_ch * 2, bottleneck_ch * 2) 27 self.out_conv = nn.Sequential(nn.Conv2d(base_ch * 2, base_ch, 3, 1, 1), 28 nn.LeakyReLU(0.1), 29 nn.Conv2d(base_ch, out_ch, 3, 1, 1)) 30 31 def forward(self, in_tensor, condition=None): 32 encoded_features = [] 33 x = self.in_conv(in_tensor) 34 for down_conv, down_sample in zip(self.down_convs, self.down_samples): 35 if condition is not None: 36 feature_dim = x.shape[-1] 37 down_conv_out = down_conv(torch.cat([x, condition.repeat(1, 1, feature_dim, feature_dim)], dim=1)) 38 else: 39 down_conv_out = down_conv(x) 40 x = down_sample(down_conv_out) 41 encoded_features.append(down_conv_out) 42 x = self.bottleneck_conv(x) 43 for encoded_feature, up_conv, up_sample in zip(reversed(encoded_features), 44 reversed(self.up_convs), 45 reversed(self.up_samples)): 46 x = up_sample(x, encoded_feature) 47 x = up_conv(x) 48 x = self.out_conv(x) 49 if self.final_act == 'sigmoid': 50 x = torch.sigmoid(x) 51 elif self.final_act == "relu": 52 x = torch.relu(x) 53 elif self.final_act == 'tanh': 54 x = torch.tanh(x) 55 else: 56 x = x 57 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
10 def __init__(self, in_ch, out_ch, conditional_ch=0, num_lvs=4, base_ch=16, final_act='noact'): 11 super().__init__() 12 self.final_act = final_act 13 self.in_conv = nn.Conv2d(in_ch, base_ch, 3, 1, 1) 14 15 self.down_convs = nn.ModuleList() 16 self.down_samples = nn.ModuleList() 17 self.up_samples = nn.ModuleList() 18 self.up_convs = nn.ModuleList() 19 for lv in range(num_lvs): 20 ch = base_ch * (2 ** lv) 21 self.down_convs.append(ConvBlock2d(ch + conditional_ch, ch * 2, ch * 2)) 22 self.down_samples.append(nn.MaxPool2d(kernel_size=2, stride=2)) 23 self.up_samples.append(Upsample(ch * 4)) 24 self.up_convs.append(ConvBlock2d(ch * 4, ch * 2, ch * 2)) 25 bottleneck_ch = base_ch * (2 ** num_lvs) 26 self.bottleneck_conv = ConvBlock2d(bottleneck_ch, bottleneck_ch * 2, bottleneck_ch * 2) 27 self.out_conv = nn.Sequential(nn.Conv2d(base_ch * 2, base_ch, 3, 1, 1), 28 nn.LeakyReLU(0.1), 29 nn.Conv2d(base_ch, out_ch, 3, 1, 1))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
31 def forward(self, in_tensor, condition=None): 32 encoded_features = [] 33 x = self.in_conv(in_tensor) 34 for down_conv, down_sample in zip(self.down_convs, self.down_samples): 35 if condition is not None: 36 feature_dim = x.shape[-1] 37 down_conv_out = down_conv(torch.cat([x, condition.repeat(1, 1, feature_dim, feature_dim)], dim=1)) 38 else: 39 down_conv_out = down_conv(x) 40 x = down_sample(down_conv_out) 41 encoded_features.append(down_conv_out) 42 x = self.bottleneck_conv(x) 43 for encoded_feature, up_conv, up_sample in zip(reversed(encoded_features), 44 reversed(self.up_convs), 45 reversed(self.up_samples)): 46 x = up_sample(x, encoded_feature) 47 x = up_conv(x) 48 x = self.out_conv(x) 49 if self.final_act == 'sigmoid': 50 x = torch.sigmoid(x) 51 elif self.final_act == "relu": 52 x = torch.relu(x) 53 elif self.final_act == 'tanh': 54 x = torch.tanh(x) 55 else: 56 x = x 57 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
60class ConvBlock2d(nn.Module): 61 def __init__(self, in_ch, mid_ch, out_ch): 62 super().__init__() 63 self.conv = nn.Sequential( 64 nn.Conv2d(in_ch, mid_ch, 3, 1, 1), 65 nn.InstanceNorm2d(mid_ch), 66 nn.LeakyReLU(0.1), 67 nn.Conv2d(mid_ch, out_ch, 3, 1, 1), 68 nn.InstanceNorm2d(out_ch), 69 nn.LeakyReLU(0.1) 70 ) 71 72 def forward(self, in_tensor): 73 return self.conv(in_tensor)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
61 def __init__(self, in_ch, mid_ch, out_ch): 62 super().__init__() 63 self.conv = nn.Sequential( 64 nn.Conv2d(in_ch, mid_ch, 3, 1, 1), 65 nn.InstanceNorm2d(mid_ch), 66 nn.LeakyReLU(0.1), 67 nn.Conv2d(mid_ch, out_ch, 3, 1, 1), 68 nn.InstanceNorm2d(out_ch), 69 nn.LeakyReLU(0.1) 70 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
76class Upsample(nn.Module): 77 def __init__(self, in_ch): 78 super().__init__() 79 out_ch = in_ch // 2 80 self.conv = nn.Sequential( 81 nn.Conv2d(in_ch, out_ch, 3, 1, 1), 82 nn.InstanceNorm2d(out_ch), 83 nn.LeakyReLU(0.1) 84 ) 85 86 def forward(self, in_tensor, encoded_feature): 87 up_sampled_tensor = F.interpolate(in_tensor, size=None, scale_factor=2, mode='bilinear', align_corners=False) 88 up_sampled_tensor = self.conv(up_sampled_tensor) 89 return torch.cat([encoded_feature, up_sampled_tensor], dim=1)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
77 def __init__(self, in_ch): 78 super().__init__() 79 out_ch = in_ch // 2 80 self.conv = nn.Sequential( 81 nn.Conv2d(in_ch, out_ch, 3, 1, 1), 82 nn.InstanceNorm2d(out_ch), 83 nn.LeakyReLU(0.1) 84 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
86 def forward(self, in_tensor, encoded_feature): 87 up_sampled_tensor = F.interpolate(in_tensor, size=None, scale_factor=2, mode='bilinear', align_corners=False) 88 up_sampled_tensor = self.conv(up_sampled_tensor) 89 return torch.cat([encoded_feature, up_sampled_tensor], dim=1)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
92class EtaEncoder(nn.Module): 93 def __init__(self, in_ch=1, out_ch=2): 94 super().__init__() 95 self.in_conv = nn.Sequential( 96 nn.Conv2d(in_ch, 16, 5, 1, 2), # (*, 16, 224, 224) 97 nn.InstanceNorm2d(16), 98 nn.LeakyReLU(0.1), 99 nn.Conv2d(16, 64, 3, 1, 1), # (*, 64, 224, 224) 100 nn.InstanceNorm2d(64), 101 nn.LeakyReLU(0.1) 102 ) 103 self.seq = nn.Sequential( 104 nn.Conv2d(64 + in_ch, 32, 32, 32, 0), # (*, 32, 7, 7) 105 nn.InstanceNorm2d(32), 106 nn.LeakyReLU(0.1), 107 nn.Conv2d(32, out_ch, 7, 7, 0)) 108 109 def forward(self, x): 110 return self.seq(torch.cat([self.in_conv(x), x], dim=1))
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
93 def __init__(self, in_ch=1, out_ch=2): 94 super().__init__() 95 self.in_conv = nn.Sequential( 96 nn.Conv2d(in_ch, 16, 5, 1, 2), # (*, 16, 224, 224) 97 nn.InstanceNorm2d(16), 98 nn.LeakyReLU(0.1), 99 nn.Conv2d(16, 64, 3, 1, 1), # (*, 64, 224, 224) 100 nn.InstanceNorm2d(64), 101 nn.LeakyReLU(0.1) 102 ) 103 self.seq = nn.Sequential( 104 nn.Conv2d(64 + in_ch, 32, 32, 32, 0), # (*, 32, 7, 7) 105 nn.InstanceNorm2d(32), 106 nn.LeakyReLU(0.1), 107 nn.Conv2d(32, out_ch, 7, 7, 0))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
113class Patchifier(nn.Module): 114 def __init__(self, in_ch, out_ch=1): 115 super().__init__() 116 self.conv = nn.Sequential( 117 nn.Conv2d(in_ch, 64, 32, 32, 0), # (*, in_ch, 224, 224) --> (*, 64, 7, 7) 118 nn.LeakyReLU(0.1), 119 nn.Conv2d(64, out_ch, 1, 1, 0)) 120 121 def forward(self, x): 122 return self.conv(x)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
114 def __init__(self, in_ch, out_ch=1): 115 super().__init__() 116 self.conv = nn.Sequential( 117 nn.Conv2d(in_ch, 64, 32, 32, 0), # (*, in_ch, 224, 224) --> (*, 64, 7, 7) 118 nn.LeakyReLU(0.1), 119 nn.Conv2d(64, out_ch, 1, 1, 0))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
125class ThetaEncoder(nn.Module): 126 def __init__(self, in_ch, out_ch): 127 super().__init__() 128 self.conv = nn.Sequential( 129 nn.Conv2d(in_ch, 32, 17, 9, 4), 130 nn.InstanceNorm2d(32), 131 nn.LeakyReLU(0.1), # (*, 32, 28, 28) 132 nn.Conv2d(32, 64, 4, 2, 1), 133 nn.InstanceNorm2d(64), 134 nn.LeakyReLU(0.1), # (*, 64, 14, 14) 135 nn.Conv2d(64, 64, 4, 2, 1), 136 nn.InstanceNorm2d(64), 137 nn.LeakyReLU(0.1)) # (* 64, 7, 7) 138 self.mean_conv = nn.Sequential( 139 nn.Conv2d(64, 32, 3, 1, 1), 140 nn.InstanceNorm2d(32), 141 nn.LeakyReLU(0.1), 142 nn.Conv2d(32, out_ch, 6, 6, 0)) 143 self.logvar_conv = nn.Sequential( 144 nn.Conv2d(64, 32, 3, 1, 1), 145 nn.InstanceNorm2d(32), 146 nn.LeakyReLU(0.1), 147 nn.Conv2d(32, out_ch, 6, 6, 0)) 148 149 def forward(self, x): 150 M = self.conv(x) 151 mu = self.mean_conv(M) 152 logvar = self.logvar_conv(M) 153 return mu, logvar
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
126 def __init__(self, in_ch, out_ch): 127 super().__init__() 128 self.conv = nn.Sequential( 129 nn.Conv2d(in_ch, 32, 17, 9, 4), 130 nn.InstanceNorm2d(32), 131 nn.LeakyReLU(0.1), # (*, 32, 28, 28) 132 nn.Conv2d(32, 64, 4, 2, 1), 133 nn.InstanceNorm2d(64), 134 nn.LeakyReLU(0.1), # (*, 64, 14, 14) 135 nn.Conv2d(64, 64, 4, 2, 1), 136 nn.InstanceNorm2d(64), 137 nn.LeakyReLU(0.1)) # (* 64, 7, 7) 138 self.mean_conv = nn.Sequential( 139 nn.Conv2d(64, 32, 3, 1, 1), 140 nn.InstanceNorm2d(32), 141 nn.LeakyReLU(0.1), 142 nn.Conv2d(32, out_ch, 6, 6, 0)) 143 self.logvar_conv = nn.Sequential( 144 nn.Conv2d(64, 32, 3, 1, 1), 145 nn.InstanceNorm2d(32), 146 nn.LeakyReLU(0.1), 147 nn.Conv2d(32, out_ch, 6, 6, 0))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
149 def forward(self, x): 150 M = self.conv(x) 151 mu = self.mean_conv(M) 152 logvar = self.logvar_conv(M) 153 return mu, logvar
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
156class AttentionModule(nn.Module): 157 def __init__(self, dim, v_ch=5): 158 super().__init__() 159 self.dim = dim 160 self.v_ch = v_ch 161 self.q_fc = nn.Sequential( 162 nn.Linear(dim, 128), 163 nn.LeakyReLU(0.1), 164 nn.Linear(128, 16), 165 nn.LayerNorm(16)) 166 self.k_fc = nn.Sequential( 167 nn.Linear(dim, 128), 168 nn.LeakyReLU(0.1), 169 nn.Linear(128, 16), 170 nn.LayerNorm(16)) 171 172 self.scale = self.dim ** (-0.5) 173 174 def forward(self, q, k, v, modality_dropout=None, temperature=10.0): 175 """ 176 Attention module for optimal anatomy fusion. 177 178 ===INPUTS=== 179 * q: torch.Tensor (batch_size, feature_dim_q, num_q_patches=1) 180 Query variable. In HAIL, query is the concatenation of target \theta and target \eta. 181 * k: torch.Tensor (batch_size, feature_dim_k, num_k_patches=1, num_contrasts=4) 182 Key variable. In HAIL, keys are \theta and \eta's of source images. 183 * v: torch.Tensor (batch_size, self.v_ch=5, num_v_patches=224*224, num_contrasts=4) 184 Value variable. In HAIL, values are multi-channel logits of source images. 185 self.v_ch is the number of \beta channels. 186 * modality_dropout: torch.Tensor (batch_size, num_contrasts=4) 187 Indicates which contrast indexes have been dropped out. 1: if dropped out, 0: if exists. 188 """ 189 batch_size, feature_dim_q, num_q_patches = q.shape 190 _, feature_dim_k, _, num_contrasts = k.shape 191 num_v_patches = v.shape[2] 192 assert ( 193 feature_dim_k == feature_dim_q or feature_dim_q == self.feature_dim 194 ), 'Feature dimensions do not match.' 195 196 # q.shape: (batch_size, num_q_patches=1, 1, feature_dim_q) 197 q = q.reshape(batch_size, feature_dim_q, num_q_patches, 1).permute(0, 2, 3, 1) 198 # k.shape: (batch_size, num_k_patches=1, num_contrasts=4, feature_dim_k) 199 k = k.permute(0, 2, 3, 1) 200 # v.shape: (batch_size, num_v_patches=224*224, num_contrasts=4, v_ch=5) 201 v = v.permute(0, 2, 3, 1) 202 q = self.q_fc(q) 203 # k.shape: (batch_size, num_k_patches=1, feature_dim_k, num_contrasts=4) 204 k = self.k_fc(k).permute(0, 1, 3, 2) 205 206 # dot_prod.shape: (batch_size, num_q_patches=1, 1, num_contrasts=4) 207 dot_prod = (q @ k) * self.scale 208 interpolation_factor = int(math.sqrt(num_v_patches // num_q_patches)) 209 210 q_spatial_dim = int(math.sqrt(num_q_patches)) 211 dot_prod = dot_prod.view(batch_size, q_spatial_dim, q_spatial_dim, num_contrasts) 212 213 image_dim = int(math.sqrt(num_v_patches)) 214 # dot_prod_interp.shape: (batch_size, image_dim, image_dim, num_contrasts) 215 dot_prod_interp = dot_prod.repeat(1, interpolation_factor, interpolation_factor, 1) 216 if modality_dropout is not None: 217 modality_dropout = modality_dropout.view(batch_size, num_contrasts, 1, 1).permute(0, 2, 3, 1) 218 dot_prod_interp = dot_prod_interp - (modality_dropout.repeat(1, image_dim, image_dim, 1).detach() * 1e5) 219 220 attention = (dot_prod_interp / temperature).softmax(dim=-1) 221 v = attention.view(batch_size, num_v_patches, 1, num_contrasts) @ v 222 v = v.view(batch_size, image_dim, image_dim, self.v_ch).permute(0, 3, 1, 2) 223 attention = attention.view(batch_size, image_dim, image_dim, num_contrasts).permute(0, 3, 1, 2) 224 return v, attention
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
157 def __init__(self, dim, v_ch=5): 158 super().__init__() 159 self.dim = dim 160 self.v_ch = v_ch 161 self.q_fc = nn.Sequential( 162 nn.Linear(dim, 128), 163 nn.LeakyReLU(0.1), 164 nn.Linear(128, 16), 165 nn.LayerNorm(16)) 166 self.k_fc = nn.Sequential( 167 nn.Linear(dim, 128), 168 nn.LeakyReLU(0.1), 169 nn.Linear(128, 16), 170 nn.LayerNorm(16)) 171 172 self.scale = self.dim ** (-0.5)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
174 def forward(self, q, k, v, modality_dropout=None, temperature=10.0): 175 """ 176 Attention module for optimal anatomy fusion. 177 178 ===INPUTS=== 179 * q: torch.Tensor (batch_size, feature_dim_q, num_q_patches=1) 180 Query variable. In HAIL, query is the concatenation of target \theta and target \eta. 181 * k: torch.Tensor (batch_size, feature_dim_k, num_k_patches=1, num_contrasts=4) 182 Key variable. In HAIL, keys are \theta and \eta's of source images. 183 * v: torch.Tensor (batch_size, self.v_ch=5, num_v_patches=224*224, num_contrasts=4) 184 Value variable. In HAIL, values are multi-channel logits of source images. 185 self.v_ch is the number of \beta channels. 186 * modality_dropout: torch.Tensor (batch_size, num_contrasts=4) 187 Indicates which contrast indexes have been dropped out. 1: if dropped out, 0: if exists. 188 """ 189 batch_size, feature_dim_q, num_q_patches = q.shape 190 _, feature_dim_k, _, num_contrasts = k.shape 191 num_v_patches = v.shape[2] 192 assert ( 193 feature_dim_k == feature_dim_q or feature_dim_q == self.feature_dim 194 ), 'Feature dimensions do not match.' 195 196 # q.shape: (batch_size, num_q_patches=1, 1, feature_dim_q) 197 q = q.reshape(batch_size, feature_dim_q, num_q_patches, 1).permute(0, 2, 3, 1) 198 # k.shape: (batch_size, num_k_patches=1, num_contrasts=4, feature_dim_k) 199 k = k.permute(0, 2, 3, 1) 200 # v.shape: (batch_size, num_v_patches=224*224, num_contrasts=4, v_ch=5) 201 v = v.permute(0, 2, 3, 1) 202 q = self.q_fc(q) 203 # k.shape: (batch_size, num_k_patches=1, feature_dim_k, num_contrasts=4) 204 k = self.k_fc(k).permute(0, 1, 3, 2) 205 206 # dot_prod.shape: (batch_size, num_q_patches=1, 1, num_contrasts=4) 207 dot_prod = (q @ k) * self.scale 208 interpolation_factor = int(math.sqrt(num_v_patches // num_q_patches)) 209 210 q_spatial_dim = int(math.sqrt(num_q_patches)) 211 dot_prod = dot_prod.view(batch_size, q_spatial_dim, q_spatial_dim, num_contrasts) 212 213 image_dim = int(math.sqrt(num_v_patches)) 214 # dot_prod_interp.shape: (batch_size, image_dim, image_dim, num_contrasts) 215 dot_prod_interp = dot_prod.repeat(1, interpolation_factor, interpolation_factor, 1) 216 if modality_dropout is not None: 217 modality_dropout = modality_dropout.view(batch_size, num_contrasts, 1, 1).permute(0, 2, 3, 1) 218 dot_prod_interp = dot_prod_interp - (modality_dropout.repeat(1, image_dim, image_dim, 1).detach() * 1e5) 219 220 attention = (dot_prod_interp / temperature).softmax(dim=-1) 221 v = attention.view(batch_size, num_v_patches, 1, num_contrasts) @ v 222 v = v.view(batch_size, image_dim, image_dim, self.v_ch).permute(0, 3, 1, 2) 223 attention = attention.view(batch_size, image_dim, image_dim, num_contrasts).permute(0, 3, 1, 2) 224 return v, attention
Attention module for optimal anatomy fusion.
===INPUTS===
- q: torch.Tensor (batch_size, feature_dim_q, num_q_patches=1) Query variable. In HAIL, query is the concatenation of target heta and target \eta.
- k: torch.Tensor (batch_size, feature_dim_k, num_k_patches=1, num_contrasts=4) Key variable. In HAIL, keys are heta and \eta's of source images.
- v: torch.Tensor (batch_size, self.v_ch=5, num_v_patches=224*224, num_contrasts=4) Value variable. In HAIL, values are multi-channel logits of source images. self.v_ch is the number of eta channels.
- modality_dropout: torch.Tensor (batch_size, num_contrasts=4) Indicates which contrast indexes have been dropped out. 1: if dropped out, 0: if exists.