Utilities

Custom Datasets

class trust.utils.custom_dataset.DataHandler_CIFAR10(X, Y=None, select=True, use_test_transform=False)[source]

Data Handler to load CIFAR10 dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

class trust.utils.custom_dataset.DataHandler_MNIST(X, Y=None, select=True, use_test_transform=False)[source]

Data Handler to load MNIST dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

class trust.utils.custom_dataset.DataHandler_SVHN(X, Y=None, select=True, use_test_transform=False)[source]

Data Handler to load SVHN dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

class trust.utils.custom_dataset.DataHandler_UTKFace(X, Y=None, select=True, use_test_transform=False)[source]

Data Handler to load UTKFace dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

trust.utils.custom_dataset.load_dataset_custom(datadir, dset_name, feature, split_cfg, augVal=False, dataAug=True)[source]

Loads a common dataset with additional options to create class imbalances, out-of-distribution classes, and redundancies.

Parameters
  • datadir (string) – The root directory in which the data is stored (or should be downloaded)

  • dset_name (string) – The name of the dataset. This should be one of ‘cifar10’, ‘mnist’, ‘svhn’, ‘cifar100’, ‘breast-density’.

  • feature (string) – The modification that should be applied to the dataset. This should be one of ‘classimb’, ‘ood’, ‘duplicate’, ‘vanilla’

  • split_cfg (dict) –

    Contains information relating to the dataset splits that should be created. Some of the keys for this dictionary are as follows:
    ’per_imbclass_train’: int

    The number of examples in the train set for each imbalanced class (classimb)

    ’per_imbclass_val’: int

    The number of examples in the validation set for each imbalanced class (classimb)

    ’per_imbclass_lake’: int

    The number of examples in the lake set for each imbalanced class (classimb)

    ’per_class_train’: int

    The number of examples in the train set for each balanced class (classimb)

    ’per_class_val’: int

    The number of examples in the validation set for each balanced class (classimb)

    ’per_class_lake’: int

    The number of examples in the lake set for each balanced class (classimb)

    ’sel_cls_idx’: list

    A list of classes that are affected by class imbalance. (classimb)

    ’train_size’: int

    The size of the train set (vanilla, duplicate)

    ’val_size’: int

    The size of the validation set (vanilla, duplicate)

    ’lake_size’: int

    The size of the lake set (vanilla, duplicate)

    ’num_rep’: int

    The number of times to repeat a selection in the lake set (duplicate)

    ’lake_subset_repeat_size’: int

    The size of the repeated selection in the lake set (duplicate)

    ’num_cls_imbalance’: int

    The number of classes to randomly affect by class imbalance. (classimb)

    ’num_cls_idc’: int

    The number of in-distribution classes to keep (ood)

    ’per_idc_train’: int

    The number of in-distribution examples to keep in the train set per class (ood)

    ’per_idc_val’: int

    The number of in-distribution examples to keep in the validation set per class (ood)

    ’per_idc_lake’: int

    The number of in-distribution examples to keep in the lake set per class (ood)

    ’per_ood_train’: int

    The number of OOD examples to keep in the train set per class (ood)

    ’per_ood_val’: int

    The number of OOD examples to keep in the validation set per class (ood)

    ’per_ood_lake’: int

    The number of OOD examples to keep in the lake set per class (ood)

  • augVal (bool, optional) – If True, the train set will also contain affected classes from the validation set. The default is False.

  • dataAug (bool, optional) – If True, the all but the test set will be affected by random cropping and random horizontal flip. The default is True.

Returns

Returns a train set, validation set, test set, lake set, and number of classes. Amount of returned items depends on specific configuration. Each set is an instance of torch.utils.data.Dataset

Return type

tuple

Other Utilities

class trust.utils.utils.ConcatWithTargets(dataset1, dataset2)[source]

Provides a convenience torch.utils.data.Dataset subclass that allows one to access a targets field while creating concatenations of two datasets.

