hail.utils

  1from typing import List, Tuple
  2
  3import torch
  4from torch.nn import functional as F
  5import numpy as np
  6
  7def reparameterize_logit(logit: torch.Tensor) -> torch.Tensor:
  8    """ reparameterize the logit tensor
  9
 10    Args:
 11        logit (torch.Tensor): logit tensor
 12
 13    Returns:
 14        torch.Tensor: reparameterized logit
 15    """
 16    import warnings
 17    warnings.filterwarnings('ignore', message='.*Mixed memory format inputs detected.*')
 18    beta = F.gumbel_softmax(logit, tau=1.0, dim=1, hard=True)
 19    return beta
 20
 21def divide_into_batches(in_tensor: torch.Tensor, num_batches:int ) -> List[torch.Tensor]:
 22    """ divide the input tensor into multiple batches
 23
 24    Args:
 25        in_tensor (torch.Tensor): tensor to be divided into batches
 26        num_batches (int): the number of batches to divide the tensor into
 27
 28    Returns:
 29        List[torch.Tensor]: list of tensors divided into batches
 30    """
 31    batch_size = in_tensor.shape[0] // num_batches
 32    remainder = in_tensor.shape[0] % num_batches
 33    batches = []
 34
 35    current_start = 0
 36    # divide the tensor into batches
 37    for i in range(num_batches):
 38        current_end = current_start + batch_size
 39        if remainder:
 40            current_end += 1
 41            remainder -= 1
 42        batches.append(in_tensor[current_start:current_end, ...])
 43        current_start = current_end
 44    return batches
 45
 46
 47def normalize_intensity(image: np.ndarray) -> Tuple[np.ndarray, int]:
 48    """ Normalize the intensity of the image
 49
 50    Args:
 51        image (np.ndarray): input image
 52
 53    Returns:
 54        Tuple[np.ndarray, int]: normalized image and the threshold value
 55    """
 56
 57    thresh = np.percentile(image.flatten(), 95)
 58    image = image / (thresh + 1e-5)
 59    image = np.clip(image, a_min=0.0, a_max=5.0)
 60    return image, thresh
 61
 62
 63def zero_pad(image: np.ndarray, image_dim: int=256) -> np.ndarray:
 64    """ pad a 3D image with zeros with given image dimension
 65
 66    Args:
 67        image (np.ndarray): input 3D image
 68        image_dim (int, optional): image dim. Defaults to 256.
 69
 70    Returns:
 71        np.ndarray: padded image
 72    """
 73    [n_row, n_col, n_slc] = image.shape
 74    image_padded = np.zeros((image_dim, image_dim, image_dim))
 75    center_loc = image_dim // 2
 76    image_padded[center_loc - n_row // 2: center_loc + n_row - n_row // 2,
 77                 center_loc - n_col // 2: center_loc + n_col - n_col // 2,
 78                 center_loc - n_slc // 2: center_loc + n_slc - n_slc // 2] = image
 79    return image_padded
 80
 81def zero_pad2d(image: np.ndarray, image_dim: int=256) -> np.ndarray:
 82    """ pad a 2D image with zeros with given image dimension
 83
 84    Args:
 85        image (np.ndarray): input 2D image
 86        image_dim (int, optional): image dim. Defaults to 256.
 87
 88    Returns:
 89        np.ndarray: padded image
 90    """
 91    [n_row, n_col] = image.shape
 92    image_padded = np.zeros((image_dim, image_dim))
 93    center_loc = image_dim // 2
 94    image_padded[center_loc - n_row // 2: center_loc + n_row - n_row // 2,
 95                 center_loc - n_col // 2: center_loc + n_col - n_col // 2] = image
 96    return image_padded
 97
 98
 99def crop(image: np.ndarray, n_row: int, n_col: int, n_slc: int) -> np.ndarray:
100    """ crop a 3D image to the given dimensions
101
102    Args:
103        image (np.ndarray): input 3D image
104        n_row (int): number of rows
105        n_col (int): number of columns
106        n_slc (int): number of slices
107
108    Returns:
109        np.ndarray: cropped image
110    """
111    image_dim = image.shape[0]
112    center_loc = image_dim // 2
113    return image[center_loc - n_row // 2: center_loc + n_row - n_row // 2,
114                 center_loc - n_col // 2: center_loc + n_col - n_col // 2,
115                 center_loc - n_slc // 2: center_loc + n_slc - n_slc // 2]
116
117def crop2d(image: np.ndarray, n_row: int, n_col: int) -> np.ndarray:
118    """ crop a 2D image to the given dimensions
119
120    Args:
121        image (np.ndarray): input 2D image
122        n_row (int): number of rows
123        n_col (int): number of columns
124
125    Returns:
126        np.ndarray: cropped image
127    """
128    image_dim = image.shape[0]
129    center_loc = image_dim // 2
130    return image[center_loc - n_row // 2: center_loc + n_row - n_row // 2,
131                 center_loc - n_col // 2: center_loc + n_col - n_col // 2]
def reparameterize_logit(logit: torch.Tensor) -> torch.Tensor:
 9def reparameterize_logit(logit: torch.Tensor) -> torch.Tensor:
10    """ reparameterize the logit tensor
11
12    Args:
13        logit (torch.Tensor): logit tensor
14
15    Returns:
16        torch.Tensor: reparameterized logit
17    """
18    import warnings
19    warnings.filterwarnings('ignore', message='.*Mixed memory format inputs detected.*')
20    beta = F.gumbel_softmax(logit, tau=1.0, dim=1, hard=True)
21    return beta

reparameterize the logit tensor

Args: logit (torch.Tensor): logit tensor

Returns: torch.Tensor: reparameterized logit

def divide_into_batches(in_tensor: torch.Tensor, num_batches: int) -> List[torch.Tensor]:
23def divide_into_batches(in_tensor: torch.Tensor, num_batches:int ) -> List[torch.Tensor]:
24    """ divide the input tensor into multiple batches
25
26    Args:
27        in_tensor (torch.Tensor): tensor to be divided into batches
28        num_batches (int): the number of batches to divide the tensor into
29
30    Returns:
31        List[torch.Tensor]: list of tensors divided into batches
32    """
33    batch_size = in_tensor.shape[0] // num_batches
34    remainder = in_tensor.shape[0] % num_batches
35    batches = []
36
37    current_start = 0
38    # divide the tensor into batches
39    for i in range(num_batches):
40        current_end = current_start + batch_size
41        if remainder:
42            current_end += 1
43            remainder -= 1
44        batches.append(in_tensor[current_start:current_end, ...])
45        current_start = current_end
46    return batches

divide the input tensor into multiple batches

Args: in_tensor (torch.Tensor): tensor to be divided into batches num_batches (int): the number of batches to divide the tensor into

Returns: List[torch.Tensor]: list of tensors divided into batches

def normalize_intensity(image: numpy.ndarray) -> Tuple[numpy.ndarray, int]:
49def normalize_intensity(image: np.ndarray) -> Tuple[np.ndarray, int]:
50    """ Normalize the intensity of the image
51
52    Args:
53        image (np.ndarray): input image
54
55    Returns:
56        Tuple[np.ndarray, int]: normalized image and the threshold value
57    """
58
59    thresh = np.percentile(image.flatten(), 95)
60    image = image / (thresh + 1e-5)
61    image = np.clip(image, a_min=0.0, a_max=5.0)
62    return image, thresh

Normalize the intensity of the image

Args: image (np.ndarray): input image

Returns: Tuple[np.ndarray, int]: normalized image and the threshold value

def zero_pad(image: numpy.ndarray, image_dim: int = 256) -> numpy.ndarray:
65def zero_pad(image: np.ndarray, image_dim: int=256) -> np.ndarray:
66    """ pad a 3D image with zeros with given image dimension
67
68    Args:
69        image (np.ndarray): input 3D image
70        image_dim (int, optional): image dim. Defaults to 256.
71
72    Returns:
73        np.ndarray: padded image
74    """
75    [n_row, n_col, n_slc] = image.shape
76    image_padded = np.zeros((image_dim, image_dim, image_dim))
77    center_loc = image_dim // 2
78    image_padded[center_loc - n_row // 2: center_loc + n_row - n_row // 2,
79                 center_loc - n_col // 2: center_loc + n_col - n_col // 2,
80                 center_loc - n_slc // 2: center_loc + n_slc - n_slc // 2] = image
81    return image_padded

pad a 3D image with zeros with given image dimension

Args: image (np.ndarray): input 3D image image_dim (int, optional): image dim. Defaults to 256.

Returns: np.ndarray: padded image

def zero_pad2d(image: numpy.ndarray, image_dim: int = 256) -> numpy.ndarray:
83def zero_pad2d(image: np.ndarray, image_dim: int=256) -> np.ndarray:
84    """ pad a 2D image with zeros with given image dimension
85
86    Args:
87        image (np.ndarray): input 2D image
88        image_dim (int, optional): image dim. Defaults to 256.
89
90    Returns:
91        np.ndarray: padded image
92    """
93    [n_row, n_col] = image.shape
94    image_padded = np.zeros((image_dim, image_dim))
95    center_loc = image_dim // 2
96    image_padded[center_loc - n_row // 2: center_loc + n_row - n_row // 2,
97                 center_loc - n_col // 2: center_loc + n_col - n_col // 2] = image
98    return image_padded

pad a 2D image with zeros with given image dimension

Args: image (np.ndarray): input 2D image image_dim (int, optional): image dim. Defaults to 256.

Returns: np.ndarray: padded image

def crop( image: numpy.ndarray, n_row: int, n_col: int, n_slc: int) -> numpy.ndarray:
101def crop(image: np.ndarray, n_row: int, n_col: int, n_slc: int) -> np.ndarray:
102    """ crop a 3D image to the given dimensions
103
104    Args:
105        image (np.ndarray): input 3D image
106        n_row (int): number of rows
107        n_col (int): number of columns
108        n_slc (int): number of slices
109
110    Returns:
111        np.ndarray: cropped image
112    """
113    image_dim = image.shape[0]
114    center_loc = image_dim // 2
115    return image[center_loc - n_row // 2: center_loc + n_row - n_row // 2,
116                 center_loc - n_col // 2: center_loc + n_col - n_col // 2,
117                 center_loc - n_slc // 2: center_loc + n_slc - n_slc // 2]

crop a 3D image to the given dimensions

Args: image (np.ndarray): input 3D image n_row (int): number of rows n_col (int): number of columns n_slc (int): number of slices

Returns: np.ndarray: cropped image

def crop2d(image: numpy.ndarray, n_row: int, n_col: int) -> numpy.ndarray:
119def crop2d(image: np.ndarray, n_row: int, n_col: int) -> np.ndarray:
120    """ crop a 2D image to the given dimensions
121
122    Args:
123        image (np.ndarray): input 2D image
124        n_row (int): number of rows
125        n_col (int): number of columns
126
127    Returns:
128        np.ndarray: cropped image
129    """
130    image_dim = image.shape[0]
131    center_loc = image_dim // 2
132    return image[center_loc - n_row // 2: center_loc + n_row - n_row // 2,
133                 center_loc - n_col // 2: center_loc + n_col - n_col // 2]

crop a 2D image to the given dimensions

Args: image (np.ndarray): input 2D image n_row (int): number of rows n_col (int): number of columns

Returns: np.ndarray: cropped image