hail.model

  1import numpy as np
  2import torch
  3import os
  4
  5from datetime import datetime
  6import nibabel as nib
  7
  8from utils import reparameterize_logit, divide_into_batches
  9
 10from network import UNet, ThetaEncoder, EtaEncoder, Patchifier, AttentionModule
 11
 12
 13class HAIL:
 14    """
 15    Harmonization Across Imaging Locations (HAIL) model.
 16    """
 17    def __init__(self, beta_dim, theta_dim, eta_dim, pretrained=None, pretrained_eta_encoder=None, gpu_id=0):
 18        self.beta_dim = beta_dim
 19        self.theta_dim = theta_dim
 20        self.eta_dim = eta_dim
 21        self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
 22        self.timestr = datetime.now().strftime("%Y%m%d-%H%M%S")
 23
 24
 25        # define networks
 26        self.beta_encoder = UNet(in_ch=1, out_ch=self.beta_dim, base_ch=8, final_act='none')
 27        self.theta_encoder = ThetaEncoder(in_ch=1, out_ch=self.theta_dim)
 28        self.eta_encoder = EtaEncoder(in_ch=1, out_ch=self.eta_dim)
 29        self.attention_module = AttentionModule(self.theta_dim + self.eta_dim, v_ch=self.beta_dim)
 30        self.decoder = UNet(in_ch=1 + self.theta_dim, out_ch=1, base_ch=16, final_act='relu')
 31        self.patchifier = Patchifier(in_ch=1, out_ch=128)
 32
 33        if pretrained_eta_encoder is not None:
 34            checkpoint_eta_encoder = torch.load(pretrained_eta_encoder, map_location=self.device)
 35            self.eta_encoder.load_state_dict(checkpoint_eta_encoder['eta_encoder'])
 36        if pretrained is not None:
 37            self.checkpoint = torch.load(pretrained, map_location=self.device)
 38            self.beta_encoder.load_state_dict(self.checkpoint['beta_encoder'])
 39            self.theta_encoder.load_state_dict(self.checkpoint['theta_encoder'])
 40            self.eta_encoder.load_state_dict(self.checkpoint['eta_encoder'])
 41            self.decoder.load_state_dict(self.checkpoint['decoder'])
 42            self.attention_module.load_state_dict(self.checkpoint['attention_module'])
 43            self.patchifier.load_state_dict(self.checkpoint['patchifier'])
 44        self.beta_encoder.to(self.device)
 45        self.theta_encoder.to(self.device)
 46        self.eta_encoder.to(self.device)
 47        self.decoder.to(self.device)
 48        self.attention_module.to(self.device)
 49        self.patchifier.to(self.device)
 50
 51    def channel_aggregation(self, beta_onehot_encode: torch.Tensor) -> torch.Tensor:
 52        """
 53        Combine multi-channel one-hot encoded beta into one channel (label-encoding).
 54
 55        args:
 56            beta_onehot_encode: torch.Tensor (batch_size, self.beta_dim, image_dim, image_dim)
 57                One-hot encoded beta variable. At each pixel location, only one channel will take value of 1,
 58                and other channels will be 0.
 59        return: 
 60
 61
 62        """
 63        
 64        batch_size, image_dim = beta_onehot_encode.shape[0], beta_onehot_encode.shape[3]
 65        value_tensor = (torch.arange(0, self.beta_dim) * 1.0).to(self.device)
 66        value_tensor = value_tensor.view(1, self.beta_dim, 1, 1).repeat(batch_size, 1, image_dim, image_dim)
 67        beta_label_encode = beta_onehot_encode * value_tensor.detach()
 68        return beta_label_encode.sum(1, keepdim=True) / self.beta_dim
 69
 70
 71    def harmonize(self, source_images, target_images, target_theta, target_eta, out_paths,
 72                  recon_orientation, norm_vals, header=None, num_batches=4) -> [torch.Tensor | None]:
 73        """
 74         The main hamronization function that harmonizes the source images to the target images.
 75
 76        Args:
 77            source_images (List[torch.Tensor]): list of source images
 78            target_images (List[torch.Tensor]): list of target images
 79            target_theta (List[torch.Tensor]): list of target theta values
 80            target_eta (List[torch.Tensor]): list of target eta values
 81            out_paths (List[Path]): list of output paths
 82            recon_orientation (str): orientation of the reconstructed image
 83            norm_vals (List[Tuple[int]]): list of normalization values for the reconstructed image
 84            header (nib.Nifti1Header): header of the input image
 85            num_batches (int): number of batches to divide the input tensor into
 86
 87        Returns:
 88            torch.Tensor: reconstructed harmonized image
 89        """
 90        if out_paths is not None:
 91            for out_path in out_paths:
 92                os.makedirs(out_path.parent, exist_ok=True)
 93            prefix = str(out_paths[0].name).split('.')[0]
 94        
 95        # set everything to an eval mode and turn off the gradient
 96        with torch.set_grad_enabled(False):
 97            self.beta_encoder.eval()
 98            self.theta_encoder.eval()
 99            self.eta_encoder.eval()
