hail.model
1import numpy as np 2import torch 3import os 4 5from datetime import datetime 6import nibabel as nib 7 8from utils import reparameterize_logit, divide_into_batches 9 10from network import UNet, ThetaEncoder, EtaEncoder, Patchifier, AttentionModule 11 12 13class HAIL: 14 """ 15 Harmonization Across Imaging Locations (HAIL) model. 16 """ 17 def __init__(self, beta_dim, theta_dim, eta_dim, pretrained=None, pretrained_eta_encoder=None, gpu_id=0): 18 self.beta_dim = beta_dim 19 self.theta_dim = theta_dim 20 self.eta_dim = eta_dim 21 self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') 22 self.timestr = datetime.now().strftime("%Y%m%d-%H%M%S") 23 24 25 # define networks 26 self.beta_encoder = UNet(in_ch=1, out_ch=self.beta_dim, base_ch=8, final_act='none') 27 self.theta_encoder = ThetaEncoder(in_ch=1, out_ch=self.theta_dim) 28 self.eta_encoder = EtaEncoder(in_ch=1, out_ch=self.eta_dim) 29 self.attention_module = AttentionModule(self.theta_dim + self.eta_dim, v_ch=self.beta_dim) 30 self.decoder = UNet(in_ch=1 + self.theta_dim, out_ch=1, base_ch=16, final_act='relu') 31 self.patchifier = Patchifier(in_ch=1, out_ch=128) 32 33 if pretrained_eta_encoder is not None: 34 checkpoint_eta_encoder = torch.load(pretrained_eta_encoder, map_location=self.device) 35 self.eta_encoder.load_state_dict(checkpoint_eta_encoder['eta_encoder']) 36 if pretrained is not None: 37 self.checkpoint = torch.load(pretrained, map_location=self.device) 38 self.beta_encoder.load_state_dict(self.checkpoint['beta_encoder']) 39 self.theta_encoder.load_state_dict(self.checkpoint['theta_encoder']) 40 self.eta_encoder.load_state_dict(self.checkpoint['eta_encoder']) 41 self.decoder.load_state_dict(self.checkpoint['decoder']) 42 self.attention_module.load_state_dict(self.checkpoint['attention_module']) 43 self.patchifier.load_state_dict(self.checkpoint['patchifier']) 44 self.beta_encoder.to(self.device) 45 self.theta_encoder.to(self.device) 46 self.eta_encoder.to(self.device) 47 self.decoder.to(self.device) 48 self.attention_module.to(self.device) 49 self.patchifier.to(self.device) 50 51 def channel_aggregation(self, beta_onehot_encode: torch.Tensor) -> torch.Tensor: 52 """ 53 Combine multi-channel one-hot encoded beta into one channel (label-encoding). 54 55 args: 56 beta_onehot_encode: torch.Tensor (batch_size, self.beta_dim, image_dim, image_dim) 57 One-hot encoded beta variable. At each pixel location, only one channel will take value of 1, 58 and other channels will be 0. 59 return: 60 61 62 """ 63 64 batch_size, image_dim = beta_onehot_encode.shape[0], beta_onehot_encode.shape[3] 65 value_tensor = (torch.arange(0, self.beta_dim) * 1.0).to(self.device) 66 value_tensor = value_tensor.view(1, self.beta_dim, 1, 1).repeat(batch_size, 1, image_dim, image_dim) 67 beta_label_encode = beta_onehot_encode * value_tensor.detach() 68 return beta_label_encode.sum(1, keepdim=True) / self.beta_dim 69 70 71 def harmonize(self, source_images, target_images, target_theta, target_eta, out_paths, 72 recon_orientation, norm_vals, header=None, num_batches=4) -> [torch.Tensor | None]: 73 """ 74 The main hamronization function that harmonizes the source images to the target images. 75 76 Args: 77 source_images (List[torch.Tensor]): list of source images 78 target_images (List[torch.Tensor]): list of target images 79 target_theta (List[torch.Tensor]): list of target theta values 80 target_eta (List[torch.Tensor]): list of target eta values 81 out_paths (List[Path]): list of output paths 82 recon_orientation (str): orientation of the reconstructed image 83 norm_vals (List[Tuple[int]]): list of normalization values for the reconstructed image 84 header (nib.Nifti1Header): header of the input image 85 num_batches (int): number of batches to divide the input tensor into 86 87 Returns: 88 torch.Tensor: reconstructed harmonized image 89 """ 90 if out_paths is not None: 91 for out_path in out_paths: 92 os.makedirs(out_path.parent, exist_ok=True) 93 prefix = str(out_paths[0].name).split('.')[0] 94 95 # set everything to an eval mode and turn off the gradient 96 with torch.set_grad_enabled(False): 97 self.beta_encoder.eval() 98 self.theta_encoder.eval() 99 self.eta_encoder.eval() 100 self.decoder.eval() 101 102 # Calculate the masks, logits, betas, and keys for the source images 103 logits, betas, keys, masks = [], [], [], [] 104 for source_image in source_images: 105 source_image = source_image.unsqueeze(1) 106 source_image_batches = divide_into_batches(source_image, num_batches) 107 mask_tmp, logit_tmp, beta_tmp, key_tmp = [], [], [], [] 108 for source_image_batch in source_image_batches: 109 batch_size = source_image_batch.shape[0] 110 source_image_batch = source_image_batch.to(self.device) 111 mask = (source_image_batch > 1e-6) * 1.0 112 logit = self.beta_encoder(source_image_batch) 113 beta = self.channel_aggregation(reparameterize_logit(logit)) 114 theta_source, _ = self.theta_encoder(source_image_batch) 115 eta_source = self.eta_encoder(source_image_batch).view(batch_size, self.eta_dim, 1, 1) 116 mask_tmp.append(mask) 117 logit_tmp.append(logit) 118 beta_tmp.append(beta) 119 key_tmp.append(torch.cat([theta_source, eta_source], dim=1)) 120 masks.append(torch.cat(mask_tmp, dim=0)) 121 logits.append(torch.cat(logit_tmp, dim=0)) 122 betas.append(torch.cat(beta_tmp, dim=0)) 123 keys.append(torch.cat(key_tmp, dim=0)) 124 125 # calculate the harmonized theta and eta from the target images 126 if target_theta is None: 127 queries, thetas_target = [], [] 128 for target_image in target_images: 129 target_image = target_image.to(self.device).unsqueeze(1) 130 theta_target, _ = self.theta_encoder(target_image) 131 theta_target = theta_target.mean(dim=0, keepdim=True) 132 eta_target = self.eta_encoder(target_image).mean(dim=0, keepdim=True).view(1, self.eta_dim, 1, 1) 133 thetas_target.append(theta_target) 134 queries.append( 135 torch.cat([theta_target, eta_target], dim=1).view(1, self.theta_dim + self.eta_dim, 1)) 136 else: 137 queries, thetas_target = [], [] 138 for target_theta_tmp, target_eta_tmp in zip(target_theta, target_eta): 139 thetas_target.append(target_theta_tmp.view(1, self.theta_dim, 1, 1).to(self.device)) 140 queries.append(torch.cat([target_theta_tmp.view(1, self.theta_dim, 1).to(self.device), 141 target_eta_tmp.view(1, self.eta_dim, 1).to(self.device)], dim=1)) 142 143 # decode the harmonized normal val image 144 for tid, (theta_target, query, norm_val) in enumerate(zip(thetas_target, queries, norm_vals)): 145 if out_paths is not None: 146 out_prefix = out_paths[tid].name.replace('.nii.gz', '') 147 rec_image, beta_fusion, logit_fusion, attention = [], [], [], [] 148 for batch_id in range(num_batches): 149 keys_tmp = [divide_into_batches(ks, num_batches)[batch_id] for ks in keys] 150 logits_tmp = [divide_into_batches(ls, num_batches)[batch_id] for ls in logits] 151 masks_tmp = [divide_into_batches(ms, num_batches)[batch_id] for ms in masks] 152 batch_size = keys_tmp[0].shape[0] 153 query_tmp = query.view(1, self.theta_dim + self.eta_dim, 1).repeat(batch_size, 1, 1) 154 k = torch.cat(keys_tmp, dim=-1).view(batch_size, self.theta_dim + self.eta_dim, 1, len(source_images)) 155 v = torch.stack(logits_tmp, dim=-1).view(batch_size, self.beta_dim, 224 * 224, len(source_images)) 156 logit_fusion_tmp, attention_tmp = self.attention_module(query_tmp, k, v, None, 5.0) 157 beta_fusion_tmp = self.channel_aggregation(reparameterize_logit(logit_fusion_tmp)) 158 combined_map = torch.cat([beta_fusion_tmp, theta_target.repeat(batch_size, 1, 224, 224)], dim=1) 159 rec_image_tmp = self.decoder(combined_map) * masks_tmp[0] 160 161 rec_image.append(rec_image_tmp) 162 beta_fusion.append(beta_fusion_tmp) 163 logit_fusion.append(logit_fusion_tmp) 164 attention.append(attention_tmp) 165 166 rec_image = torch.cat(rec_image, dim=0) 167 beta_fusion = torch.cat(beta_fusion, dim=0) 168 logit_fusion = torch.cat(logit_fusion, dim=0) 169 attention = torch.cat(attention, dim=0) 170 171 # save the harmonized image 172 print('recon_orient', recon_orientation) 173 if header is not None: 174 if recon_orientation == "axial": 175 img_save = np.array(rec_image.cpu().squeeze().permute(1, 2, 0).permute(1, 0, 2)) 176 else: 177 raise NotImplementedError('Only axial orientation is supported') 178 179 # put the image back to the original shape 180 # put the image back to original harmonized intensity 181 img_save = self.clipper(img_save[112 - 96:112 + 96, :, 112 - 96:112 + 96]* norm_val[0], norm_val[1]) 182 img_save = nib.Nifti1Image(img_save, None, 183 header) 184 file_name = out_path.parent / f'{out_prefix}_harmonized_{recon_orientation}.nii.gz' 185 nib.save(img_save, file_name) 186 187 188 if header is None: 189 return rec_image.cpu().squeeze() 190 191 def clipper(self, rec_image: np.ndarray, norm_val: float) -> np.ndarray: 192 """ Clip the image to the valid intensity 193 194 Args: 195 rec_image (np.ndarray): harmonized image 196 norm_val (float ): normalization value to be clipped to 197 198 Returns: 199 np.ndarray : clipped image 200 """ 201 return np.clip(rec_image, 0, norm_val)
15class HAIL: 16 """ 17 Harmonization Across Imaging Locations (HAIL) model. 18 """ 19 def __init__(self, beta_dim, theta_dim, eta_dim, pretrained=None, pretrained_eta_encoder=None, gpu_id=0): 20 self.beta_dim = beta_dim 21 self.theta_dim = theta_dim 22 self.eta_dim = eta_dim 23 self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') 24 self.timestr = datetime.now().strftime("%Y%m%d-%H%M%S") 25 26 27 # define networks 28 self.beta_encoder = UNet(in_ch=1, out_ch=self.beta_dim, base_ch=8, final_act='none') 29 self.theta_encoder = ThetaEncoder(in_ch=1, out_ch=self.theta_dim) 30 self.eta_encoder = EtaEncoder(in_ch=1, out_ch=self.eta_dim) 31 self.attention_module = AttentionModule(self.theta_dim + self.eta_dim, v_ch=self.beta_dim) 32 self.decoder = UNet(in_ch=1 + self.theta_dim, out_ch=1, base_ch=16, final_act='relu') 33 self.patchifier = Patchifier(in_ch=1, out_ch=128) 34 35 if pretrained_eta_encoder is not None: 36 checkpoint_eta_encoder = torch.load(pretrained_eta_encoder, map_location=self.device) 37 self.eta_encoder.load_state_dict(checkpoint_eta_encoder['eta_encoder']) 38 if pretrained is not None: 39 self.checkpoint = torch.load(pretrained, map_location=self.device) 40 self.beta_encoder.load_state_dict(self.checkpoint['beta_encoder']) 41 self.theta_encoder.load_state_dict(self.checkpoint['theta_encoder']) 42 self.eta_encoder.load_state_dict(self.checkpoint['eta_encoder']) 43 self.decoder.load_state_dict(self.checkpoint['decoder']) 44 self.attention_module.load_state_dict(self.checkpoint['attention_module']) 45 self.patchifier.load_state_dict(self.checkpoint['patchifier']) 46 self.beta_encoder.to(self.device) 47 self.theta_encoder.to(self.device) 48 self.eta_encoder.to(self.device) 49 self.decoder.to(self.device) 50 self.attention_module.to(self.device) 51 self.patchifier.to(self.device) 52 53 def channel_aggregation(self, beta_onehot_encode: torch.Tensor) -> torch.Tensor: 54 """ 55 Combine multi-channel one-hot encoded beta into one channel (label-encoding). 56 57 args: 58 beta_onehot_encode: torch.Tensor (batch_size, self.beta_dim, image_dim, image_dim) 59 One-hot encoded beta variable. At each pixel location, only one channel will take value of 1, 60 and other channels will be 0. 61 return: 62 63 64 """ 65 66 batch_size, image_dim = beta_onehot_encode.shape[0], beta_onehot_encode.shape[3] 67 value_tensor = (torch.arange(0, self.beta_dim) * 1.0).to(self.device) 68 value_tensor = value_tensor.view(1, self.beta_dim, 1, 1).repeat(batch_size, 1, image_dim, image_dim) 69 beta_label_encode = beta_onehot_encode * value_tensor.detach() 70 return beta_label_encode.sum(1, keepdim=True) / self.beta_dim 71 72 73 def harmonize(self, source_images, target_images, target_theta, target_eta, out_paths, 74 recon_orientation, norm_vals, header=None, num_batches=4) -> [torch.Tensor | None]: 75 """ 76 The main hamronization function that harmonizes the source images to the target images. 77 78 Args: 79 source_images (List[torch.Tensor]): list of source images 80 target_images (List[torch.Tensor]): list of target images 81 target_theta (List[torch.Tensor]): list of target theta values 82 target_eta (List[torch.Tensor]): list of target eta values 83 out_paths (List[Path]): list of output paths 84 recon_orientation (str): orientation of the reconstructed image 85 norm_vals (List[Tuple[int]]): list of normalization values for the reconstructed image 86 header (nib.Nifti1Header): header of the input image 87 num_batches (int): number of batches to divide the input tensor into 88 89 Returns: 90 torch.Tensor: reconstructed harmonized image 91 """ 92 if out_paths is not None: 93 for out_path in out_paths: 94 os.makedirs(out_path.parent, exist_ok=True) 95 prefix = str(out_paths[0].name).split('.')[0] 96 97 # set everything to an eval mode and turn off the gradient 98 with torch.set_grad_enabled(False): 99 self.beta_encoder.eval() 100 self.theta_encoder.eval() 101 self.eta_encoder.eval() 102 self.decoder.eval() 103 104 # Calculate the masks, logits, betas, and keys for the source images 105 logits, betas, keys, masks = [], [], [], [] 106 for source_image in source_images: 107 source_image = source_image.unsqueeze(1) 108 source_image_batches = divide_into_batches(source_image, num_batches) 109 mask_tmp, logit_tmp, beta_tmp, key_tmp = [], [], [], [] 110 for source_image_batch in source_image_batches: 111 batch_size = source_image_batch.shape[0] 112 source_image_batch = source_image_batch.to(self.device) 113 mask = (source_image_batch > 1e-6) * 1.0 114 logit = self.beta_encoder(source_image_batch) 115 beta = self.channel_aggregation(reparameterize_logit(logit)) 116 theta_source, _ = self.theta_encoder(source_image_batch) 117 eta_source = self.eta_encoder(source_image_batch).view(batch_size, self.eta_dim, 1, 1) 118 mask_tmp.append(mask) 119 logit_tmp.append(logit) 120 beta_tmp.append(beta) 121 key_tmp.append(torch.cat([theta_source, eta_source], dim=1)) 122 masks.append(torch.cat(mask_tmp, dim=0)) 123 logits.append(torch.cat(logit_tmp, dim=0)) 124 betas.append(torch.cat(beta_tmp, dim=0)) 125 keys.append(torch.cat(key_tmp, dim=0)) 126 127 # calculate the harmonized theta and eta from the target images 128 if target_theta is None: 129 queries, thetas_target = [], [] 130 for target_image in target_images: 131 target_image = target_image.to(self.device).unsqueeze(1) 132 theta_target, _ = self.theta_encoder(target_image) 133 theta_target = theta_target.mean(dim=0, keepdim=True) 134 eta_target = self.eta_encoder(target_image).mean(dim=0, keepdim=True).view(1, self.eta_dim, 1, 1) 135 thetas_target.append(theta_target) 136 queries.append( 137 torch.cat([theta_target, eta_target], dim=1).view(1, self.theta_dim + self.eta_dim, 1)) 138 else: 139 queries, thetas_target = [], [] 140 for target_theta_tmp, target_eta_tmp in zip(target_theta, target_eta): 141 thetas_target.append(target_theta_tmp.view(1, self.theta_dim, 1, 1).to(self.device)) 142 queries.append(torch.cat([target_theta_tmp.view(1, self.theta_dim, 1).to(self.device), 143 target_eta_tmp.view(1, self.eta_dim, 1).to(self.device)], dim=1)) 144 145 # decode the harmonized normal val image 146 for tid, (theta_target, query, norm_val) in enumerate(zip(thetas_target, queries, norm_vals)): 147 if out_paths is not None: 148 out_prefix = out_paths[tid].name.replace('.nii.gz', '') 149 rec_image, beta_fusion, logit_fusion, attention = [], [], [], [] 150 for batch_id in range(num_batches): 151 keys_tmp = [divide_into_batches(ks, num_batches)[batch_id] for ks in keys] 152 logits_tmp = [divide_into_batches(ls, num_batches)[batch_id] for ls in logits] 153 masks_tmp = [divide_into_batches(ms, num_batches)[batch_id] for ms in masks] 154 batch_size = keys_tmp[0].shape[0] 155 query_tmp = query.view(1, self.theta_dim + self.eta_dim, 1).repeat(batch_size, 1, 1) 156 k = torch.cat(keys_tmp, dim=-1).view(batch_size, self.theta_dim + self.eta_dim, 1, len(source_images)) 157 v = torch.stack(logits_tmp, dim=-1).view(batch_size, self.beta_dim, 224 * 224, len(source_images)) 158 logit_fusion_tmp, attention_tmp = self.attention_module(query_tmp, k, v, None, 5.0) 159 beta_fusion_tmp = self.channel_aggregation(reparameterize_logit(logit_fusion_tmp)) 160 combined_map = torch.cat([beta_fusion_tmp, theta_target.repeat(batch_size, 1, 224, 224)], dim=1) 161 rec_image_tmp = self.decoder(combined_map) * masks_tmp[0] 162 163 rec_image.append(rec_image_tmp) 164 beta_fusion.append(beta_fusion_tmp) 165 logit_fusion.append(logit_fusion_tmp) 166 attention.append(attention_tmp) 167 168 rec_image = torch.cat(rec_image, dim=0) 169 beta_fusion = torch.cat(beta_fusion, dim=0) 170 logit_fusion = torch.cat(logit_fusion, dim=0) 171 attention = torch.cat(attention, dim=0) 172 173 # save the harmonized image 174 print('recon_orient', recon_orientation) 175 if header is not None: 176 if recon_orientation == "axial": 177 img_save = np.array(rec_image.cpu().squeeze().permute(1, 2, 0).permute(1, 0, 2)) 178 else: 179 raise NotImplementedError('Only axial orientation is supported') 180 181 # put the image back to the original shape 182 # put the image back to original harmonized intensity 183 img_save = self.clipper(img_save[112 - 96:112 + 96, :, 112 - 96:112 + 96]* norm_val[0], norm_val[1]) 184 img_save = nib.Nifti1Image(img_save, None, 185 header) 186 file_name = out_path.parent / f'{out_prefix}_harmonized_{recon_orientation}.nii.gz' 187 nib.save(img_save, file_name) 188 189 190 if header is None: 191 return rec_image.cpu().squeeze() 192 193 def clipper(self, rec_image: np.ndarray, norm_val: float) -> np.ndarray: 194 """ Clip the image to the valid intensity 195 196 Args: 197 rec_image (np.ndarray): harmonized image 198 norm_val (float ): normalization value to be clipped to 199 200 Returns: 201 np.ndarray : clipped image 202 """ 203 return np.clip(rec_image, 0, norm_val)
Harmonization Across Imaging Locations (HAIL) model.
19 def __init__(self, beta_dim, theta_dim, eta_dim, pretrained=None, pretrained_eta_encoder=None, gpu_id=0): 20 self.beta_dim = beta_dim 21 self.theta_dim = theta_dim 22 self.eta_dim = eta_dim 23 self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') 24 self.timestr = datetime.now().strftime("%Y%m%d-%H%M%S") 25 26 27 # define networks 28 self.beta_encoder = UNet(in_ch=1, out_ch=self.beta_dim, base_ch=8, final_act='none') 29 self.theta_encoder = ThetaEncoder(in_ch=1, out_ch=self.theta_dim) 30 self.eta_encoder = EtaEncoder(in_ch=1, out_ch=self.eta_dim) 31 self.attention_module = AttentionModule(self.theta_dim + self.eta_dim, v_ch=self.beta_dim) 32 self.decoder = UNet(in_ch=1 + self.theta_dim, out_ch=1, base_ch=16, final_act='relu') 33 self.patchifier = Patchifier(in_ch=1, out_ch=128) 34 35 if pretrained_eta_encoder is not None: 36 checkpoint_eta_encoder = torch.load(pretrained_eta_encoder, map_location=self.device) 37 self.eta_encoder.load_state_dict(checkpoint_eta_encoder['eta_encoder']) 38 if pretrained is not None: 39 self.checkpoint = torch.load(pretrained, map_location=self.device) 40 self.beta_encoder.load_state_dict(self.checkpoint['beta_encoder']) 41 self.theta_encoder.load_state_dict(self.checkpoint['theta_encoder']) 42 self.eta_encoder.load_state_dict(self.checkpoint['eta_encoder']) 43 self.decoder.load_state_dict(self.checkpoint['decoder']) 44 self.attention_module.load_state_dict(self.checkpoint['attention_module']) 45 self.patchifier.load_state_dict(self.checkpoint['patchifier']) 46 self.beta_encoder.to(self.device) 47 self.theta_encoder.to(self.device) 48 self.eta_encoder.to(self.device) 49 self.decoder.to(self.device) 50 self.attention_module.to(self.device) 51 self.patchifier.to(self.device)
53 def channel_aggregation(self, beta_onehot_encode: torch.Tensor) -> torch.Tensor: 54 """ 55 Combine multi-channel one-hot encoded beta into one channel (label-encoding). 56 57 args: 58 beta_onehot_encode: torch.Tensor (batch_size, self.beta_dim, image_dim, image_dim) 59 One-hot encoded beta variable. At each pixel location, only one channel will take value of 1, 60 and other channels will be 0. 61 return: 62 63 64 """ 65 66 batch_size, image_dim = beta_onehot_encode.shape[0], beta_onehot_encode.shape[3] 67 value_tensor = (torch.arange(0, self.beta_dim) * 1.0).to(self.device) 68 value_tensor = value_tensor.view(1, self.beta_dim, 1, 1).repeat(batch_size, 1, image_dim, image_dim) 69 beta_label_encode = beta_onehot_encode * value_tensor.detach() 70 return beta_label_encode.sum(1, keepdim=True) / self.beta_dim
Combine multi-channel one-hot encoded beta into one channel (label-encoding).
args: beta_onehot_encode: torch.Tensor (batch_size, self.beta_dim, image_dim, image_dim) One-hot encoded beta variable. At each pixel location, only one channel will take value of 1, and other channels will be 0. return:
73 def harmonize(self, source_images, target_images, target_theta, target_eta, out_paths, 74 recon_orientation, norm_vals, header=None, num_batches=4) -> [torch.Tensor | None]: 75 """ 76 The main hamronization function that harmonizes the source images to the target images. 77 78 Args: 79 source_images (List[torch.Tensor]): list of source images 80 target_images (List[torch.Tensor]): list of target images 81 target_theta (List[torch.Tensor]): list of target theta values 82 target_eta (List[torch.Tensor]): list of target eta values 83 out_paths (List[Path]): list of output paths 84 recon_orientation (str): orientation of the reconstructed image 85 norm_vals (List[Tuple[int]]): list of normalization values for the reconstructed image 86 header (nib.Nifti1Header): header of the input image 87 num_batches (int): number of batches to divide the input tensor into 88 89 Returns: 90 torch.Tensor: reconstructed harmonized image 91 """ 92 if out_paths is not None: 93 for out_path in out_paths: 94 os.makedirs(out_path.parent, exist_ok=True) 95 prefix = str(out_paths[0].name).split('.')[0] 96 97 # set everything to an eval mode and turn off the gradient 98 with torch.set_grad_enabled(False): 99 self.beta_encoder.eval() 100 self.theta_encoder.eval() 101 self.eta_encoder.eval() 102 self.decoder.eval() 103 104 # Calculate the masks, logits, betas, and keys for the source images 105 logits, betas, keys, masks = [], [], [], [] 106 for source_image in source_images: 107 source_image = source_image.unsqueeze(1) 108 source_image_batches = divide_into_batches(source_image, num_batches) 109 mask_tmp, logit_tmp, beta_tmp, key_tmp = [], [], [], [] 110 for source_image_batch in source_image_batches: 111 batch_size = source_image_batch.shape[0] 112 source_image_batch = source_image_batch.to(self.device) 113 mask = (source_image_batch > 1e-6) * 1.0 114 logit = self.beta_encoder(source_image_batch) 115 beta = self.channel_aggregation(reparameterize_logit(logit)) 116 theta_source, _ = self.theta_encoder(source_image_batch) 117 eta_source = self.eta_encoder(source_image_batch).view(batch_size, self.eta_dim, 1, 1) 118 mask_tmp.append(mask) 119 logit_tmp.append(logit) 120 beta_tmp.append(beta) 121 key_tmp.append(torch.cat([theta_source, eta_source], dim=1)) 122 masks.append(torch.cat(mask_tmp, dim=0)) 123 logits.append(torch.cat(logit_tmp, dim=0)) 124 betas.append(torch.cat(beta_tmp, dim=0)) 125 keys.append(torch.cat(key_tmp, dim=0)) 126 127 # calculate the harmonized theta and eta from the target images 128 if target_theta is None: 129 queries, thetas_target = [], [] 130 for target_image in target_images: 131 target_image = target_image.to(self.device).unsqueeze(1) 132 theta_target, _ = self.theta_encoder(target_image) 133 theta_target = theta_target.mean(dim=0, keepdim=True) 134 eta_target = self.eta_encoder(target_image).mean(dim=0, keepdim=True).view(1, self.eta_dim, 1, 1) 135 thetas_target.append(theta_target) 136 queries.append( 137 torch.cat([theta_target, eta_target], dim=1).view(1, self.theta_dim + self.eta_dim, 1)) 138 else: 139 queries, thetas_target = [], [] 140 for target_theta_tmp, target_eta_tmp in zip(target_theta, target_eta): 141 thetas_target.append(target_theta_tmp.view(1, self.theta_dim, 1, 1).to(self.device)) 142 queries.append(torch.cat([target_theta_tmp.view(1, self.theta_dim, 1).to(self.device), 143 target_eta_tmp.view(1, self.eta_dim, 1).to(self.device)], dim=1)) 144 145 # decode the harmonized normal val image 146 for tid, (theta_target, query, norm_val) in enumerate(zip(thetas_target, queries, norm_vals)): 147 if out_paths is not None: 148 out_prefix = out_paths[tid].name.replace('.nii.gz', '') 149 rec_image, beta_fusion, logit_fusion, attention = [], [], [], [] 150 for batch_id in range(num_batches): 151 keys_tmp = [divide_into_batches(ks, num_batches)[batch_id] for ks in keys] 152 logits_tmp = [divide_into_batches(ls, num_batches)[batch_id] for ls in logits] 153 masks_tmp = [divide_into_batches(ms, num_batches)[batch_id] for ms in masks] 154 batch_size = keys_tmp[0].shape[0] 155 query_tmp = query.view(1, self.theta_dim + self.eta_dim, 1).repeat(batch_size, 1, 1) 156 k = torch.cat(keys_tmp, dim=-1).view(batch_size, self.theta_dim + self.eta_dim, 1, len(source_images)) 157 v = torch.stack(logits_tmp, dim=-1).view(batch_size, self.beta_dim, 224 * 224, len(source_images)) 158 logit_fusion_tmp, attention_tmp = self.attention_module(query_tmp, k, v, None, 5.0) 159 beta_fusion_tmp = self.channel_aggregation(reparameterize_logit(logit_fusion_tmp)) 160 combined_map = torch.cat([beta_fusion_tmp, theta_target.repeat(batch_size, 1, 224, 224)], dim=1) 161 rec_image_tmp = self.decoder(combined_map) * masks_tmp[0] 162 163 rec_image.append(rec_image_tmp) 164 beta_fusion.append(beta_fusion_tmp) 165 logit_fusion.append(logit_fusion_tmp) 166 attention.append(attention_tmp) 167 168 rec_image = torch.cat(rec_image, dim=0) 169 beta_fusion = torch.cat(beta_fusion, dim=0) 170 logit_fusion = torch.cat(logit_fusion, dim=0) 171 attention = torch.cat(attention, dim=0) 172 173 # save the harmonized image 174 print('recon_orient', recon_orientation) 175 if header is not None: 176 if recon_orientation == "axial": 177 img_save = np.array(rec_image.cpu().squeeze().permute(1, 2, 0).permute(1, 0, 2)) 178 else: 179 raise NotImplementedError('Only axial orientation is supported') 180 181 # put the image back to the original shape 182 # put the image back to original harmonized intensity 183 img_save = self.clipper(img_save[112 - 96:112 + 96, :, 112 - 96:112 + 96]* norm_val[0], norm_val[1]) 184 img_save = nib.Nifti1Image(img_save, None, 185 header) 186 file_name = out_path.parent / f'{out_prefix}_harmonized_{recon_orientation}.nii.gz' 187 nib.save(img_save, file_name) 188 189 190 if header is None: 191 return rec_image.cpu().squeeze()
The main hamronization function that harmonizes the source images to the target images.
Args: source_images (List[torch.Tensor]): list of source images target_images (List[torch.Tensor]): list of target images target_theta (List[torch.Tensor]): list of target theta values target_eta (List[torch.Tensor]): list of target eta values out_paths (List[Path]): list of output paths recon_orientation (str): orientation of the reconstructed image norm_vals (List[Tuple[int]]): list of normalization values for the reconstructed image header (nib.Nifti1Header): header of the input image num_batches (int): number of batches to divide the input tensor into
Returns: torch.Tensor: reconstructed harmonized image
193 def clipper(self, rec_image: np.ndarray, norm_val: float) -> np.ndarray: 194 """ Clip the image to the valid intensity 195 196 Args: 197 rec_image (np.ndarray): harmonized image 198 norm_val (float ): normalization value to be clipped to 199 200 Returns: 201 np.ndarray : clipped image 202 """ 203 return np.clip(rec_image, 0, norm_val)
Clip the image to the valid intensity
Args: rec_image (np.ndarray): harmonized image norm_val (float ): normalization value to be clipped to
Returns: np.ndarray : clipped image