Parameters
  • dataset1 (torch.utils.data.Dataset) – The first dataset to concatenate. Must have a targets field.

  • dataset2 (torch.utils.data.Dataset) – The second dataset to concatenate. Must have a targets field.

class trust.utils.utils.LabeledToUnlabeledDataset(wrapped_dataset)[source]

Provides a convenience torch.utils.data.Dataset subclass that allows one to ignore the labels in a labeled dataset, thereby making it unlabeled.

Parameters

wrapped_dataset (torch.utils.data.Dataset) – The labeled dataset in which only the data will be returned.

class trust.utils.utils.SubsetWithTargets(dataset, indices, labels)[source]

Provides a convenience torch.utils.data.Dataset subclass that allows one to access a targets field while creating subsets of the dataset.

Parameters
  • dataset (torch.utils.data.Dataset) – The dataset from which to pull a subset

  • indices (sequence) – A sequence of indices of the passed dataset from which to select a subset

  • labels (sequence) – A sequence of labels for each of the elements drawn for the subset. This sequence should be the same length as indices.

class trust.utils.utils.SubsetWithTargetsSingleChannel(dataset, indices, labels)[source]

Provides a convenience torch.utils.data.Dataset subclass that allows one to access a targets field while creating subsets of the dataset. Additionally, the single-channel images from the wrapped dataset are expanded to three channels for compatibility with three-chanel model input.

Parameters
  • dataset (torch.utils.data.Dataset) – The dataset from which to pull a subset

  • indices (sequence) – A sequence of indices of the passed dataset from which to select a subset

  • labels (sequence) – A sequence of labels for each of the elements drawn for the subset. This sequence should be the same length as indices.

trust.utils.utils.get_macro_roc_auc(target, output, n_classes)[source]

Function to compute macro average ROC(Reciever Operater Characteristics) for a list of predicted outputs and ground truth targets. The output is in the form of three dictionaries with class numbers as keys and the values of precision, recall, area under precision recall curve.

Parameters
  • target (numpy.ndarray) – The ground truth label of the set

  • output (sequence) – Predicted output of the set

  • n_classes (int) – The number of classes in the dataset

trust.utils.utils.get_pr_auc(target, output, n_classes)[source]

Function to compute precision, recall, and area under Precision Recall curve for a list of predicted outputs and ground truth targets. The output is in the form of three dictionaries with class numbers as keys and the values of precision, recall, area under precision recall curve.

Parameters
  • target (numpy.ndarray) – The ground truth label of the set

  • output (sequence) – Predicted output of the set

  • n_classes (int) – The number of classes in the dataset

trust.utils.utils.get_roc_auc(target, output, n_classes)[source]

Function to compute false positive rate(fpr), true positive rate(tpr), and area under ROC(Reciever Operator Characteristics) curve for a list of predicted outputs and ground truth targets. The output is in the form of three dictionaries with class numbers as keys and the values of fpr, tpr, area under ROC curve.

Parameters
  • target (numpy.ndarray) – The ground truth label of the set

  • output (sequence) – Predicted output of the set

  • n_classes (int) – The number of classes in the dataset

Models

We have incorporated several neural network architectures in TRUST. Below is a list of neural network architectures:
  • densenet

  • dla

  • dla_simple

  • dpn

  • efficientnet

  • googlenet

  • lenet

  • mobilenet

  • mobilenetv2

  • pnasnet

  • preact_resnet

  • regnet

  • resnet

  • resnext

  • senet

  • shufflenet

  • shufflenetv2

  • vgg

To use a custom model architecture, ensure the model architecture has the following:

The forward method should have two more variables:

  1. A boolean variable last which -

    If *true: returns the model output and the output of the second last layer

    If *false: Returns the model output.

  2. A boolean variable ‘freeze’ which -

    If *true: disables the tracking of any calculations required to later calculate a gradient i.e skips gradient calculation over the weights

    If *false: otherwise

  3. get_embedding_dim() method which returns the number of hidden units in the last layer.