100            self.decoder.eval()
101
102            # Calculate the masks, logits, betas, and keys for the source images
103            logits, betas, keys, masks = [], [], [], []
104            for source_image in source_images:
105                source_image = source_image.unsqueeze(1)
106                source_image_batches = divide_into_batches(source_image, num_batches)
107                mask_tmp, logit_tmp, beta_tmp, key_tmp = [], [], [], []
108                for source_image_batch in source_image_batches:
109                    batch_size = source_image_batch.shape[0]
110                    source_image_batch = source_image_batch.to(self.device)
111                    mask = (source_image_batch > 1e-6) * 1.0
112                    logit = self.beta_encoder(source_image_batch)
113                    beta = self.channel_aggregation(reparameterize_logit(logit))
114                    theta_source, _ = self.theta_encoder(source_image_batch)
115                    eta_source = self.eta_encoder(source_image_batch).view(batch_size, self.eta_dim, 1, 1)
116                    mask_tmp.append(mask)
117                    logit_tmp.append(logit)
118                    beta_tmp.append(beta)
119                    key_tmp.append(torch.cat([theta_source, eta_source], dim=1))
120                masks.append(torch.cat(mask_tmp, dim=0))
121                logits.append(torch.cat(logit_tmp, dim=0))
122                betas.append(torch.cat(beta_tmp, dim=0))
123                keys.append(torch.cat(key_tmp, dim=0))
124
125            # calculate the harmonized theta and eta from the target images
126            if target_theta is None:
127                queries, thetas_target = [], []
128                for target_image in target_images:
129                    target_image = target_image.to(self.device).unsqueeze(1)
130                    theta_target, _ = self.theta_encoder(target_image)
131                    theta_target = theta_target.mean(dim=0, keepdim=True)
132                    eta_target = self.eta_encoder(target_image).mean(dim=0, keepdim=True).view(1, self.eta_dim, 1, 1)
133                    thetas_target.append(theta_target)
134                    queries.append(
135                        torch.cat([theta_target, eta_target], dim=1).view(1, self.theta_dim + self.eta_dim, 1))
136            else:
137                queries, thetas_target = [], []
138                for target_theta_tmp, target_eta_tmp in zip(target_theta, target_eta):
139                    thetas_target.append(target_theta_tmp.view(1, self.theta_dim, 1, 1).to(self.device))
140                    queries.append(torch.cat([target_theta_tmp.view(1, self.theta_dim, 1).to(self.device),
141                                              target_eta_tmp.view(1, self.eta_dim, 1).to(self.device)], dim=1))
142
143            # decode the harmonized normal val  image
144            for tid, (theta_target, query, norm_val) in enumerate(zip(thetas_target, queries, norm_vals)):
145                if out_paths is not None:
146                    out_prefix = out_paths[tid].name.replace('.nii.gz', '')
147                rec_image, beta_fusion, logit_fusion, attention = [], [], [], []
148                for batch_id in range(num_batches):
149                    keys_tmp = [divide_into_batches(ks, num_batches)[batch_id] for ks in keys]
150                    logits_tmp = [divide_into_batches(ls, num_batches)[batch_id] for ls in logits]
151                    masks_tmp = [divide_into_batches(ms, num_batches)[batch_id] for ms in masks]
152                    batch_size = keys_tmp[0].shape[0]
153                    query_tmp = query.view(1, self.theta_dim + self.eta_dim, 1).repeat(batch_size, 1, 1)
154                    k = torch.cat(keys_tmp, dim=-1).view(batch_size, self.theta_dim + self.eta_dim, 1, len(source_images))
155                    v = torch.stack(logits_tmp, dim=-1).view(batch_size, self.beta_dim, 224 * 224, len(source_images))
156                    logit_fusion_tmp, attention_tmp = self.attention_module(query_tmp, k, v, None, 5.0)
157                    beta_fusion_tmp = self.channel_aggregation(reparameterize_logit(logit_fusion_tmp))
158                    combined_map = torch.cat([beta_fusion_tmp, theta_target.repeat(batch_size, 1, 224, 224)], dim=1)
159                    rec_image_tmp = self.decoder(combined_map) * masks_tmp[0]
160
161                    rec_image.append(rec_image_tmp)
162                    beta_fusion.append(beta_fusion_tmp)
163                    logit_fusion.append(logit_fusion_tmp)
164                    attention.append(attention_tmp)
165
166                rec_image = torch.cat(rec_image, dim=0)
167                beta_fusion = torch.cat(beta_fusion, dim=0)
168                logit_fusion = torch.cat(logit_fusion, dim=0)
169                attention = torch.cat(attention, dim=0)
170
171                # save the harmonized image
172                print('recon_orient', recon_orientation)
173                if header is not None:
174                    if recon_orientation == "axial":
175                        img_save = np.array(rec_image.cpu().squeeze().permute(1, 2, 0).permute(1, 0, 2))
176                    else:
177                        raise NotImplementedError('Only axial orientation is supported')
178                    
179                    # put the image back to the original shape
180                    # put the image back to original harmonized intensity
181                    img_save = self.clipper(img_save[112 - 96:112 + 96, :, 112 - 96:112 + 96]* norm_val[0], norm_val[1])
182                    img_save = nib.Nifti1Image(img_save, None,
183                                               header)
184                    file_name = out_path.parent / f'{out_prefix}_harmonized_{recon_orientation}.nii.gz'
185                    nib.save(img_save, file_name)
186
187
188        if header is None:
189            return rec_image.cpu().squeeze()
190
191    def clipper(self, rec_image: np.ndarray, norm_val: float) -> np.ndarray:
192        """ Clip the image to the valid intensity
193
194        Args:
195            rec_image (np.ndarray): harmonized image
196            norm_val (float ): normalization value to be clipped to
197
198        Returns:
199            np.ndarray : clipped image
200        """
201        return  np.clip(rec_image, 0, norm_val)
class HAIL:
 15class HAIL:
 16    """
 17    Harmonization Across Imaging Locations (HAIL) model.
 18    """
 19    def __init__(self, beta_dim, theta_dim, eta_dim, pretrained=None, pretrained_eta_encoder=None, gpu_id=0):
 20        self.beta_dim = beta_dim
 21        self.theta_dim = theta_dim
 22        self.eta_dim = eta_dim
 23        self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
 24        self.timestr = datetime.now().strftime("%Y%m%d-%H%M%S")
 25
 26
 27        # define networks
 28        self.beta_encoder = UNet(in_ch=1, out_ch=self.beta_dim, base_ch=8, final_act='none')
 29        self.theta_encoder = ThetaEncoder(in_ch=1, out_ch=self.theta_dim)
 30        self.eta_encoder = EtaEncoder(in_ch=1, out_ch=self.eta_dim)
 31        self.attention_module = AttentionModule(self.theta_dim + self.eta_dim, v_ch=self.beta_dim)
 32        self.decoder = UNet(in_ch=1 + self.theta_dim, out_ch=1, base_ch=16, final_act='relu')
 33        self.patchifier = Patchifier(in_ch=1, out_ch=128)
 34
 35        if pretrained_eta_encoder is not None:
 36            checkpoint_eta_encoder = torch.load(pretrained_eta_encoder, map_location=self.device)
 37            self.eta_encoder.load_state_dict(checkpoint_eta_encoder['eta_encoder'])
 38        if pretrained is not None:
 39            self.checkpoint = torch.load(pretrained, map_location=self.device)
 40            self.beta_encoder.load_state_dict(self.checkpoint['beta_encoder'])
 41            self.theta_encoder.load_state_dict(self.checkpoint['theta_encoder'])
 42            self.eta_encoder.load_state_dict(self.checkpoint['eta_encoder'])
 43            self.decoder.load_state_dict(self.checkpoint['decoder'])
 44            self.attention_module.load_state_dict(self.checkpoint['attention_module'])
 45            self.patchifier.load_state_dict(self.checkpoint['patchifier'])
 46        self.beta_encoder.to(self.device)
 47        self.theta_encoder.to(self.device)
 48        self.eta_encoder.to(self.device)
 49        self.decoder.to(self.device)
 50        self.attention_module.to(self.device)
 51        self.patchifier.to(self.device)
 52
 53    def channel_aggregation(self, beta_onehot_encode: torch.Tensor) -> torch.Tensor:
 54        """
 55        Combine multi-channel one-hot encoded beta into one channel (label-encoding).
 56
 57        args:
 58            beta_onehot_encode: torch.Tensor (batch_size, self.beta_dim, image_dim, image_dim)
 59                One-hot encoded beta variable. At each pixel location, only one channel will take value of 1,
 60                and other channels will be 0.
 61        return: 
 62
 63
 64        """
 65        
 66        batch_size, image_dim = beta_onehot_encode.shape[0], beta_onehot_encode.shape[3]
 67        value_tensor = (torch.arange(0, self.beta_dim) * 1.0).to(self.device)
 68        value_tensor = value_tensor.view(1, self.beta_dim, 1, 1).repeat(batch_size, 1, image_dim, image_dim)
 69        beta_label_encode = beta_onehot_encode * value_tensor.detach()
 70        return beta_label_encode.sum(1, keepdim=True) / self.beta_dim
 71
 72
 73    def harmonize(self, source_images, target_images, target_theta, target_eta, out_paths,
 74                  recon_orientation, norm_vals, header=None, num_batches=4) -> [torch.Tensor | None]:
 75        """
 76         The main hamronization function that harmonizes the source images to the target images.
 77
 78        Args:
 79            source_images (List[torch.Tensor]): list of source images
 80            target_images (List[torch.Tensor]): list of target images
 81            target_theta (List[torch.Tensor]): list of target theta values
 82            target_eta (List[torch.Tensor]): list of target eta values
 83            out_paths (List[Path]): list of output paths
 84            recon_orientation (str): orientation of the reconstructed image
 85            norm_vals (List[Tuple[int]]): list of normalization values for the reconstructed image
 86            header (nib.Nifti1Header): header of the input image
 87            num_batches (int): number of batches to divide the input tensor into
 88
 89        Returns:
 90            torch.Tensor: reconstructed harmonized image
 91        """
 92        if out_paths is not None:
 93            for out_path in out_paths:
 94                os.makedirs(out_path.parent, exist_ok=True)
 95            prefix = str(out_paths[0].name).split('.')[0]
 96        
 97        # set everything to an eval mode and turn off the gradient
 98        with torch.set_grad_enabled(False):
 99            self.beta_encoder.eval()
