hail.utils
1from typing import List, Tuple 2 3import torch 4from torch.nn import functional as F 5import numpy as np 6 7def reparameterize_logit(logit: torch.Tensor) -> torch.Tensor: 8 """ reparameterize the logit tensor 9 10 Args: 11 logit (torch.Tensor): logit tensor 12 13 Returns: 14 torch.Tensor: reparameterized logit 15 """ 16 import warnings 17 warnings.filterwarnings('ignore', message='.*Mixed memory format inputs detected.*') 18 beta = F.gumbel_softmax(logit, tau=1.0, dim=1, hard=True) 19 return beta 20 21def divide_into_batches(in_tensor: torch.Tensor, num_batches:int ) -> List[torch.Tensor]: 22 """ divide the input tensor into multiple batches 23 24 Args: 25 in_tensor (torch.Tensor): tensor to be divided into batches 26 num_batches (int): the number of batches to divide the tensor into 27 28 Returns: 29 List[torch.Tensor]: list of tensors divided into batches 30 """ 31 batch_size = in_tensor.shape[0] // num_batches 32 remainder = in_tensor.shape[0] % num_batches 33 batches = [] 34 35 current_start = 0 36 # divide the tensor into batches 37 for i in range(num_batches): 38 current_end = current_start + batch_size 39 if remainder: 40 current_end += 1 41 remainder -= 1 42 batches.append(in_tensor[current_start:current_end, ...]) 43 current_start = current_end 44 return batches 45 46 47def normalize_intensity(image: np.ndarray) -> Tuple[np.ndarray, int]: 48 """ Normalize the intensity of the image 49 50 Args: 51 image (np.ndarray): input image 52 53 Returns: 54 Tuple[np.ndarray, int]: normalized image and the threshold value 55 """ 56 57 thresh = np.percentile(image.flatten(), 95) 58 image = image / (thresh + 1e-5) 59 image = np.clip(image, a_min=0.0, a_max=5.0) 60 return image, thresh 61 62 63def zero_pad(image: np.ndarray, image_dim: int=256) -> np.ndarray: 64 """ pad a 3D image with zeros with given image dimension 65 66 Args: 67 image (np.ndarray): input 3D image 68 image_dim (int, optional): image dim. Defaults to 256. 69 70 Returns: 71 np.ndarray: padded image 72 """ 73 [n_row, n_col, n_slc] = image.shape 74 image_padded = np.zeros((image_dim, image_dim, image_dim)) 75 center_loc = image_dim // 2 76 image_padded[center_loc - n_row // 2: center_loc + n_row - n_row // 2, 77 center_loc - n_col // 2: center_loc + n_col - n_col // 2, 78 center_loc - n_slc // 2: center_loc + n_slc - n_slc // 2] = image 79 return image_padded 80 81def zero_pad2d(image: np.ndarray, image_dim: int=256) -> np.ndarray: 82 """ pad a 2D image with zeros with given image dimension 83 84 Args: 85 image (np.ndarray): input 2D image 86 image_dim (int, optional): image dim. Defaults to 256. 87 88 Returns: 89 np.ndarray: padded image 90 """ 91 [n_row, n_col] = image.shape 92 image_padded = np.zeros((image_dim, image_dim)) 93 center_loc = image_dim // 2 94 image_padded[center_loc - n_row // 2: center_loc + n_row - n_row // 2, 95 center_loc - n_col // 2: center_loc + n_col - n_col // 2] = image 96 return image_padded 97 98 99def crop(image: np.ndarray, n_row: int, n_col: int, n_slc: int) -> np.ndarray: 100 """ crop a 3D image to the given dimensions 101 102 Args: 103 image (np.ndarray): input 3D image 104 n_row (int): number of rows 105 n_col (int): number of columns 106 n_slc (int): number of slices 107 108 Returns: 109 np.ndarray: cropped image 110 """ 111 image_dim = image.shape[0] 112 center_loc = image_dim // 2 113 return image[center_loc - n_row // 2: center_loc + n_row - n_row // 2, 114 center_loc - n_col // 2: center_loc + n_col - n_col // 2, 115 center_loc - n_slc // 2: center_loc + n_slc - n_slc // 2] 116 117def crop2d(image: np.ndarray, n_row: int, n_col: int) -> np.ndarray: 118 """ crop a 2D image to the given dimensions 119 120 Args: 121 image (np.ndarray): input 2D image 122 n_row (int): number of rows 123 n_col (int): number of columns 124 125 Returns: 126 np.ndarray: cropped image 127 """ 128 image_dim = image.shape[0] 129 center_loc = image_dim // 2 130 return image[center_loc - n_row // 2: center_loc + n_row - n_row // 2, 131 center_loc - n_col // 2: center_loc + n_col - n_col // 2]
9def reparameterize_logit(logit: torch.Tensor) -> torch.Tensor: 10 """ reparameterize the logit tensor 11 12 Args: 13 logit (torch.Tensor): logit tensor 14 15 Returns: 16 torch.Tensor: reparameterized logit 17 """ 18 import warnings 19 warnings.filterwarnings('ignore', message='.*Mixed memory format inputs detected.*') 20 beta = F.gumbel_softmax(logit, tau=1.0, dim=1, hard=True) 21 return beta
reparameterize the logit tensor
Args: logit (torch.Tensor): logit tensor
Returns: torch.Tensor: reparameterized logit
23def divide_into_batches(in_tensor: torch.Tensor, num_batches:int ) -> List[torch.Tensor]: 24 """ divide the input tensor into multiple batches 25 26 Args: 27 in_tensor (torch.Tensor): tensor to be divided into batches 28 num_batches (int): the number of batches to divide the tensor into 29 30 Returns: 31 List[torch.Tensor]: list of tensors divided into batches 32 """ 33 batch_size = in_tensor.shape[0] // num_batches 34 remainder = in_tensor.shape[0] % num_batches 35 batches = [] 36 37 current_start = 0 38 # divide the tensor into batches 39 for i in range(num_batches): 40 current_end = current_start + batch_size 41 if remainder: 42 current_end += 1 43 remainder -= 1 44 batches.append(in_tensor[current_start:current_end, ...]) 45 current_start = current_end 46 return batches
divide the input tensor into multiple batches
Args: in_tensor (torch.Tensor): tensor to be divided into batches num_batches (int): the number of batches to divide the tensor into
Returns: List[torch.Tensor]: list of tensors divided into batches
49def normalize_intensity(image: np.ndarray) -> Tuple[np.ndarray, int]: 50 """ Normalize the intensity of the image 51 52 Args: 53 image (np.ndarray): input image 54 55 Returns: 56 Tuple[np.ndarray, int]: normalized image and the threshold value 57 """ 58 59 thresh = np.percentile(image.flatten(), 95) 60 image = image / (thresh + 1e-5) 61 image = np.clip(image, a_min=0.0, a_max=5.0) 62 return image, thresh
Normalize the intensity of the image
Args: image (np.ndarray): input image
Returns: Tuple[np.ndarray, int]: normalized image and the threshold value
65def zero_pad(image: np.ndarray, image_dim: int=256) -> np.ndarray: 66 """ pad a 3D image with zeros with given image dimension 67 68 Args: 69 image (np.ndarray): input 3D image 70 image_dim (int, optional): image dim. Defaults to 256. 71 72 Returns: 73 np.ndarray: padded image 74 """ 75 [n_row, n_col, n_slc] = image.shape 76 image_padded = np.zeros((image_dim, image_dim, image_dim)) 77 center_loc = image_dim // 2 78 image_padded[center_loc - n_row // 2: center_loc + n_row - n_row // 2, 79 center_loc - n_col // 2: center_loc + n_col - n_col // 2, 80 center_loc - n_slc // 2: center_loc + n_slc - n_slc // 2] = image 81 return image_padded
pad a 3D image with zeros with given image dimension
Args: image (np.ndarray): input 3D image image_dim (int, optional): image dim. Defaults to 256.
Returns: np.ndarray: padded image
83def zero_pad2d(image: np.ndarray, image_dim: int=256) -> np.ndarray: 84 """ pad a 2D image with zeros with given image dimension 85 86 Args: 87 image (np.ndarray): input 2D image 88 image_dim (int, optional): image dim. Defaults to 256. 89 90 Returns: 91 np.ndarray: padded image 92 """ 93 [n_row, n_col] = image.shape 94 image_padded = np.zeros((image_dim, image_dim)) 95 center_loc = image_dim // 2 96 image_padded[center_loc - n_row // 2: center_loc + n_row - n_row // 2, 97 center_loc - n_col // 2: center_loc + n_col - n_col // 2] = image 98 return image_padded
pad a 2D image with zeros with given image dimension
Args: image (np.ndarray): input 2D image image_dim (int, optional): image dim. Defaults to 256.
Returns: np.ndarray: padded image
101def crop(image: np.ndarray, n_row: int, n_col: int, n_slc: int) -> np.ndarray: 102 """ crop a 3D image to the given dimensions 103 104 Args: 105 image (np.ndarray): input 3D image 106 n_row (int): number of rows 107 n_col (int): number of columns 108 n_slc (int): number of slices 109 110 Returns: 111 np.ndarray: cropped image 112 """ 113 image_dim = image.shape[0] 114 center_loc = image_dim // 2 115 return image[center_loc - n_row // 2: center_loc + n_row - n_row // 2, 116 center_loc - n_col // 2: center_loc + n_col - n_col // 2, 117 center_loc - n_slc // 2: center_loc + n_slc - n_slc // 2]
crop a 3D image to the given dimensions
Args: image (np.ndarray): input 3D image n_row (int): number of rows n_col (int): number of columns n_slc (int): number of slices
Returns: np.ndarray: cropped image
119def crop2d(image: np.ndarray, n_row: int, n_col: int) -> np.ndarray: 120 """ crop a 2D image to the given dimensions 121 122 Args: 123 image (np.ndarray): input 2D image 124 n_row (int): number of rows 125 n_col (int): number of columns 126 127 Returns: 128 np.ndarray: cropped image 129 """ 130 image_dim = image.shape[0] 131 center_loc = image_dim // 2 132 return image[center_loc - n_row // 2: center_loc + n_row - n_row // 2, 133 center_loc - n_col // 2: center_loc + n_col - n_col // 2]
crop a 2D image to the given dimensions
Args: image (np.ndarray): input 2D image n_row (int): number of rows n_col (int): number of columns
Returns: np.ndarray: cropped image