Utilities
Custom Datasets
- class trust.utils.custom_dataset.DataHandler_CIFAR10(X, Y=None, select=True, use_test_transform=False)[source]
Data Handler to load CIFAR10 dataset. This class extends
torch.utils.data.Datasetto handle loading data even without labels- Parameters
X (numpy array) – Data to be loaded
y (numpy array, optional) – Labels to be loaded (default: None)
select (bool) – True if loading data without labels, False otherwise
- class trust.utils.custom_dataset.DataHandler_MNIST(X, Y=None, select=True, use_test_transform=False)[source]
Data Handler to load MNIST dataset. This class extends
torch.utils.data.Datasetto handle loading data even without labels- Parameters
X (numpy array) – Data to be loaded
y (numpy array, optional) – Labels to be loaded (default: None)
select (bool) – True if loading data without labels, False otherwise
- class trust.utils.custom_dataset.DataHandler_SVHN(X, Y=None, select=True, use_test_transform=False)[source]
Data Handler to load SVHN dataset. This class extends
torch.utils.data.Datasetto handle loading data even without labels- Parameters
X (numpy array) – Data to be loaded
y (numpy array, optional) – Labels to be loaded (default: None)
select (bool) – True if loading data without labels, False otherwise
- class trust.utils.custom_dataset.DataHandler_UTKFace(X, Y=None, select=True, use_test_transform=False)[source]
Data Handler to load UTKFace dataset. This class extends
torch.utils.data.Datasetto handle loading data even without labels- Parameters
X (numpy array) – Data to be loaded
y (numpy array, optional) – Labels to be loaded (default: None)
select (bool) – True if loading data without labels, False otherwise
- trust.utils.custom_dataset.load_dataset_custom(datadir, dset_name, feature, split_cfg, augVal=False, dataAug=True)[source]
Loads a common dataset with additional options to create class imbalances, out-of-distribution classes, and redundancies.
- Parameters
datadir (string) – The root directory in which the data is stored (or should be downloaded)
dset_name (string) – The name of the dataset. This should be one of ‘cifar10’, ‘mnist’, ‘svhn’, ‘cifar100’, ‘breast-density’.
feature (string) – The modification that should be applied to the dataset. This should be one of ‘classimb’, ‘ood’, ‘duplicate’, ‘vanilla’
split_cfg (dict) –
- Contains information relating to the dataset splits that should be created. Some of the keys for this dictionary are as follows:
- ’per_imbclass_train’: int
The number of examples in the train set for each imbalanced class (classimb)
- ’per_imbclass_val’: int
The number of examples in the validation set for each imbalanced class (classimb)
- ’per_imbclass_lake’: int
The number of examples in the lake set for each imbalanced class (classimb)
- ’per_class_train’: int
The number of examples in the train set for each balanced class (classimb)
- ’per_class_val’: int
The number of examples in the validation set for each balanced class (classimb)
- ’per_class_lake’: int
The number of examples in the lake set for each balanced class (classimb)
- ’sel_cls_idx’: list
A list of classes that are affected by class imbalance. (classimb)
- ’train_size’: int
The size of the train set (vanilla, duplicate)
- ’val_size’: int
The size of the validation set (vanilla, duplicate)
- ’lake_size’: int
The size of the lake set (vanilla, duplicate)
- ’num_rep’: int
The number of times to repeat a selection in the lake set (duplicate)
- ’lake_subset_repeat_size’: int
The size of the repeated selection in the lake set (duplicate)
- ’num_cls_imbalance’: int
The number of classes to randomly affect by class imbalance. (classimb)
- ’num_cls_idc’: int
The number of in-distribution classes to keep (ood)
- ’per_idc_train’: int
The number of in-distribution examples to keep in the train set per class (ood)
- ’per_idc_val’: int
The number of in-distribution examples to keep in the validation set per class (ood)
- ’per_idc_lake’: int
The number of in-distribution examples to keep in the lake set per class (ood)
- ’per_ood_train’: int
The number of OOD examples to keep in the train set per class (ood)
- ’per_ood_val’: int
The number of OOD examples to keep in the validation set per class (ood)
- ’per_ood_lake’: int
The number of OOD examples to keep in the lake set per class (ood)
augVal (bool, optional) – If True, the train set will also contain affected classes from the validation set. The default is False.
dataAug (bool, optional) – If True, the all but the test set will be affected by random cropping and random horizontal flip. The default is True.
- Returns
Returns a train set, validation set, test set, lake set, and number of classes. Amount of returned items depends on specific configuration. Each set is an instance of torch.utils.data.Dataset
- Return type
tuple
Other Utilities
- class trust.utils.utils.ConcatWithTargets(dataset1, dataset2)[source]
Provides a convenience torch.utils.data.Dataset subclass that allows one to access a targets field while creating concatenations of two datasets.
- Parameters
dataset1 (torch.utils.data.Dataset) – The first dataset to concatenate. Must have a targets field.
dataset2 (torch.utils.data.Dataset) – The second dataset to concatenate. Must have a targets field.
- class trust.utils.utils.LabeledToUnlabeledDataset(wrapped_dataset)[source]
Provides a convenience torch.utils.data.Dataset subclass that allows one to ignore the labels in a labeled dataset, thereby making it unlabeled.
- Parameters
wrapped_dataset (torch.utils.data.Dataset) – The labeled dataset in which only the data will be returned.
- class trust.utils.utils.SubsetWithTargets(dataset, indices, labels)[source]
Provides a convenience torch.utils.data.Dataset subclass that allows one to access a targets field while creating subsets of the dataset.
- Parameters
dataset (torch.utils.data.Dataset) – The dataset from which to pull a subset
indices (sequence) – A sequence of indices of the passed dataset from which to select a subset
labels (sequence) – A sequence of labels for each of the elements drawn for the subset. This sequence should be the same length as indices.
- class trust.utils.utils.SubsetWithTargetsSingleChannel(dataset, indices, labels)[source]
Provides a convenience torch.utils.data.Dataset subclass that allows one to access a targets field while creating subsets of the dataset. Additionally, the single-channel images from the wrapped dataset are expanded to three channels for compatibility with three-chanel model input.
- Parameters
dataset (torch.utils.data.Dataset) – The dataset from which to pull a subset
indices (sequence) – A sequence of indices of the passed dataset from which to select a subset
labels (sequence) – A sequence of labels for each of the elements drawn for the subset. This sequence should be the same length as indices.
- trust.utils.utils.get_macro_roc_auc(target, output, n_classes)[source]
Function to compute macro average ROC(Reciever Operater Characteristics) for a list of predicted outputs and ground truth targets. The output is in the form of three dictionaries with class numbers as keys and the values of precision, recall, area under precision recall curve.
- Parameters
target (numpy.ndarray) – The ground truth label of the set
output (sequence) – Predicted output of the set
n_classes (int) – The number of classes in the dataset
- trust.utils.utils.get_pr_auc(target, output, n_classes)[source]
Function to compute precision, recall, and area under Precision Recall curve for a list of predicted outputs and ground truth targets. The output is in the form of three dictionaries with class numbers as keys and the values of precision, recall, area under precision recall curve.
- Parameters
target (numpy.ndarray) – The ground truth label of the set
output (sequence) – Predicted output of the set
n_classes (int) – The number of classes in the dataset
- trust.utils.utils.get_roc_auc(target, output, n_classes)[source]
Function to compute false positive rate(fpr), true positive rate(tpr), and area under ROC(Reciever Operator Characteristics) curve for a list of predicted outputs and ground truth targets. The output is in the form of three dictionaries with class numbers as keys and the values of fpr, tpr, area under ROC curve.
- Parameters
target (numpy.ndarray) – The ground truth label of the set
output (sequence) – Predicted output of the set
n_classes (int) – The number of classes in the dataset
Models
- We have incorporated several neural network architectures in TRUST. Below is a list of neural network architectures:
densenet
dla
dla_simple
dpn
efficientnet
googlenet
lenet
mobilenet
mobilenetv2
pnasnet
preact_resnet
regnet
resnet
resnext
senet
shufflenet
shufflenetv2
vgg
To use a custom model architecture, ensure the model architecture has the following:
The forward method should have two more variables:
A boolean variable last which -
If *true: returns the model output and the output of the second last layer
If *false: Returns the model output.
A boolean variable ‘freeze’ which -
If *true: disables the tracking of any calculations required to later calculate a gradient i.e skips gradient calculation over the weights
If *false: otherwise
get_embedding_dim() method which returns the number of hidden units in the last layer.