100            self.theta_encoder.eval()
101            self.eta_encoder.eval()
102            self.decoder.eval()
103
104            # Calculate the masks, logits, betas, and keys for the source images
105            logits, betas, keys, masks = [], [], [], []
106            for source_image in source_images:
107                source_image = source_image.unsqueeze(1)
108                source_image_batches = divide_into_batches(source_image, num_batches)
109                mask_tmp, logit_tmp, beta_tmp, key_tmp = [], [], [], []
110                for source_image_batch in source_image_batches:
111                    batch_size = source_image_batch.shape[0]
112                    source_image_batch = source_image_batch.to(self.device)
113                    mask = (source_image_batch > 1e-6) * 1.0
114                    logit = self.beta_encoder(source_image_batch)
115                    beta = self.channel_aggregation(reparameterize_logit(logit))
116                    theta_source, _ = self.theta_encoder(source_image_batch)
117                    eta_source = self.eta_encoder(source_image_batch).view(batch_size, self.eta_dim, 1, 1)
118                    mask_tmp.append(mask)
119                    logit_tmp.append(logit)
120                    beta_tmp.append(beta)
121                    key_tmp.append(torch.cat([theta_source, eta_source], dim=1))
122                masks.append(torch.cat(mask_tmp, dim=0))
123                logits.append(torch.cat(logit_tmp, dim=0))
124                betas.append(torch.cat(beta_tmp, dim=0))
125                keys.append(torch.cat(key_tmp, dim=0))
126
127            # calculate the harmonized theta and eta from the target images
128            if target_theta is None:
129                queries, thetas_target = [], []
130                for target_image in target_images:
131                    target_image = target_image.to(self.device).unsqueeze(1)
132                    theta_target, _ = self.theta_encoder(target_image)
133                    theta_target = theta_target.mean(dim=0, keepdim=True)
134                    eta_target = self.eta_encoder(target_image).mean(dim=0, keepdim=True).view(1, self.eta_dim, 1, 1)
135                    thetas_target.append(theta_target)
136                    queries.append(
137                        torch.cat([theta_target, eta_target], dim=1).view(1, self.theta_dim + self.eta_dim, 1))
138            else:
139                queries, thetas_target = [], []
140                for target_theta_tmp, target_eta_tmp in zip(target_theta, target_eta):
141                    thetas_target.append(target_theta_tmp.view(1, self.theta_dim, 1, 1).to(self.device))
142                    queries.append(torch.cat([target_theta_tmp.view(1, self.theta_dim, 1).to(self.device),
143                                              target_eta_tmp.view(1, self.eta_dim, 1).to(self.device)], dim=1))
144
145            # decode the harmonized normal val  image
146            for tid, (theta_target, query, norm_val) in enumerate(zip(thetas_target, queries, norm_vals)):
147                if out_paths is not None:
148                    out_prefix = out_paths[tid].name.replace('.nii.gz', '')
149                rec_image, beta_fusion, logit_fusion, attention = [], [], [], []
150                for batch_id in range(num_batches):
151                    keys_tmp = [divide_into_batches(ks, num_batches)[batch_id] for ks in keys]
152                    logits_tmp = [divide_into_batches(ls, num_batches)[batch_id] for ls in logits]
153                    masks_tmp = [divide_into_batches(ms, num_batches)[batch_id] for ms in masks]
154                    batch_size = keys_tmp[0].shape[0]
155                    query_tmp = query.view(1, self.theta_dim + self.eta_dim, 1).repeat(batch_size, 1, 1)
156                    k = torch.cat(keys_tmp, dim=-1).view(batch_size, self.theta_dim + self.eta_dim, 1, len(source_images))
157                    v = torch.stack(logits_tmp, dim=-1).view(batch_size, self.beta_dim, 224 * 224, len(source_images))
158                    logit_fusion_tmp, attention_tmp = self.attention_module(query_tmp, k, v, None, 5.0)
159                    beta_fusion_tmp = self.channel_aggregation(reparameterize_logit(logit_fusion_tmp))
160                    combined_map = torch.cat([beta_fusion_tmp, theta_target.repeat(batch_size, 1, 224, 224)], dim=1)
161                    rec_image_tmp = self.decoder(combined_map) * masks_tmp[0]
162
163                    rec_image.append(rec_image_tmp)
164                    beta_fusion.append(beta_fusion_tmp)
165                    logit_fusion.append(logit_fusion_tmp)
166                    attention.append(attention_tmp)
167
168                rec_image = torch.cat(rec_image, dim=0)
169                beta_fusion = torch.cat(beta_fusion, dim=0)
170                logit_fusion = torch.cat(logit_fusion, dim=0)
171                attention = torch.cat(attention, dim=0)
172
173                # save the harmonized image
174                print('recon_orient', recon_orientation)
175                if header is not None:
176                    if recon_orientation == "axial":
177                        img_save = np.array(rec_image.cpu().squeeze().permute(1, 2, 0).permute(1, 0, 2))
178                    else:
179                        raise NotImplementedError('Only axial orientation is supported')
180                    
181                    # put the image back to the original shape
182                    # put the image back to original harmonized intensity
183                    img_save = self.clipper(img_save[112 - 96:112 + 96, :, 112 - 96:112 + 96]* norm_val[0], norm_val[1])
184                    img_save = nib.Nifti1Image(img_save, None,
185                                               header)
186                    file_name = out_path.parent / f'{out_prefix}_harmonized_{recon_orientation}.nii.gz'
187                    nib.save(img_save, file_name)
188
189
190        if header is None:
191            return rec_image.cpu().squeeze()
192
193    def clipper(self, rec_image: np.ndarray, norm_val: float) -> np.ndarray:
194        """ Clip the image to the valid intensity
195
196        Args:
197            rec_image (np.ndarray): harmonized image
198            norm_val (float ): normalization value to be clipped to
199
200        Returns:
201            np.ndarray : clipped image
202        """
203        return  np.clip(rec_image, 0, norm_val)

