Source code for trust.utils.utils

import torch
from torch.utils.data import Dataset

[docs]class SubsetWithTargets(Dataset): """ Provides a convenience torch.utils.data.Dataset subclass that allows one to access a targets field while creating subsets of the dataset. Parameters ---------- dataset: torch.utils.data.Dataset The dataset from which to pull a subset indices: sequence A sequence of indices of the passed dataset from which to select a subset labels: sequence A sequence of labels for each of the elements drawn for the subset. This sequence should be the same length as indices. """ def __init__(self, dataset, indices, labels): self.dataset = torch.utils.data.Subset(dataset, indices) self.targets = labels.type(torch.long) def __getitem__(self, idx): image = self.dataset[idx][0] target = self.targets[idx] return (image, target) def __len__(self): return len(self.targets)
[docs]class SubsetWithTargetsSingleChannel(Dataset): """ Provides a convenience torch.utils.data.Dataset subclass that allows one to access a targets field while creating subsets of the dataset. Additionally, the single-channel images from the wrapped dataset are expanded to three channels for compatibility with three-chanel model input. Parameters ---------- dataset: torch.utils.data.Dataset The dataset from which to pull a subset indices: sequence A sequence of indices of the passed dataset from which to select a subset labels: sequence A sequence of labels for each of the elements drawn for the subset. This sequence should be the same length as indices. """ def __init__(self, dataset, indices, labels): self.dataset = torch.utils.data.Subset(dataset, indices) self.targets = labels.type(torch.long) def __getitem__(self, idx): image = self.dataset[idx][0] image = torch.repeat_interleave(image, 3, 0) target = self.targets[idx] return (image, target) def __len__(self): return len(self.targets)
[docs]class ConcatWithTargets(Dataset): """ Provides a convenience torch.utils.data.Dataset subclass that allows one to access a targets field while creating concatenations of two datasets. Parameters ---------- dataset1: torch.utils.data.Dataset The first dataset to concatenate. Must have a targets field. dataset2: torch.utils.data.Dataset The second dataset to concatenate. Must have a targets field. """ def __init__(self, dataset1, dataset2): self.dataset = torch.utils.data.ConcatDataset([dataset1, dataset2]) self.targets = torch.Tensor(list(dataset1.targets) + list(dataset2.targets)).type(torch.long) def __getitem__(self, idx): image = self.dataset[idx][0] target = self.targets[idx] return (image, target) def __len__(self): return len(self.targets)
[docs]class LabeledToUnlabeledDataset(Dataset): """ Provides a convenience torch.utils.data.Dataset subclass that allows one to ignore the labels in a labeled dataset, thereby making it unlabeled. Parameters ---------- wrapped_dataset: torch.utils.data.Dataset The labeled dataset in which only the data will be returned. """ def __init__(self, wrapped_dataset): self.wrapped_dataset = wrapped_dataset def __getitem__(self, index): data, label = self.wrapped_dataset[index] return data def __len__(self): return len(self.wrapped_dataset)
[docs]def get_roc_auc(target, output, n_classes): """ Function to compute false positive rate(fpr), true positive rate(tpr), and area under ROC(Reciever Operator Characteristics) curve for a list of predicted outputs and ground truth targets. The output is in the form of three dictionaries with class numbers as keys and the values of fpr, tpr, area under ROC curve. Parameters ---------- target: numpy.ndarray The ground truth label of the set output: sequence Predicted output of the set n_classes: int The number of classes in the dataset """ target = label_binarize(target, classes=list(range(n_classes))) output = np.array(output) fpr = dict() tpr = dict() roc_auc = dict() for i in range(n_classes): fpr[i], tpr[i], _ = roc_curve(target[:,i],output[:,i]) roc_auc[i] = auc(fpr[i], tpr[i]) fpr[i] = fpr[i].tolist() tpr[i] = tpr[i].tolist() roc_auc[i] = roc_auc[i].tolist() return fpr,tpr,roc_auc
[docs]def get_pr_auc(target, output, n_classes): """ Function to compute precision, recall, and area under Precision Recall curve for a list of predicted outputs and ground truth targets. The output is in the form of three dictionaries with class numbers as keys and the values of precision, recall, area under precision recall curve. Parameters ---------- target: numpy.ndarray The ground truth label of the set output: sequence Predicted output of the set n_classes: int The number of classes in the dataset """ target = label_binarize(target, classes=list(range(n_classes))) output = np.array(output) precision = dict() recall = dict() aupr = dict() for i in range(n_classes): precision[i], recall[i], _ = precision_recall_curve(target[:,i],output[:,i]) aupr[i] = auc(recall[i], precision[i]) precision[i] = precision[i].tolist() recall[i] = recall[i].tolist() aupr[i] = aupr[i].tolist() return precision, recall, aupr
[docs]def get_macro_roc_auc(target, output, n_classes): """ Function to compute macro average ROC(Reciever Operater Characteristics) for a list of predicted outputs and ground truth targets. The output is in the form of three dictionaries with class numbers as keys and the values of precision, recall, area under precision recall curve. Parameters ---------- target: numpy.ndarray The ground truth label of the set output: sequence Predicted output of the set n_classes: int The number of classes in the dataset """ target = label_binarize(target,classes=list(range(n_classes))) output = np.array(output) return roc_auc_score(target,output,average="macro")