hail.network

  1import torch
  2from torch import nn
  3import torch.nn.functional as F
  4import math
  5
  6
  7
  8class UNet(nn.Module):
  9    def __init__(self, in_ch, out_ch, conditional_ch=0, num_lvs=4, base_ch=16, final_act='noact'):
 10        super().__init__()
 11        self.final_act = final_act
 12        self.in_conv = nn.Conv2d(in_ch, base_ch, 3, 1, 1)
 13
 14        self.down_convs = nn.ModuleList()
 15        self.down_samples = nn.ModuleList()
 16        self.up_samples = nn.ModuleList()
 17        self.up_convs = nn.ModuleList()
 18        for lv in range(num_lvs):
 19            ch = base_ch * (2 ** lv)
 20            self.down_convs.append(ConvBlock2d(ch + conditional_ch, ch * 2, ch * 2))
 21            self.down_samples.append(nn.MaxPool2d(kernel_size=2, stride=2))
 22            self.up_samples.append(Upsample(ch * 4))
 23            self.up_convs.append(ConvBlock2d(ch * 4, ch * 2, ch * 2))
 24        bottleneck_ch = base_ch * (2 ** num_lvs)
 25        self.bottleneck_conv = ConvBlock2d(bottleneck_ch, bottleneck_ch * 2, bottleneck_ch * 2)
 26        self.out_conv = nn.Sequential(nn.Conv2d(base_ch * 2, base_ch, 3, 1, 1),
 27                                      nn.LeakyReLU(0.1),
 28                                      nn.Conv2d(base_ch, out_ch, 3, 1, 1))
 29
 30    def forward(self, in_tensor, condition=None):
 31        encoded_features = []
 32        x = self.in_conv(in_tensor)
 33        for down_conv, down_sample in zip(self.down_convs, self.down_samples):
 34            if condition is not None:
 35                feature_dim = x.shape[-1]
 36                down_conv_out = down_conv(torch.cat([x, condition.repeat(1, 1, feature_dim, feature_dim)], dim=1))
 37            else:
 38                down_conv_out = down_conv(x)
 39            x = down_sample(down_conv_out)
 40            encoded_features.append(down_conv_out)
 41        x = self.bottleneck_conv(x)
 42        for encoded_feature, up_conv, up_sample in zip(reversed(encoded_features),
 43                                                       reversed(self.up_convs),
 44                                                       reversed(self.up_samples)):
 45            x = up_sample(x, encoded_feature)
 46            x = up_conv(x)
 47        x = self.out_conv(x)
 48        if self.final_act == 'sigmoid':
 49            x = torch.sigmoid(x)
 50        elif self.final_act == "relu":
 51            x = torch.relu(x)
 52        elif self.final_act == 'tanh':
 53            x = torch.tanh(x)
 54        else:
 55            x = x
 56        return x
 57
 58
 59class ConvBlock2d(nn.Module):
 60    def __init__(self, in_ch, mid_ch, out_ch):
 61        super().__init__()
 62        self.conv = nn.Sequential(
 63            nn.Conv2d(in_ch, mid_ch, 3, 1, 1),
 64            nn.InstanceNorm2d(mid_ch),
 65            nn.LeakyReLU(0.1),
 66            nn.Conv2d(mid_ch, out_ch, 3, 1, 1),
 67            nn.InstanceNorm2d(out_ch),
 68            nn.LeakyReLU(0.1)
 69        )
 70
 71    def forward(self, in_tensor):
 72        return self.conv(in_tensor)
 73
 74
 75class Upsample(nn.Module):
 76    def __init__(self, in_ch):
 77        super().__init__()
 78        out_ch = in_ch // 2
 79        self.conv = nn.Sequential(
 80            nn.Conv2d(in_ch, out_ch, 3, 1, 1),
 81            nn.InstanceNorm2d(out_ch),
 82            nn.LeakyReLU(0.1)
 83        )
 84
 85    def forward(self, in_tensor, encoded_feature):
 86        up_sampled_tensor = F.interpolate(in_tensor, size=None, scale_factor=2, mode='bilinear', align_corners=False)
 87        up_sampled_tensor = self.conv(up_sampled_tensor)
 88        return torch.cat([encoded_feature, up_sampled_tensor], dim=1)
 89
 90
 91class EtaEncoder(nn.Module):
 92    def __init__(self, in_ch=1, out_ch=2):
 93        super().__init__()
 94        self.in_conv = nn.Sequential(
 95            nn.Conv2d(in_ch, 16, 5, 1, 2),  # (*, 16, 224, 224)
 96            nn.InstanceNorm2d(16),
 97            nn.LeakyReLU(0.1),
 98            nn.Conv2d(16, 64, 3, 1, 1),  # (*, 64, 224, 224)
 99            nn.InstanceNorm2d(64),
100            nn.LeakyReLU(0.1)
101        )
102        self.seq = nn.Sequential(
103            nn.Conv2d(64 + in_ch, 32, 32, 32, 0),  # (*, 32, 7, 7)
104            nn.InstanceNorm2d(32),
105            nn.LeakyReLU(0.1),
106            nn.Conv2d(32, out_ch, 7, 7, 0))
107
108    def forward(self, x):
109        return self.seq(torch.cat([self.in_conv(x), x], dim=1))
110
111
112class Patchifier(nn.Module):
113    def __init__(self, in_ch, out_ch=1):
114        super().__init__()
115        self.conv = nn.Sequential(
116            nn.Conv2d(in_ch, 64, 32, 32, 0),  # (*, in_ch, 224, 224) --> (*, 64, 7, 7)
117            nn.LeakyReLU(0.1),
118            nn.Conv2d(64, out_ch, 1, 1, 0))
119
120    def forward(self, x):
121        return self.conv(x)
122
123
124class ThetaEncoder(nn.Module):
125    def __init__(self, in_ch, out_ch):
126        super().__init__()
127        self.conv = nn.Sequential(
128            nn.Conv2d(in_ch, 32, 17, 9, 4),
129            nn.InstanceNorm2d(32),
130            nn.LeakyReLU(0.1),  # (*, 32, 28, 28)
131            nn.Conv2d(32, 64, 4, 2, 1),
132            nn.InstanceNorm2d(64),
133            nn.LeakyReLU(0.1),  # (*, 64, 14, 14)
134            nn.Conv2d(64, 64, 4, 2, 1),
135            nn.InstanceNorm2d(64),
136            nn.LeakyReLU(0.1))  # (* 64, 7, 7)
137        self.mean_conv = nn.Sequential(
138            nn.Conv2d(64, 32, 3, 1, 1),
139            nn.InstanceNorm2d(32),
140            nn.LeakyReLU(0.1),
141            nn.Conv2d(32, out_ch, 6, 6, 0))
142        self.logvar_conv = nn.Sequential(
143            nn.Conv2d(64, 32, 3, 1, 1),
144            nn.InstanceNorm2d(32),
145            nn.LeakyReLU(0.1),
146            nn.Conv2d(32, out_ch, 6, 6, 0))
147
148    def forward(self, x):
149        M = self.conv(x)
150        mu = self.mean_conv(M)
151        logvar = self.logvar_conv(M)
152        return mu, logvar
153
154
155class AttentionModule(nn.Module):
156    def __init__(self, dim, v_ch=5):
157        super().__init__()
158        self.dim = dim
159        self.v_ch = v_ch
160        self.q_fc = nn.Sequential(
161            nn.Linear(dim, 128),
162            nn.LeakyReLU(0.1),
163            nn.Linear(128, 16),
164            nn.LayerNorm(16))
165        self.k_fc = nn.Sequential(
166            nn.Linear(dim, 128),
167            nn.LeakyReLU(0.1),
168            nn.Linear(128, 16),
169            nn.LayerNorm(16))
170
171        self.scale = self.dim ** (-0.5)
172
173    def forward(self, q, k, v, modality_dropout=None, temperature=10.0):
174        """
175        Attention module for optimal anatomy fusion.
176
177        ===INPUTS===
178        * q: torch.Tensor (batch_size, feature_dim_q, num_q_patches=1)
179            Query variable. In HAIL, query is the concatenation of target \theta and target \eta.
180        * k: torch.Tensor (batch_size, feature_dim_k, num_k_patches=1, num_contrasts=4)
181            Key variable. In HAIL, keys are \theta and \eta's of source images.
182        * v: torch.Tensor (batch_size, self.v_ch=5, num_v_patches=224*224, num_contrasts=4)
183            Value variable. In HAIL, values are multi-channel logits of source images.
184            self.v_ch is the number of \beta channels.
185        * modality_dropout: torch.Tensor (batch_size, num_contrasts=4)
186            Indicates which contrast indexes have been dropped out. 1: if dropped out, 0: if exists.
187        """
188        batch_size, feature_dim_q, num_q_patches = q.shape
189        _, feature_dim_k, _, num_contrasts = k.shape
190        num_v_patches = v.shape[2]
191        assert (
192                feature_dim_k == feature_dim_q or feature_dim_q == self.feature_dim
193        ), 'Feature dimensions do not match.'
194
195        # q.shape: (batch_size, num_q_patches=1, 1, feature_dim_q)
196        q = q.reshape(batch_size, feature_dim_q, num_q_patches, 1).permute(0, 2, 3, 1)
197        # k.shape: (batch_size, num_k_patches=1, num_contrasts=4, feature_dim_k)
198        k = k.permute(0, 2, 3, 1)
199        # v.shape: (batch_size, num_v_patches=224*224, num_contrasts=4, v_ch=5)
200        v = v.permute(0, 2, 3, 1)
201        q = self.q_fc(q)
202        # k.shape: (batch_size, num_k_patches=1, feature_dim_k, num_contrasts=4)
203        k = self.k_fc(k).permute(0, 1, 3, 2)
204
205        # dot_prod.shape: (batch_size, num_q_patches=1, 1, num_contrasts=4)
206        dot_prod = (q @ k) * self.scale
207        interpolation_factor = int(math.sqrt(num_v_patches // num_q_patches))
208
209        q_spatial_dim = int(math.sqrt(num_q_patches))
210        dot_prod = dot_prod.view(batch_size, q_spatial_dim, q_spatial_dim, num_contrasts)
211
212        image_dim = int(math.sqrt(num_v_patches))
213        # dot_prod_interp.shape: (batch_size, image_dim, image_dim, num_contrasts)
214        dot_prod_interp = dot_prod.repeat(1, interpolation_factor, interpolation_factor, 1)
215        if modality_dropout is not None:
216            modality_dropout = modality_dropout.view(batch_size, num_contrasts, 1, 1).permute(0, 2, 3, 1)
217            dot_prod_interp = dot_prod_interp - (modality_dropout.repeat(1, image_dim, image_dim, 1).detach() * 1e5)
218
219        attention = (dot_prod_interp / temperature).softmax(dim=-1)
220        v = attention.view(batch_size, num_v_patches, 1, num_contrasts) @ v
221        v = v.view(batch_size, image_dim, image_dim, self.v_ch).permute(0, 3, 1, 2)
222        attention = attention.view(batch_size, image_dim, image_dim, num_contrasts).permute(0, 3, 1, 2)
223        return v, attention
class UNet(torch.nn.modules.module.Module):
 9class UNet(nn.Module):
10    def __init__(self, in_ch, out_ch, conditional_ch=0, num_lvs=4, base_ch=16, final_act='noact'):
11        super().__init__()
12        self.final_act = final_act
13        self.in_conv = nn.Conv2d(in_ch, base_ch, 3, 1, 1)
14
15        self.down_convs = nn.ModuleList()
16        self.down_samples = nn.ModuleList()
17        self.up_samples = nn.ModuleList()
18        self.up_convs = nn.ModuleList()
19        for lv in range(num_lvs):
20            ch = base_ch * (2 ** lv)
21            self.down_convs.append(ConvBlock2d(ch + conditional_ch, ch * 2, ch * 2))
22            self.down_samples.append(nn.MaxPool2d(kernel_size=2, stride=2))
23            self.up_samples.append(Upsample(ch * 4))
24            self.up_convs.append(ConvBlock2d(ch * 4, ch * 2, ch * 2))
25        bottleneck_ch = base_ch * (2 ** num_lvs)
26        self.bottleneck_conv = ConvBlock2d(bottleneck_ch, bottleneck_ch * 2, bottleneck_ch * 2)
27        self.out_conv = nn.Sequential(nn.Conv2d(base_ch * 2, base_ch, 3, 1, 1),
28                                      nn.LeakyReLU(0.1),
29                                      nn.Conv2d(base_ch, out_ch, 3, 1, 1))
30
31    def forward(self, in_tensor, condition=None):
32        encoded_features = []
33        x = self.in_conv(in_tensor)
34        for down_conv, down_sample in zip(self.down_convs, self.down_samples):
35            if condition is not None:
36                feature_dim = x.shape[-1]
37                down_conv_out = down_conv(torch.cat([x, condition.repeat(1, 1, feature_dim, feature_dim)], dim=1))
38            else:
39                down_conv_out = down_conv(x)
40            x = down_sample(down_conv_out)
41            encoded_features.append(down_conv_out)
42        x = self.bottleneck_conv(x)
43        for encoded_feature, up_conv, up_sample in zip(reversed(encoded_features),
44                                                       reversed(self.up_convs),
45                                                       reversed(self.up_samples)):
46            x = up_sample(x, encoded_feature)
47            x = up_conv(x)
48        x = self.out_conv(x)
49        if self.final_act == 'sigmoid':
50            x = torch.sigmoid(x)
51        elif self.final_act == "relu":
52            x = torch.relu(x)
53        elif self.final_act == 'tanh':
54            x = torch.tanh(x)
55        else:
56            x = x
57        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

UNet( in_ch, out_ch, conditional_ch=0, num_lvs=4, base_ch=16, final_act='noact')
10    def __init__(self, in_ch, out_ch, conditional_ch=0, num_lvs=4, base_ch=16, final_act='noact'):
11        super().__init__()
12        self.final_act = final_act
13        self.in_conv = nn.Conv2d(in_ch, base_ch, 3, 1, 1)
14
15        self.down_convs = nn.ModuleList()
16        self.down_samples = nn.ModuleList()
17        self.up_samples = nn.ModuleList()
18        self.up_convs = nn.ModuleList()
19        for lv in range(num_lvs):
20            ch = base_ch * (2 ** lv)
21            self.down_convs.append(ConvBlock2d(ch + conditional_ch, ch * 2, ch * 2))
22            self.down_samples.append(nn.MaxPool2d(kernel_size=2, stride=2))
23            self.up_samples.append(Upsample(ch * 4))
24            self.up_convs.append(ConvBlock2d(ch * 4, ch * 2, ch * 2))
25        bottleneck_ch = base_ch * (2 ** num_lvs)
26        self.bottleneck_conv = ConvBlock2d(bottleneck_ch, bottleneck_ch * 2, bottleneck_ch * 2)
27        self.out_conv = nn.Sequential(nn.Conv2d(base_ch * 2, base_ch, 3, 1, 1),
28                                      nn.LeakyReLU(0.1),
29                                      nn.Conv2d(base_ch, out_ch, 3, 1, 1))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

final_act
in_conv
down_convs
down_samples
up_samples
up_convs
bottleneck_conv
out_conv
def forward(self, in_tensor, condition=None):
31    def forward(self, in_tensor, condition=None):
32        encoded_features = []
33        x = self.in_conv(in_tensor)
34        for down_conv, down_sample in zip(self.down_convs, self.down_samples):
35            if condition is not None:
36                feature_dim = x.shape[-1]
37                down_conv_out = down_conv(torch.cat([x, condition.repeat(1, 1, feature_dim, feature_dim)], dim=1))
38            else:
39                down_conv_out = down_conv(x)
40            x = down_sample(down_conv_out)
41            encoded_features.append(down_conv_out)
42        x = self.bottleneck_conv(x)
43        for encoded_feature, up_conv, up_sample in zip(reversed(encoded_features),
44                                                       reversed(self.up_convs),
45                                                       reversed(self.up_samples)):
46            x = up_sample(x, encoded_feature)
47            x = up_conv(x)
48        x = self.out_conv(x)
49        if self.final_act == 'sigmoid':
50            x = torch.sigmoid(x)
51        elif self.final_act == "relu":
52            x = torch.relu(x)
53        elif self.final_act == 'tanh':
54            x = torch.tanh(x)
55        else:
56            x = x
57        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ConvBlock2d(torch.nn.modules.module.Module):
60class ConvBlock2d(nn.Module):
61    def __init__(self, in_ch, mid_ch, out_ch):
62        super().__init__()
63        self.conv = nn.Sequential(
64            nn.Conv2d(in_ch, mid_ch, 3, 1, 1),
65            nn.InstanceNorm2d(mid_ch),
66            nn.LeakyReLU(0.1),
67            nn.Conv2d(mid_ch, out_ch, 3, 1, 1),
68            nn.InstanceNorm2d(out_ch),
69            nn.LeakyReLU(0.1)
70        )
71
72    def forward(self, in_tensor):
73        return self.conv(in_tensor)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ConvBlock2d(in_ch, mid_ch, out_ch)
61    def __init__(self, in_ch, mid_ch, out_ch):
62        super().__init__()
63        self.conv = nn.Sequential(
64            nn.Conv2d(in_ch, mid_ch, 3, 1, 1),
65            nn.InstanceNorm2d(mid_ch),
66            nn.LeakyReLU(0.1),
67            nn.Conv2d(mid_ch, out_ch, 3, 1, 1),
68            nn.InstanceNorm2d(out_ch),
69            nn.LeakyReLU(0.1)
70        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, in_tensor):
72    def forward(self, in_tensor):
73        return self.conv(in_tensor)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Upsample(torch.nn.modules.module.Module):
76class Upsample(nn.Module):
77    def __init__(self, in_ch):
78        super().__init__()
79        out_ch = in_ch // 2
80        self.conv = nn.Sequential(
81            nn.Conv2d(in_ch, out_ch, 3, 1, 1),
82            nn.InstanceNorm2d(out_ch),
83            nn.LeakyReLU(0.1)
84        )
85
86    def forward(self, in_tensor, encoded_feature):
87        up_sampled_tensor = F.interpolate(in_tensor, size=None, scale_factor=2, mode='bilinear', align_corners=False)
88        up_sampled_tensor = self.conv(up_sampled_tensor)
89        return torch.cat([encoded_feature, up_sampled_tensor], dim=1)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Upsample(in_ch)
77    def __init__(self, in_ch):
78        super().__init__()
79        out_ch = in_ch // 2
80        self.conv = nn.Sequential(
81            nn.Conv2d(in_ch, out_ch, 3, 1, 1),
82            nn.InstanceNorm2d(out_ch),
83            nn.LeakyReLU(0.1)
84        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, in_tensor, encoded_feature):
86    def forward(self, in_tensor, encoded_feature):
87        up_sampled_tensor = F.interpolate(in_tensor, size=None, scale_factor=2, mode='bilinear', align_corners=False)
88        up_sampled_tensor = self.conv(up_sampled_tensor)
89        return torch.cat([encoded_feature, up_sampled_tensor], dim=1)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class EtaEncoder(torch.nn.modules.module.Module):
 92class EtaEncoder(nn.Module):
 93    def __init__(self, in_ch=1, out_ch=2):
 94        super().__init__()
 95        self.in_conv = nn.Sequential(
 96            nn.Conv2d(in_ch, 16, 5, 1, 2),  # (*, 16, 224, 224)
 97            nn.InstanceNorm2d(16),
 98            nn.LeakyReLU(0.1),
 99            nn.Conv2d(16, 64, 3, 1, 1),  # (*, 64, 224, 224)
100            nn.InstanceNorm2d(64),
101            nn.LeakyReLU(0.1)
102        )
103        self.seq = nn.Sequential(
104            nn.Conv2d(64 + in_ch, 32, 32, 32, 0),  # (*, 32, 7, 7)
105            nn.InstanceNorm2d(32),
106            nn.LeakyReLU(0.1),
107            nn.Conv2d(32, out_ch, 7, 7, 0))
108
109    def forward(self, x):
110        return self.seq(torch.cat([self.in_conv(x), x], dim=1))

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

EtaEncoder(in_ch=1, out_ch=2)
 93    def __init__(self, in_ch=1, out_ch=2):
 94        super().__init__()
 95        self.in_conv = nn.Sequential(
 96            nn.Conv2d(in_ch, 16, 5, 1, 2),  # (*, 16, 224, 224)
 97            nn.InstanceNorm2d(16),
 98            nn.LeakyReLU(0.1),
 99            nn.Conv2d(16, 64, 3, 1, 1),  # (*, 64, 224, 224)
100            nn.InstanceNorm2d(64),
101            nn.LeakyReLU(0.1)
102        )
103        self.seq = nn.Sequential(
104            nn.Conv2d(64 + in_ch, 32, 32, 32, 0),  # (*, 32, 7, 7)
105            nn.InstanceNorm2d(32),
106            nn.LeakyReLU(0.1),
107            nn.Conv2d(32, out_ch, 7, 7, 0))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_conv
seq
def forward(self, x):
109    def forward(self, x):
110        return self.seq(torch.cat([self.in_conv(x), x], dim=1))

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Patchifier(torch.nn.modules.module.Module):
113class Patchifier(nn.Module):
114    def __init__(self, in_ch, out_ch=1):
115        super().__init__()
116        self.conv = nn.Sequential(
117            nn.Conv2d(in_ch, 64, 32, 32, 0),  # (*, in_ch, 224, 224) --> (*, 64, 7, 7)
118            nn.LeakyReLU(0.1),
119            nn.Conv2d(64, out_ch, 1, 1, 0))
120
121    def forward(self, x):
122        return self.conv(x)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Patchifier(in_ch, out_ch=1)
114    def __init__(self, in_ch, out_ch=1):
115        super().__init__()
116        self.conv = nn.Sequential(
117            nn.Conv2d(in_ch, 64, 32, 32, 0),  # (*, in_ch, 224, 224) --> (*, 64, 7, 7)
118            nn.LeakyReLU(0.1),
119            nn.Conv2d(64, out_ch, 1, 1, 0))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x):
121    def forward(self, x):
122        return self.conv(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ThetaEncoder(torch.nn.modules.module.Module):
125class ThetaEncoder(nn.Module):
126    def __init__(self, in_ch, out_ch):
127        super().__init__()
128        self.conv = nn.Sequential(
129            nn.Conv2d(in_ch, 32, 17, 9, 4),
130            nn.InstanceNorm2d(32),
131            nn.LeakyReLU(0.1),  # (*, 32, 28, 28)
132            nn.Conv2d(32, 64, 4, 2, 1),
133            nn.InstanceNorm2d(64),
134            nn.LeakyReLU(0.1),  # (*, 64, 14, 14)
135            nn.Conv2d(64, 64, 4, 2, 1),
136            nn.InstanceNorm2d(64),
137            nn.LeakyReLU(0.1))  # (* 64, 7, 7)
138        self.mean_conv = nn.Sequential(
139            nn.Conv2d(64, 32, 3, 1, 1),
140            nn.InstanceNorm2d(32),
141            nn.LeakyReLU(0.1),
142            nn.Conv2d(32, out_ch, 6, 6, 0))
143        self.logvar_conv = nn.Sequential(
144            nn.Conv2d(64, 32, 3, 1, 1),
145            nn.InstanceNorm2d(32),
146            nn.LeakyReLU(0.1),
147            nn.Conv2d(32, out_ch, 6, 6, 0))
148
149    def forward(self, x):
150        M = self.conv(x)
151        mu = self.mean_conv(M)
152        logvar = self.logvar_conv(M)
153        return mu, logvar

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ThetaEncoder(in_ch, out_ch)
126    def __init__(self, in_ch, out_ch):
127        super().__init__()
128        self.conv = nn.Sequential(
129            nn.Conv2d(in_ch, 32, 17, 9, 4),
130            nn.InstanceNorm2d(32),
131            nn.LeakyReLU(0.1),  # (*, 32, 28, 28)
132            nn.Conv2d(32, 64, 4, 2, 1),
133            nn.InstanceNorm2d(64),
134            nn.LeakyReLU(0.1),  # (*, 64, 14, 14)
135            nn.Conv2d(64, 64, 4, 2, 1),
136            nn.InstanceNorm2d(64),
137            nn.LeakyReLU(0.1))  # (* 64, 7, 7)
138        self.mean_conv = nn.Sequential(
139            nn.Conv2d(64, 32, 3, 1, 1),
140            nn.InstanceNorm2d(32),
141            nn.LeakyReLU(0.1),
142            nn.Conv2d(32, out_ch, 6, 6, 0))
143        self.logvar_conv = nn.Sequential(
144            nn.Conv2d(64, 32, 3, 1, 1),
145            nn.InstanceNorm2d(32),
146            nn.LeakyReLU(0.1),
147            nn.Conv2d(32, out_ch, 6, 6, 0))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
mean_conv
logvar_conv
def forward(self, x):
149    def forward(self, x):
150        M = self.conv(x)
151        mu = self.mean_conv(M)
152        logvar = self.logvar_conv(M)
153        return mu, logvar

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class AttentionModule(torch.nn.modules.module.Module):
156class AttentionModule(nn.Module):
157    def __init__(self, dim, v_ch=5):
158        super().__init__()
159        self.dim = dim
160        self.v_ch = v_ch
161        self.q_fc = nn.Sequential(
162            nn.Linear(dim, 128),
163            nn.LeakyReLU(0.1),
164            nn.Linear(128, 16),
165            nn.LayerNorm(16))
166        self.k_fc = nn.Sequential(
167            nn.Linear(dim, 128),
168            nn.LeakyReLU(0.1),
169            nn.Linear(128, 16),
170            nn.LayerNorm(16))
171
172        self.scale = self.dim ** (-0.5)
173
174    def forward(self, q, k, v, modality_dropout=None, temperature=10.0):
175        """
176        Attention module for optimal anatomy fusion.
177
178        ===INPUTS===
179        * q: torch.Tensor (batch_size, feature_dim_q, num_q_patches=1)
180            Query variable. In HAIL, query is the concatenation of target \theta and target \eta.
181        * k: torch.Tensor (batch_size, feature_dim_k, num_k_patches=1, num_contrasts=4)
182            Key variable. In HAIL, keys are \theta and \eta's of source images.
183        * v: torch.Tensor (batch_size, self.v_ch=5, num_v_patches=224*224, num_contrasts=4)
184            Value variable. In HAIL, values are multi-channel logits of source images.
185            self.v_ch is the number of \beta channels.
186        * modality_dropout: torch.Tensor (batch_size, num_contrasts=4)
187            Indicates which contrast indexes have been dropped out. 1: if dropped out, 0: if exists.
188        """
189        batch_size, feature_dim_q, num_q_patches = q.shape
190        _, feature_dim_k, _, num_contrasts = k.shape
191        num_v_patches = v.shape[2]
192        assert (
193                feature_dim_k == feature_dim_q or feature_dim_q == self.feature_dim
194        ), 'Feature dimensions do not match.'
195
196        # q.shape: (batch_size, num_q_patches=1, 1, feature_dim_q)
197        q = q.reshape(batch_size, feature_dim_q, num_q_patches, 1).permute(0, 2, 3, 1)
198        # k.shape: (batch_size, num_k_patches=1, num_contrasts=4, feature_dim_k)
199        k = k.permute(0, 2, 3, 1)
200        # v.shape: (batch_size, num_v_patches=224*224, num_contrasts=4, v_ch=5)
201        v = v.permute(0, 2, 3, 1)
202        q = self.q_fc(q)
203        # k.shape: (batch_size, num_k_patches=1, feature_dim_k, num_contrasts=4)
204        k = self.k_fc(k).permute(0, 1, 3, 2)
205
206        # dot_prod.shape: (batch_size, num_q_patches=1, 1, num_contrasts=4)
207        dot_prod = (q @ k) * self.scale
208        interpolation_factor = int(math.sqrt(num_v_patches // num_q_patches))
209
210        q_spatial_dim = int(math.sqrt(num_q_patches))
211        dot_prod = dot_prod.view(batch_size, q_spatial_dim, q_spatial_dim, num_contrasts)
212
213        image_dim = int(math.sqrt(num_v_patches))
214        # dot_prod_interp.shape: (batch_size, image_dim, image_dim, num_contrasts)
215        dot_prod_interp = dot_prod.repeat(1, interpolation_factor, interpolation_factor, 1)
216        if modality_dropout is not None:
217            modality_dropout = modality_dropout.view(batch_size, num_contrasts, 1, 1).permute(0, 2, 3, 1)
218            dot_prod_interp = dot_prod_interp - (modality_dropout.repeat(1, image_dim, image_dim, 1).detach() * 1e5)
219
220        attention = (dot_prod_interp / temperature).softmax(dim=-1)
221        v = attention.view(batch_size, num_v_patches, 1, num_contrasts) @ v
222        v = v.view(batch_size, image_dim, image_dim, self.v_ch).permute(0, 3, 1, 2)
223        attention = attention.view(batch_size, image_dim, image_dim, num_contrasts).permute(0, 3, 1, 2)
224        return v, attention

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AttentionModule(dim, v_ch=5)
157    def __init__(self, dim, v_ch=5):
158        super().__init__()
159        self.dim = dim
160        self.v_ch = v_ch
161        self.q_fc = nn.Sequential(
162            nn.Linear(dim, 128),
163            nn.LeakyReLU(0.1),
164            nn.Linear(128, 16),
165            nn.LayerNorm(16))
166        self.k_fc = nn.Sequential(
167            nn.Linear(dim, 128),
168            nn.LeakyReLU(0.1),
169            nn.Linear(128, 16),
170            nn.LayerNorm(16))
171
172        self.scale = self.dim ** (-0.5)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dim
v_ch
q_fc
k_fc
scale
def forward(self, q, k, v, modality_dropout=None, temperature=10.0):
174    def forward(self, q, k, v, modality_dropout=None, temperature=10.0):
175        """
176        Attention module for optimal anatomy fusion.
177
178        ===INPUTS===
179        * q: torch.Tensor (batch_size, feature_dim_q, num_q_patches=1)
180            Query variable. In HAIL, query is the concatenation of target \theta and target \eta.
181        * k: torch.Tensor (batch_size, feature_dim_k, num_k_patches=1, num_contrasts=4)
182            Key variable. In HAIL, keys are \theta and \eta's of source images.
183        * v: torch.Tensor (batch_size, self.v_ch=5, num_v_patches=224*224, num_contrasts=4)
184            Value variable. In HAIL, values are multi-channel logits of source images.
185            self.v_ch is the number of \beta channels.
186        * modality_dropout: torch.Tensor (batch_size, num_contrasts=4)
187            Indicates which contrast indexes have been dropped out. 1: if dropped out, 0: if exists.
188        """
189        batch_size, feature_dim_q, num_q_patches = q.shape
190        _, feature_dim_k, _, num_contrasts = k.shape
191        num_v_patches = v.shape[2]
192        assert (
193                feature_dim_k == feature_dim_q or feature_dim_q == self.feature_dim
194        ), 'Feature dimensions do not match.'
195
196        # q.shape: (batch_size, num_q_patches=1, 1, feature_dim_q)
197        q = q.reshape(batch_size, feature_dim_q, num_q_patches, 1).permute(0, 2, 3, 1)
198        # k.shape: (batch_size, num_k_patches=1, num_contrasts=4, feature_dim_k)
199        k = k.permute(0, 2, 3, 1)
200        # v.shape: (batch_size, num_v_patches=224*224, num_contrasts=4, v_ch=5)
201        v = v.permute(0, 2, 3, 1)
202        q = self.q_fc(q)
203        # k.shape: (batch_size, num_k_patches=1, feature_dim_k, num_contrasts=4)
204        k = self.k_fc(k).permute(0, 1, 3, 2)
205
206        # dot_prod.shape: (batch_size, num_q_patches=1, 1, num_contrasts=4)
207        dot_prod = (q @ k) * self.scale
208        interpolation_factor = int(math.sqrt(num_v_patches // num_q_patches))
209
210        q_spatial_dim = int(math.sqrt(num_q_patches))
211        dot_prod = dot_prod.view(batch_size, q_spatial_dim, q_spatial_dim, num_contrasts)
212
213        image_dim = int(math.sqrt(num_v_patches))
214        # dot_prod_interp.shape: (batch_size, image_dim, image_dim, num_contrasts)
215        dot_prod_interp = dot_prod.repeat(1, interpolation_factor, interpolation_factor, 1)
216        if modality_dropout is not None:
217            modality_dropout = modality_dropout.view(batch_size, num_contrasts, 1, 1).permute(0, 2, 3, 1)
218            dot_prod_interp = dot_prod_interp - (modality_dropout.repeat(1, image_dim, image_dim, 1).detach() * 1e5)
219
220        attention = (dot_prod_interp / temperature).softmax(dim=-1)
221        v = attention.view(batch_size, num_v_patches, 1, num_contrasts) @ v
222        v = v.view(batch_size, image_dim, image_dim, self.v_ch).permute(0, 3, 1, 2)
223        attention = attention.view(batch_size, image_dim, image_dim, num_contrasts).permute(0, 3, 1, 2)
224        return v, attention

Attention module for optimal anatomy fusion.

===INPUTS===

  • q: torch.Tensor (batch_size, feature_dim_q, num_q_patches=1) Query variable. In HAIL, query is the concatenation of target heta and target \eta.
  • k: torch.Tensor (batch_size, feature_dim_k, num_k_patches=1, num_contrasts=4) Key variable. In HAIL, keys are heta and \eta's of source images.
  • v: torch.Tensor (batch_size, self.v_ch=5, num_v_patches=224*224, num_contrasts=4) Value variable. In HAIL, values are multi-channel logits of source images. self.v_ch is the number of eta channels.
  • modality_dropout: torch.Tensor (batch_size, num_contrasts=4) Indicates which contrast indexes have been dropped out. 1: if dropped out, 0: if exists.