Harmonization Across Imaging Locations (HAIL) model.

HAIL( beta_dim, theta_dim, eta_dim, pretrained=None, pretrained_eta_encoder=None, gpu_id=0)
19    def __init__(self, beta_dim, theta_dim, eta_dim, pretrained=None, pretrained_eta_encoder=None, gpu_id=0):
20        self.beta_dim = beta_dim
21        self.theta_dim = theta_dim
22        self.eta_dim = eta_dim
23        self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
24        self.timestr = datetime.now().strftime("%Y%m%d-%H%M%S")
25
26
27        # define networks
28        self.beta_encoder = UNet(in_ch=1, out_ch=self.beta_dim, base_ch=8, final_act='none')
29        self.theta_encoder = ThetaEncoder(in_ch=1, out_ch=self.theta_dim)
30        self.eta_encoder = EtaEncoder(in_ch=1, out_ch=self.eta_dim)
31        self.attention_module = AttentionModule(self.theta_dim + self.eta_dim, v_ch=self.beta_dim)
32        self.decoder = UNet(in_ch=1 + self.theta_dim, out_ch=1, base_ch=16, final_act='relu')
33        self.patchifier = Patchifier(in_ch=1, out_ch=128)
34
35        if pretrained_eta_encoder is not None:
36            checkpoint_eta_encoder = torch.load(pretrained_eta_encoder, map_location=self.device)
37            self.eta_encoder.load_state_dict(checkpoint_eta_encoder['eta_encoder'])
38        if pretrained is not None:
39            self.checkpoint = torch.load(pretrained, map_location=self.device)
40            self.beta_encoder.load_state_dict(self.checkpoint['beta_encoder'])
41            self.theta_encoder.load_state_dict(self.checkpoint['theta_encoder'])
42            self.eta_encoder.load_state_dict(self.checkpoint['eta_encoder'])
43            self.decoder.load_state_dict(self.checkpoint['decoder'])
44            self.attention_module.load_state_dict(self.checkpoint['attention_module'])
45            self.patchifier.load_state_dict(self.checkpoint['patchifier'])
46        self.beta_encoder.to(self.device)
47        self.theta_encoder.to(self.device)
48        self.eta_encoder.to(self.device)
49        self.decoder.to(self.device)
50        self.attention_module.to(self.device)
51        self.patchifier.to(self.device)
beta_dim
theta_dim
eta_dim
device
timestr
beta_encoder
theta_encoder
eta_encoder
attention_module
decoder
patchifier
def channel_aggregation(self, beta_onehot_encode: torch.Tensor) -> torch.Tensor:
53    def channel_aggregation(self, beta_onehot_encode: torch.Tensor) -> torch.Tensor:
54        """
55        Combine multi-channel one-hot encoded beta into one channel (label-encoding).
56
57        args:
58            beta_onehot_encode: torch.Tensor (batch_size, self.beta_dim, image_dim, image_dim)
59                One-hot encoded beta variable. At each pixel location, only one channel will take value of 1,
60                and other channels will be 0.
61        return: 
62
63
64        """
65        
66        batch_size, image_dim = beta_onehot_encode.shape[0], beta_onehot_encode.shape[3]
67        value_tensor = (torch.arange(0, self.beta_dim) * 1.0).to(self.device)
68        value_tensor = value_tensor.view(1, self.beta_dim, 1, 1).repeat(batch_size, 1, image_dim, image_dim)
69        beta_label_encode = beta_onehot_encode * value_tensor.detach()
70        return beta_label_encode.sum(1, keepdim=True) / self.beta_dim

