Source code for trust.utils.custom_dataset

import numpy as np
import os
import torch
import torchvision
from sklearn import datasets
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import PIL.Image as Image
from .utils import *
np.random.seed(42)
torch.manual_seed(42)

[docs]class DataHandler_MNIST(Dataset): """ Data Handler to load MNIST dataset. This class extends :class:`torch.utils.data.Dataset` to handle loading data even without labels Parameters ---------- X: numpy array Data to be loaded y: numpy array, optional Labels to be loaded (default: None) select: bool True if loading data without labels, False otherwise """ def __init__(self, X, Y=None, select=True, use_test_transform=False): """ Constructor """ self.select = select self.use_test_transform=use_test_transform self.training_gen_transform = transforms.Compose([transforms.Resize((32, 32)), transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) self.test_gen_transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) if not self.select: self.X = X self.targets = Y else: self.X = X def __getitem__(self, index): if not self.select: x, y = self.X[index], self.targets[index] x = Image.fromarray(x) if self.use_test_transform: x = self.test_gen_transform(x) else: x = self.training_gen_transform(x) if(x.shape[0]==1): x = torch.repeat_interleave(x, 3, 0) y=y.long() return (x, y.long()) else: x = self.X[index] x = Image.fromarray(x) if self.use_test_transform: x = self.test_gen_transform(x) else: x = self.training_gen_transform(x) if(x.shape[0]==1): x = torch.repeat_interleave(x, 3, 0) return x def __len__(self): return len(self.X)
[docs]class DataHandler_CIFAR10(Dataset): """ Data Handler to load CIFAR10 dataset. This class extends :class:`torch.utils.data.Dataset` to handle loading data even without labels Parameters ---------- X: numpy array Data to be loaded y: numpy array, optional Labels to be loaded (default: None) select: bool True if loading data without labels, False otherwise """ def __init__(self, X, Y=None, select=True, use_test_transform = False): """ Constructor """ self.select = select if(use_test_transform): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) else: transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) if not self.select: self.X = X self.targets = Y self.transform = transform else: self.X = X self.transform = transform def __getitem__(self, index): if not self.select: x, y = self.X[index], self.targets[index] x = Image.fromarray(x) x = self.transform(x) return (x, y) else: x = self.X[index] x = Image.fromarray(x) x = self.transform(x) return x def __len__(self): return len(self.X)
[docs]class DataHandler_SVHN(Dataset): """ Data Handler to load SVHN dataset. This class extends :class:`torch.utils.data.Dataset` to handle loading data even without labels Parameters ---------- X: numpy array Data to be loaded y: numpy array, optional Labels to be loaded (default: None) select: bool True if loading data without labels, False otherwise """ def __init__(self, X, Y=None, select=True, use_test_transform=False): """ Constructor """ self.select = select self.use_test_transform=use_test_transform self.training_gen_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) self.test_gen_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # ImageNet mean/std if not self.select: self.X = X self.targets = Y else: self.X = X def __getitem__(self, index): if not self.select: x, y = self.X[index], self.targets[index] x = Image.fromarray(np.transpose(x, (1, 2, 0))) if self.use_test_transform: x = self.test_gen_transform(x) else: x = self.training_gen_transform(x) return (x, y) else: x = self.X[index] x = Image.fromarray(np.transpose(x, (1, 2, 0))) if self.use_test_transform: x = self.test_gen_transform(x) else: x = self.training_gen_transform(x) return x def __len__(self): return len(self.X)
[docs]class DataHandler_UTKFace(Dataset): """ Data Handler to load UTKFace dataset. This class extends :class:`torch.utils.data.Dataset` to handle loading data even without labels Parameters ---------- X: numpy array Data to be loaded y: numpy array, optional Labels to be loaded (default: None) select: bool True if loading data without labels, False otherwise """ def __init__(self, X, Y=None, select=True, use_test_transform = False): """ Constructor """ self.select = select if(use_test_transform): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # ImageNet mean/std else: transform = transforms.Compose([transforms.RandomCrop(200, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # ImageNet mean/std if not self.select: self.X = X self.targets = Y self.transform = transform else: self.X = X self.transform = transform def __getitem__(self, index): if not self.select: x, y = self.X[index], self.targets[index] x = Image.fromarray(np.transpose(x, (1,2,0))) x = self.transform(x) return (x, y) else: x = self.X[index] x = Image.fromarray(x) x = self.transform(x) return x def __len__(self): return len(self.X)
class DuplicateChannels(object): """ Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8 In the other cases, tensors are returned without scaling. """ def __call__(self, pic): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. Returns: Tensor: Converted image. """ return torch.repeat_interleave(pic.unsqueeze(1), 3, 1).float() def __repr__(self): return self.__class__.__name__ + '()' def getOODtargets(targets, sel_cls_idx, ood_cls_id): ood_targets = [] targets_list = list(targets) for i in range(len(targets_list)): if(targets_list[i] in list(sel_cls_idx)): ood_targets.append(targets_list[i]) else: ood_targets.append(ood_cls_id) print("num ood samples: ", ood_targets.count(ood_cls_id)) return torch.Tensor(ood_targets) def create_ood_data(dset_name, fullset, testset, split_cfg, num_cls, augVal): np.random.seed(42) train_idx = [] val_idx = [] lake_idx = [] test_idx = [] selected_classes = np.array(list(range(split_cfg['num_cls_idc']))) for i in range(num_cls): #all_classes if(dset_name=="mnist"): full_idx_class = list(torch.where(torch.Tensor(fullset.targets.float()) == i)[0].cpu().numpy()) else: full_idx_class = list(torch.where(torch.Tensor(fullset.targets) == i)[0].cpu().numpy()) if(i in selected_classes): if(dset_name=="mnist"): test_idx_class = list(torch.where(torch.Tensor(testset.targets.float()) == i)[0].cpu().numpy()) else: test_idx_class = list(torch.where(torch.Tensor(testset.targets) == i)[0].cpu().numpy()) test_idx += test_idx_class class_train_idx = list(np.random.choice(np.array(full_idx_class), size=split_cfg['per_idc_train'], replace=False)) train_idx += class_train_idx remain_idx = list(set(full_idx_class) - set(class_train_idx)) class_val_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_idc_val'], replace=False)) remain_idx = list(set(remain_idx) - set(class_val_idx)) class_lake_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_idc_lake'], replace=False)) else: class_train_idx = list(np.random.choice(np.array(full_idx_class), size=split_cfg['per_ood_train'], replace=False)) #always 0 remain_idx = list(set(full_idx_class) - set(class_train_idx)) class_val_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_ood_val'], replace=False)) #Only for CG ood val has samples remain_idx = list(set(remain_idx) - set(class_val_idx)) class_lake_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_ood_lake'], replace=False)) #many ood samples in lake if(augVal and (i in selected_classes)): #augment with samples only from the imbalanced classes train_idx += class_val_idx val_idx += class_val_idx lake_idx += class_lake_idx if(dset_name=="mnist"): train_set = SubsetWithTargetsSingleChannel(fullset, train_idx, torch.Tensor(fullset.targets.float())[train_idx]) val_set = SubsetWithTargetsSingleChannel(fullset, val_idx, torch.Tensor(fullset.targets.float())[val_idx]) lake_set = SubsetWithTargetsSingleChannel(fullset, lake_idx, getOODtargets(torch.Tensor(fullset.targets.float())[lake_idx], selected_classes, split_cfg['num_cls_idc'])) test_set = SubsetWithTargetsSingleChannel(testset, test_idx, torch.Tensor(testset.targets.float())[test_idx]) else: train_set = SubsetWithTargets(fullset, train_idx, torch.Tensor(fullset.targets)[train_idx]) val_set = SubsetWithTargets(fullset, val_idx, torch.Tensor(fullset.targets)[val_idx]) lake_set = SubsetWithTargets(fullset, lake_idx, getOODtargets(torch.Tensor(fullset.targets)[lake_idx], selected_classes, split_cfg['num_cls_idc'])) test_set = SubsetWithTargets(testset, test_idx, torch.Tensor(testset.targets)[test_idx]) return train_set, val_set, test_set, lake_set, selected_classes def create_class_imb(dset_name, fullset, split_cfg, num_cls, augVal): np.random.seed(42) train_idx = [] val_idx = [] lake_idx = [] if(dset_name=="mnist"): selected_classes=np.array([5,8]) else: selected_classes = np.random.choice(np.arange(num_cls), size=split_cfg['num_cls_imbalance'], replace=False) #classes to imbalance for i in range(num_cls): #all_classes if(dset_name=="mnist"): full_idx_class = list(torch.where(torch.Tensor(fullset.targets.float()) == i)[0].cpu().numpy()) elif(dset_name=="svhn"): full_idx_class = list(torch.where(torch.Tensor(fullset.labels) == i)[0].cpu().numpy()) else: full_idx_class = list(torch.where(torch.Tensor(fullset.targets) == i)[0].cpu().numpy()) if(i in selected_classes): class_train_idx = list(np.random.choice(np.array(full_idx_class), size=split_cfg['per_imbclass_train'], replace=False)) remain_idx = list(set(full_idx_class) - set(class_train_idx)) class_val_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_imbclass_val'], replace=False)) remain_idx = list(set(remain_idx) - set(class_val_idx)) class_lake_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_imbclass_lake'], replace=False)) else: class_train_idx = list(np.random.choice(np.array(full_idx_class), size=split_cfg['per_class_train'], replace=False)) remain_idx = list(set(full_idx_class) - set(class_train_idx)) class_val_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_class_val'], replace=False)) remain_idx = list(set(remain_idx) - set(class_val_idx)) class_lake_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_class_lake'], replace=False)) train_idx += class_train_idx if(augVal and (i in selected_classes)): #augment with samples only from the imbalanced classes train_idx += class_val_idx val_idx += class_val_idx lake_idx += class_lake_idx if(dset_name=="mnist"): train_set = SubsetWithTargetsSingleChannel(fullset, train_idx, torch.Tensor(fullset.targets.float())[train_idx]) val_set = SubsetWithTargetsSingleChannel(fullset, val_idx, torch.Tensor(fullset.targets.float())[val_idx]) lake_set = SubsetWithTargetsSingleChannel(fullset, lake_idx, torch.Tensor(fullset.targets.float())[lake_idx]) elif(dset_name=="svhn"): train_set = SubsetWithTargets(fullset, train_idx, torch.Tensor(fullset.labels)[train_idx]) val_set = SubsetWithTargets(fullset, val_idx, torch.Tensor(fullset.labels)[val_idx]) lake_set = SubsetWithTargets(fullset, lake_idx, torch.Tensor(fullset.labels)[lake_idx]) else: train_set = SubsetWithTargets(fullset, train_idx, torch.Tensor(fullset.targets)[train_idx]) val_set = SubsetWithTargets(fullset, val_idx, torch.Tensor(fullset.targets)[val_idx]) lake_set = SubsetWithTargets(fullset, lake_idx, torch.Tensor(fullset.targets)[lake_idx]) return train_set, val_set, lake_set, selected_classes def getDuplicateData(dset_name, fullset, split_cfg): num_rep=split_cfg['num_rep'] if(dset_name=="mnist"): # X = np.resize(fullset.data.float().cpu().numpy(), (len(fullset),32,32)) X = fullset.data.numpy() y = torch.from_numpy(np.array(fullset.targets.float())) elif(dset_name == "svhn"): X = fullset.data y = torch.from_numpy(np.array(fullset.labels)) else: X = fullset.data y = torch.from_numpy(np.array(fullset.targets)) X_tr = X[:split_cfg['train_size']] y_tr = y[:split_cfg['train_size']] X_unlabeled = X[split_cfg['train_size']:len(X)-split_cfg['val_size']] y_unlabeled = y[split_cfg['train_size']:len(X)-split_cfg['val_size']] X_val = X[len(X)-split_cfg['val_size']:] y_val = y[len(X)-split_cfg['val_size']:] X_unlabeled_rep = np.repeat(X_unlabeled[:split_cfg['lake_subset_repeat_size']], num_rep, axis=0) y_unlabeled_rep = np.repeat(y_unlabeled[:split_cfg['lake_subset_repeat_size']], num_rep, axis=0) assert((X_unlabeled_rep[0]==X_unlabeled_rep[num_rep-1]).all()) assert((y_unlabeled_rep[0]==y_unlabeled_rep[num_rep-1]).all()) X_unlabeled_rep = np.concatenate((X_unlabeled_rep, X_unlabeled[split_cfg['lake_subset_repeat_size']:split_cfg['lake_size']]), axis=0) y_unlabeled_rep = torch.from_numpy(np.concatenate((y_unlabeled_rep, y_unlabeled[split_cfg['lake_subset_repeat_size']:split_cfg['lake_size']]), axis=0)) if(dset_name=="mnist"): train_set = DataHandler_MNIST(X_tr, y_tr, False) lake_set = DataHandler_MNIST(X_unlabeled_rep, y_unlabeled_rep, False) val_set = DataHandler_MNIST(X_val, y_val, False) elif(dset_name=="svhn"): train_set = DataHandler_SVHN(X_tr, y_tr, False) lake_set = DataHandler_SVHN(X_unlabeled_rep, y_unlabeled_rep, False) val_set = DataHandler_SVHN(X_val, y_val, False) else: train_set = DataHandler_CIFAR10(X_tr, y_tr, False) lake_set = DataHandler_CIFAR10(X_unlabeled_rep, y_unlabeled_rep, False) val_set = DataHandler_CIFAR10(X_val, y_val, False) return X_tr, y_tr, X_val, y_val, X_unlabeled_rep, y_unlabeled_rep, train_set, val_set, lake_set def getVanillaData(dset_name, fullset, split_cfg): if(dset_name=="mnist"): # X = np.resize(fullset.data.float().cpu().numpy(), (len(fullset),32,32)) X = fullset.data.numpy() y = torch.from_numpy(np.array(fullset.targets.float())) else: X = fullset.data y = torch.from_numpy(np.array(fullset.targets)) X_tr = X[:split_cfg['train_size']] y_tr = y[:split_cfg['train_size']] X_unlabeled = X[split_cfg['train_size']:len(X)-split_cfg['val_size']] y_unlabeled = y[split_cfg['train_size']:len(X)-split_cfg['val_size']] X_val = X[len(X)-split_cfg['val_size']:] y_val = y[len(X)-split_cfg['val_size']:] if(dset_name=="mnist"): train_set = DataHandler_MNIST(X_tr, y_tr, False) lake_set = DataHandler_MNIST(X_unlabeled, y_unlabeled, False) val_set = DataHandler_MNIST(X_val, y_val, False) else: train_set = DataHandler_CIFAR10(X_tr, y_tr, False) lake_set = DataHandler_CIFAR10(X_unlabeled, y_unlabeled, False) val_set = DataHandler_CIFAR10(X_val, y_val, False) return X_tr, y_tr, X_val, y_val, X_unlabeled[:split_cfg['lake_size']], y_unlabeled[:split_cfg['lake_size']], train_set, val_set, lake_set def create_perclass_imb(dset_name, fullset, split_cfg, num_cls, augVal): np.random.seed(42) train_idx = [] val_idx = [] lake_idx = [] selected_classes=split_cfg['sel_cls_idx'] for i in range(num_cls): #all_classes full_idx_class = list(torch.where(torch.Tensor(fullset.targets) == i)[0].cpu().numpy()) class_train_idx = list(np.random.choice(np.array(full_idx_class), size=split_cfg['per_class_train'][i], replace=False)) remain_idx = list(set(full_idx_class) - set(class_train_idx)) class_val_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_class_val'][i], replace=False)) remain_idx = list(set(remain_idx) - set(class_val_idx)) class_lake_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_class_lake'][i], replace=False)) train_idx += class_train_idx if(augVal and (i in selected_classes)): #augment with samples only from the imbalanced classes train_idx += class_val_idx val_idx += class_val_idx lake_idx += class_lake_idx train_set = SubsetWithTargets(fullset, train_idx, torch.Tensor(fullset.targets)[train_idx]) val_set = SubsetWithTargets(fullset, val_idx, torch.Tensor(fullset.targets)[val_idx]) lake_set = SubsetWithTargets(fullset, lake_idx, torch.Tensor(fullset.targets)[lake_idx]) return train_set, val_set, lake_set, selected_classes
[docs]def load_dataset_custom(datadir, dset_name, feature, split_cfg, augVal=False, dataAug=True): """ Loads a common dataset with additional options to create class imbalances, out-of-distribution classes, and redundancies. Parameters ---------- datadir : string The root directory in which the data is stored (or should be downloaded) dset_name : string The name of the dataset. This should be one of 'cifar10', 'mnist', 'svhn', 'cifar100', 'breast-density'. feature : string The modification that should be applied to the dataset. This should be one of 'classimb', 'ood', 'duplicate', 'vanilla' split_cfg : dict Contains information relating to the dataset splits that should be created. Some of the keys for this dictionary are as follows: 'per_imbclass_train': int The number of examples in the train set for each imbalanced class (classimb) 'per_imbclass_val': int The number of examples in the validation set for each imbalanced class (classimb) 'per_imbclass_lake': int The number of examples in the lake set for each imbalanced class (classimb) 'per_class_train': int The number of examples in the train set for each balanced class (classimb) 'per_class_val': int The number of examples in the validation set for each balanced class (classimb) 'per_class_lake': int The number of examples in the lake set for each balanced class (classimb) 'sel_cls_idx': list A list of classes that are affected by class imbalance. (classimb) 'train_size': int The size of the train set (vanilla, duplicate) 'val_size': int The size of the validation set (vanilla, duplicate) 'lake_size': int The size of the lake set (vanilla, duplicate) 'num_rep': int The number of times to repeat a selection in the lake set (duplicate) 'lake_subset_repeat_size': int The size of the repeated selection in the lake set (duplicate) 'num_cls_imbalance': int The number of classes to randomly affect by class imbalance. (classimb) 'num_cls_idc': int The number of in-distribution classes to keep (ood) 'per_idc_train': int The number of in-distribution examples to keep in the train set per class (ood) 'per_idc_val': int The number of in-distribution examples to keep in the validation set per class (ood) 'per_idc_lake': int The number of in-distribution examples to keep in the lake set per class (ood) 'per_ood_train': int The number of OOD examples to keep in the train set per class (ood) 'per_ood_val': int The number of OOD examples to keep in the validation set per class (ood) 'per_ood_lake': int The number of OOD examples to keep in the lake set per class (ood) augVal : bool, optional If True, the train set will also contain affected classes from the validation set. The default is False. dataAug : bool, optional If True, the all but the test set will be affected by random cropping and random horizontal flip. The default is True. Returns ------- tuple Returns a train set, validation set, test set, lake set, and number of classes. Amount of returned items depends on specific configuration. Each set is an instance of torch.utils.data.Dataset """ if(not(os.path.exists(datadir))): os.mkdir(datadir) if(dset_name=="cifar10"): num_cls=10 cifar_test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) if(dataAug): cifar_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) else: cifar_transform = cifar_test_transform fullset = torchvision.datasets.CIFAR10(root=datadir, train=True, download=True, transform=cifar_transform) test_set = torchvision.datasets.CIFAR10(root=datadir, train=False, download=True, transform=cifar_test_transform) if(feature=="classimb"): if("sel_cls_idx" in split_cfg): train_set, val_set, lake_set, imb_cls_idx = create_perclass_imb(dset_name, fullset, split_cfg, num_cls, augVal) else: train_set, val_set, lake_set, imb_cls_idx = create_class_imb(dset_name, fullset, split_cfg, num_cls, augVal) print("CIFAR-10 Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, imb_cls_idx, num_cls if(feature=="ood"): train_set, val_set, test_set, lake_set, ood_cls_idx = create_ood_data(dset_name, fullset, test_set, split_cfg, num_cls, augVal) print("CIFAR-10 Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set), "Test set: ", len(test_set)) return train_set, val_set, test_set, lake_set, ood_cls_idx, split_cfg['num_cls_idc'] if(feature=="vanilla"): X_tr, y_tr, X_val, y_val, X_unlabeled, y_unlabeled, train_set, val_set, lake_set = getVanillaData(dset_name, fullset, split_cfg) print("CIFAR-10 Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, num_cls if(feature=="duplicate"): X_tr, y_tr, X_val, y_val, X_unlabeled_rep, y_unlabeled_rep, train_set, val_set, lake_set = getDuplicateData(dset_name, fullset, split_cfg) print("CIFAR-10 Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, num_cls if(dset_name=="mnist"): num_cls=10 mnist_test_transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) if(dataAug): mnist_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) else: mnist_transform = mnist_test_transform fullset = torchvision.datasets.MNIST(root=datadir, train=True, download=True, transform=mnist_transform) test_set = torchvision.datasets.MNIST(root=datadir, train=False, download=True, transform=mnist_test_transform) # fullset.data = torch.repeat_interleave(fullset.data.unsqueeze(1), 3, 1).float() if(feature=="classimb"): if("sel_cls_idx" in split_cfg): train_set, val_set, lake_set, imb_cls_idx = create_perclass_imb(dset_name, fullset, split_cfg, num_cls, augVal) else: train_set, val_set, lake_set, imb_cls_idx = create_class_imb(dset_name, fullset, split_cfg, num_cls, augVal) print("MNIST Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, imb_cls_idx, num_cls if(feature=="ood"): train_set, val_set, test_set, lake_set, ood_cls_idx = create_ood_data(dset_name, fullset, test_set, split_cfg, num_cls, augVal) print("MNIST Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set), "Test set: ", len(test_set)) return train_set, val_set, test_set, lake_set, ood_cls_idx, split_cfg['num_cls_idc'] if(feature=="vanilla"): X_tr, y_tr, X_val, y_val, X_unlabeled, y_unlabeled, train_set, val_set, lake_set = getVanillaData(dset_name, fullset, split_cfg) print("MNIST Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, num_cls if(feature=="duplicate"): X_tr, y_tr, X_val, y_val, X_unlabeled_rep, y_unlabeled_rep, train_set, val_set, lake_set = getDuplicateData(dset_name, fullset, split_cfg) print("MNIST Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, num_cls if(dset_name=="svhn"): num_cls=10 SVHN_test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) if(dataAug): SVHN_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) else: SVHN_transform = SVHN_test_transform fullset = torchvision.datasets.SVHN(root=datadir, split="train", download=True, transform=SVHN_transform) test_set = torchvision.datasets.SVHN(root=datadir, split="test", download=True, transform=SVHN_test_transform) if(feature=="classimb"): if("sel_cls_idx" in split_cfg): train_set, val_set, lake_set, imb_cls_idx = create_perclass_imb(dset_name, fullset, split_cfg, num_cls, augVal) else: train_set, val_set, lake_set, imb_cls_idx = create_class_imb(dset_name, fullset, split_cfg, num_cls, augVal) print("SVHN Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, imb_cls_idx, num_cls if(feature=="ood"): train_set, val_set, test_set, lake_set, ood_cls_idx = create_ood_data(dset_name, fullset, test_set, split_cfg, num_cls, augVal) print("SVHN Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set), "Test set: ", len(test_set)) return train_set, val_set, test_set, lake_set, ood_cls_idx, split_cfg['num_cls_idc'] if(feature=="vanilla"): X_tr, y_tr, X_val, y_val, X_unlabeled, y_unlabeled, train_set, val_set, lake_set = getVanillaData(dset_name, fullset, split_cfg) print("SVHN Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, num_cls if(feature=="duplicate"): X_tr, y_tr, X_val, y_val, X_unlabeled_rep, y_unlabeled_rep, train_set, val_set, lake_set = getDuplicateData(dset_name, fullset, split_cfg) print("SVHN Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, num_cls if(dset_name=="cifar100"): num_cls=100 cifar100_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) fullset = torchvision.datasets.CIFAR100(root=datadir, train=True, download=True, transform=cifar100_transform) test_set = torchvision.datasets.CIFAR100(root=datadir, train=False, download=True, transform=cifar100_transform) if(feature=="classimb"): if("sel_cls_idx" in split_cfg): train_set, val_set, lake_set, imb_cls_idx = create_perclass_imb(dset_name, fullset, split_cfg, num_cls, augVal) else: train_set, val_set, lake_set, imb_cls_idx = create_class_imb(dset_name, fullset, split_cfg, num_cls, augVal) print("CIFAR-100 Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, imb_cls_idx, num_cls if(feature=="ood"): train_set, val_set, test_set, lake_set, ood_cls_idx = create_ood_data(dset_name, fullset, test_set, split_cfg, num_cls, augVal) print("CIFAR-100 Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set), "Test set: ", len(test_set)) return train_set, val_set, test_set, lake_set, ood_cls_idx, num_cls if(feature=="vanilla"): X_tr, y_tr, X_val, y_val, X_unlabeled, y_unlabeled, train_set, val_set, lake_set = getVanillaData(dset_name, fullset, split_cfg) print("CIFAR-100 Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, num_cls if(feature=="duplicate"): X_tr, y_tr, X_val, y_val, X_unlabeled_rep, y_unlabeled_rep, train_set, val_set, lake_set = getDuplicateData(dset_name, fullset, split_cfg) print("CIFAR-100 Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, num_cls if(dset_name=="breast_density"): num_cls=4 input_size=224 data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(input_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'test': transforms.Compose([ transforms.Resize(input_size), transforms.CenterCrop(input_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } fullset = datasets.ImageFolder(os.path.join(datadir, 'train'), data_transforms['train']) test_set = datasets.ImageFolder(os.path.join(datadir, 'test'), data_transforms['test']) if(feature=="classimb"): train_set, val_set, lake_set, imb_cls_idx = create_perclass_imb(dset_name, fullset, split_cfg, num_cls, augVal) print("Breast-density Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set)) return train_set, val_set, test_set, lake_set, imb_cls_idx, num_cls