TorchUtils

Some handy utilities for pytorch

source

device_by_name

 device_by_name (name:str)

Return reference to cuda device by using Part of it’s name

Args: name: part of the cuda device name (shuuld be distinct)

Return: Reference to cuda device

Updated: Yuval 12/10/19

How to use

device_by_name("Tesla")
device(type='cuda', index=0)

If the device doesn’t exist we should get an error

error = False
try:
    device_by_name("fff")
except AssertionError:
    error = True
assert error

source

DatasetCat

 DatasetCat (*datasets)

Concatenate datasets for Pytorch dataloader

The normal pytorch implementation does it only for raws. this is a “column” implementation

Arges: datasets: list of datasets, of the same length

Updated: Yuval 12/10/2019

How to use

This is one dataset

dataset1=torch.utils.data.TensorDataset(torch.ones(5,1),torch.randn(5,1))
print(len(dataset1))
print (dataset1.__getitem__(0))
5
(tensor([1.]), tensor([-1.2270]))

This is the 2nd

dataset2=torch.utils.data.TensorDataset(torch.zeros(5,1),torch.randn(5,1))
print(len(dataset2))
print (dataset2.__getitem__(0))
5
(tensor([0.]), tensor([1.0632]))

And we will concat them row wise

dataset3 = DatasetCat(dataset1,dataset2)
print(len(dataset3))
print (dataset3.__getitem__(0))
assert dataset3.__getitem__(3) == (*dataset1.__getitem__(3),*dataset2.__getitem__(3))
assert len(dataset3) == len(dataset1)
5
(tensor([1.]), tensor([-1.2270]), tensor([0.]), tensor([1.0632]))