Combine multi-channel one-hot encoded beta into one channel (label-encoding).

args: beta_onehot_encode: torch.Tensor (batch_size, self.beta_dim, image_dim, image_dim) One-hot encoded beta variable. At each pixel location, only one channel will take value of 1, and other channels will be 0. return:

def harmonize( self, source_images, target_images, target_theta, target_eta, out_paths, recon_orientation, norm_vals, header=None, num_batches=4) -> [torch.Tensor | None]:
 73    def harmonize(self, source_images, target_images, target_theta, target_eta, out_paths,
 74                  recon_orientation, norm_vals, header=None, num_batches=4) -> [torch.Tensor | None]:
 75        """
 76         The main hamronization function that harmonizes the source images to the target images.
 77
 78        Args:
 79            source_images (List[torch.Tensor]): list of source images
 80            target_images (List[torch.Tensor]): list of target images
 81            target_theta (List[torch.Tensor]): list of target theta values
 82            target_eta (List[torch.Tensor]): list of target eta values
 83            out_paths (List[Path]): list of output paths
 84            recon_orientation (str): orientation of the reconstructed image
 85            norm_vals (List[Tuple[int]]): list of normalization values for the reconstructed image
 86            header (nib.Nifti1Header): header of the input image
 87            num_batches (int): number of batches to divide the input tensor into
 88
 89        Returns:
 90            torch.Tensor: reconstructed harmonized image
 91        """
 92        if out_paths is not None:
 93            for out_path in out_paths:
 94                os.makedirs(out_path.parent, exist_ok=True)
 95            prefix = str(out_paths[0].name).split('.')[0]
 96        
 97        # set everything to an eval mode and turn off the gradient
 98        with torch.set_grad_enabled(False):
 99            self.beta_encoder.eval()
100            self.theta_encoder.eval()
101            self.eta_encoder.eval()
102            self.decoder.eval()
103
104            # Calculate the masks, logits, betas, and keys for the source images
105            logits, betas, keys, masks = [], [], [], []
106            for source_image in source_images:
107                source_image = source_image.unsqueeze(1)
108                source_image_batches = divide_into_batches(source_image, num_batches)
109                mask_tmp, logit_tmp, beta_tmp, key_tmp = [], [], [], []
110                for source_image_batch in source_image_batches:
111                    batch_size = source_image_batch.shape[0]
112                    source_image_batch = source_image_batch.to(self.device)
113                    mask = (source_image_batch > 1e-6) * 1.0
114                    logit = self.beta_encoder(source_image_batch)
115                    beta = self.channel_aggregation(reparameterize_logit(logit))
116                    theta_source, _ = self.theta_encoder(source_image_batch)
117                    eta_source = self.eta_encoder(source_image_batch).view(batch_size, self.eta_dim, 1, 1)
118                    mask_tmp.append(mask)
119                    logit_tmp.append(logit)
120                    beta_tmp.append(beta)
121                    key_tmp.append(torch.cat([theta_source, eta_source], dim=1))
122                masks.append(torch.cat(mask_tmp, dim=0))
123                logits.append(torch.cat(logit_tmp, dim=0))
124                betas.append(torch.cat(beta_tmp, dim=0))
125                keys.append(torch.cat(key_tmp, dim=0))
126
127            # calculate the harmonized theta and eta from the target images
128            if target_theta is None:
129                queries, thetas_target = [], []
130                for target_image in target_images:
131                    target_image = target_image.to(self.device).unsqueeze(1)
132                    theta_target, _ = self.theta_encoder(target_image)
133                    theta_target = theta_target.mean(dim=0, keepdim=True)
134                    eta_target = self.eta_encoder(target_image).mean(dim=0, keepdim=True).view(1, self.eta_dim, 1, 1)
135                    thetas_target.append(theta_target)
136                    queries.append(
137                        torch.cat([theta_target, eta_target], dim=1).view(1, self.theta_dim + self.eta_dim, 1))
138            else:
139                queries, thetas_target = [], []
140                for target_theta_tmp, target_eta_tmp in zip(target_theta, target_eta):
141                    thetas_target.append(target_theta_tmp.view(1, self.theta_dim, 1, 1).to(self.device))
142                    queries.append(torch.cat([target_theta_tmp.view(1, self.theta_dim, 1).to(self.device),
143                                              target_eta_tmp.view(1, self.eta_dim, 1).to(self.device)], dim=1))
144
145            # decode the harmonized normal val  image
146            for tid, (theta_target, query, norm_val) in enumerate(zip(thetas_target, queries, norm_vals)):
147                if out_paths is not None:
148                    out_prefix = out_paths[tid].name.replace('.nii.gz', '')
149                rec_image, beta_fusion, logit_fusion, attention = [], [], [], []
150                for batch_id in range(num_batches):
151                    keys_tmp = [divide_into_batches(ks, num_batches)[batch_id] for ks in keys]
152                    logits_tmp = [divide_into_batches(ls, num_batches)[batch_id] for ls in logits]
153                    masks_tmp = [divide_into_batches(ms, num_batches)[batch_id] for ms in masks]
154                    batch_size = keys_tmp[0].shape[0]
155                    query_tmp = query.view(1, self.theta_dim + self.eta_dim, 1).repeat(batch_size, 1, 1)
156                    k = torch.cat(keys_tmp, dim=-1).view(batch_size, self.theta_dim + self.eta_dim, 1, len(source_images))
157                    v = torch.stack(logits_tmp, dim=-1).view(batch_size, self.beta_dim, 224 * 224, len(source_images))
158                    logit_fusion_tmp, attention_tmp = self.attention_module(query_tmp, k, v, None, 5.0)
159                    beta_fusion_tmp = self.channel_aggregation(reparameterize_logit(logit_fusion_tmp))
160                    combined_map = torch.cat([beta_fusion_tmp, theta_target.repeat(batch_size, 1, 224, 224)], dim=1)
161                    rec_image_tmp = self.decoder(combined_map) * masks_tmp[0]
162
163                    rec_image.append(rec_image_tmp)
164                    beta_fusion.append(beta_fusion_tmp)
165                    logit_fusion.append(logit_fusion_tmp)
166                    attention.append(attention_tmp)
167
168                rec_image = torch.cat(rec_image, dim=0)
169                beta_fusion = torch.cat(beta_fusion, dim=0)
170                logit_fusion = torch.cat(logit_fusion, dim=0)
171                attention = torch.cat(attention, dim=0)
172
173                # save the harmonized image
174                print('recon_orient', recon_orientation)
175                if header is not None:
176                    if recon_orientation == "axial":
177                        img_save = np.array(rec_image.cpu().squeeze().permute(1, 2, 0).permute(1, 0, 2))
178                    else:
179                        raise NotImplementedError('Only axial orientation is supported')
180                    
181                    # put the image back to the original shape
182                    # put the image back to original harmonized intensity
183                    img_save = self.clipper(img_save[112 - 96:112 + 96, :, 112 - 96:112 + 96]* norm_val[0], norm_val[1])
184                    img_save = nib.Nifti1Image(img_save, None,
185                                               header)
186                    file_name = out_path.parent / f'{out_prefix}_harmonized_{recon_orientation}.nii.gz'
187                    nib.save(img_save, file_name)
188
189
190        if header is None:
191            return rec_image.cpu().squeeze()

The main hamronization function that harmonizes the source images to the target images.

Args: source_images (List[torch.Tensor]): list of source images target_images (List[torch.Tensor]): list of target images target_theta (List[torch.Tensor]): list of target theta values target_eta (List[torch.Tensor]): list of target eta values out_paths (List[Path]): list of output paths recon_orientation (str): orientation of the reconstructed image norm_vals (List[Tuple[int]]): list of normalization values for the reconstructed image header (nib.Nifti1Header): header of the input image num_batches (int): number of batches to divide the input tensor into

Returns: torch.Tensor: reconstructed harmonized image

def clipper(self, rec_image: numpy.ndarray, norm_val: float) -> numpy.ndarray:
193    def clipper(self, rec_image: np.ndarray, norm_val: float) -> np.ndarray:
194        """ Clip the image to the valid intensity
195
196        Args:
197            rec_image (np.ndarray): harmonized image
198            norm_val (float ): normalization value to be clipped to
199
200        Returns:
201            np.ndarray : clipped image
202        """
203        return  np.clip(rec_image, 0, norm_val)

Clip the image to the valid intensity

Args: rec_image (np.ndarray): harmonized image norm_val (float ): normalization value to be clipped to

Returns: np.ndarray : clipped image