Pitfalls/gotchas
nn.CrossEntropyLoss
in PyTorch performs two operations: nn.LogSoftmax
and nn.NLLLoss
Hooks
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.cl1 = nn.Linear(25, 60)
self.cl2 = nn.Linear(60, 16)
self.fc1 = nn.Linear(16, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.cl1(x))
x = F.relu(self.cl2(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.log_softmax(self.fc3(x), dim=1)
return x
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
model = MyModel()
model.fc2.register_forward_hook(get_activation('fc2'))
Tensors
Dims
# implicitly selects : for remaining unspecified dims
torch.empty(3,4,5)[:,3].shape # 3,5
t[None] # insert a dimension at beginning
torch.empty(3,4)[:,None,:].shape # torch.Size([3, 1, 4])
torch.empty(3,4,5).t() # fails!
# permute is generalization of .t(), expects all dims listed
torch.empty(3,4,5).permute(2,0,1).shape # 5,3,4
# movedim is similar
torch.movedim(torch.empty(1,2,3,4),(1,0),(2,1)).shape # 3,1,2,4
torch.movedim(torch.empty(1,2,3,4),(1,2),(0,1)).shape # 2,3,1,4
torch.movedim(torch.empty(1,2,3,4),(0,1),(1,0)).shape # 2,1,3,4
torch.movedim(torch.empty(1,2,3,4),(0,2),(2,1)).shape # 2,3,1,4
# unsqueeze same as [None], and squeeze only applies to dims of 1
torch.empty(3,4).unsqueeze(0).shape # 1,3,4
torch.empty(3,4).unsqueeze(1).shape # 3,1,4
torch.empty(3,4).squeeze(0).shape # 3,4
torch.empty(1,3,4).squeeze(0).shape # 3,4
torch.empty(3,1,4).squeeze(1).shape # 3,4
Broadcasting rules
# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty( 3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty( 3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
>>> (torch.ones(2,1,4,5) @ torch.ones(3,5,6)).shape
torch.Size([2, 3, 4, 6])
Reshaping
# Returns a tensor with the same data and number of elements as self but with the specified shape. This method returns a view if shape is compatible with the current shape. See torch.Tensor.view() on when it is possible to return a view.
>>> torch.zeros(2,3,4).reshape(2,12).shape
torch.Size([2, 12])
>>> torch.zeros(2,3,4).reshape(3,8).shape
torch.Size([3, 8])
>>> torch.zeros(2,3,4).reshape(4,6).shape
torch.Size([4, 6])
>>> torch.zeros(2,3,4).reshape(3,-1).shape
torch.Size([3, 8])
>>> torch.zeros(2,3,4).reshape(4,-1).shape
torch.Size([4, 6])
# -1 means infer remaining dimensions
>>> torch.zeros(2,3,4).reshape(-1).shape
torch.Size([24])
>>> torch.zeros(2,3,4).reshape(2,-1).shape
torch.Size([2, 12])
>>> torch.zeros(2,3,4).reshape(3,-1).shape
torch.Size([3, 8])
>>> torch.zeros(2,3,4).view(-1,3).shape
torch.Size([8, 3])
>>> torch.zeros(2,3,4).reshape(2,1,3,4,1).shape
torch.Size([2, 1, 3, 4, 1])
# .view is similar but no copies
torch.empty(3,4).view(3,1,4).shape # 3,1,4
Reducing
>>> torch.ones(3,4,5).sum()
tensor(60.)
>>> torch.ones(3,4,5).sum(dim=0).shape
torch.Size([4, 5])
>>> torch.ones(3,4,5).sum(dim=1).shape
torch.Size([3, 5])
>>> torch.ones(3,4,5).sum(dim=2).shape
torch.Size([3, 4])
>>> torch.ones(3,4,5).sum(0,keepdim=True).shape
torch.Size([1, 4, 5])
>>> torch.ones(3,4,5).sum([0,1],keepdim=True).shape
torch.Size([1, 1, 5])
Combining
torch.cat([torch.empty(3,4), torch.empty(3,4)]).shape # 6,4
torch.cat([torch.empty(3,4), torch.empty(3,4)], dim=0).shape # 6,4
torch.cat([torch.empty(3,4), torch.empty(3,4)], dim=1).shape # 3,8
# stack adds dimension
# this `dim` specifies where to insert the new dimension
torch.stack([torch.empty(3,4), torch.empty(3,4)]).shape # 2,3,4
torch.stack([torch.empty(3,4), torch.empty(3,4)], dim=1).shape # 3,2,4
torch.stack([torch.empty(3,4), torch.empty(3,4)], dim=2).shape # 3,4,2
Splitting
>>> a = torch.arange(10).reshape(5, 2)
>>> a
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
[2, 3]]),
tensor([[4, 5],
[6, 7]]),
tensor([[8, 9]]))
>>> torch.split(a, [1, 4])
(tensor([[0, 1]]),
tensor([[2, 3],
[4, 5],
[6, 7],
[8, 9]]))
>>> [tuple(x.shape) for x in torch.zeros(6).chunk(2)]
[(3,), (3,)]
>>> [tuple(x.shape) for x in torch.zeros(6,5).chunk(2)]
[(3, 5), (3, 5)]
>>> [tuple(x.shape) for x in torch.zeros(5,5).chunk(2)]
[(3, 5), (2, 5)]
>>> ip_tensor=torch.tensor([[1,2,3],[4,5,6]])
>>> torch.unbind(ip_tensor,dim=0)
(tensor([1, 2, 3]), tensor([4, 5, 6]))
>>> torch.unbind(ip_tensor,dim=1)
(tensor([1, 4]), tensor([2, 5]), tensor([3, 6]))
Resources
Gather, scatter are not broadcastable - you need to make the dimensions all align first. Maybe this is for safety.
Gather: read/index along a certain dimension
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Example: index along specific dimension, i.e. one col per row, e.g. get the logits/scores for the correct word/class per example, e.g. for calculating cross entropy loss
# similar task/expression that requires hardcoding the shape:
# cross_entropy_loss = -logprobs[range(n), Yb].mean()
num_examples = 5
num_classes = 3
scores = tensor([
[ 1.5410, -0.2934, -2.1788],
[ 0.5684, -1.0845, -1.3986],
[ 0.4033, 0.8380, -0.7193],
[-0.4033, -0.5966, 0.1820],
[-0.8567, 1.1006, -1.0712]])
y = torch.LongTensor([1, 2, 1, 0, 2])
res = scores.gather(1, y.view(-1, 1)).squeeze()
# tensor([-0.2934, -1.3986, 0.8380, -0.4033, -1.0712])
Example: switch between lo and hi in stochastic rounding
def sround(x,p): # p=precision
lo=torch.floor(x*(2**p))/2**p
hi=torch.ceil(x*(2**p))/2**p
plo=(hi-x)/(hi-lo)
islo = torch.rand(x.shape)<plo
a=torch.stack((hi,lo))
return a.gather(0,islo.to(torch.int64)[None,...]).squeeze(0)
Scatter: write/place along a certain dimension
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
Example: reverse / backprop the above flows, where you want to scatter some gradient back to the corresponding one-hot element that it came from—e.g. for cross entropy loss, or for softmax’s normalization (logits.max(1)
)
# forward:
logit_maxes = logits.max(1,keepdim=True).values
# backward:
dlogits = torch.scatter(logits.max(1).indices, dlogit_maxes)
# alt: dlogits = F.one_hot(logits.max (1).indices, num_classes=logits.shape[1]) * dlogit_maxes
Example: in backprop, placing the gradient back to the embedding
# forward:
emb = C[Xb] # embed the characters into vectors
# backward:
dC = torch.zeros_like(C)
for i in range(Xb.shape[0]):
for j in range(Xb.shape[1]):
dC[Xb[i,j]] += demb[i,j]
Convolutions
>>>
>>> ## basics
>>> F.conv1d(T([[0,1,2,3]]), T([[[1]]]))
tensor([[0., 1., 2., 3.]])
>>> F.conv1d(T([[[0,1,2,3]]]), T([[[1]]]))
tensor([[[0., 1., 2., 3.]]])
>>> F.conv1d(T([[0,1,2,3]]), T([[[1,1,1]]]))
tensor([[3., 6.]])
>>>
>>> ## paddings
>>> F.conv1d(T([[0,1,2,3]]), T([[[1,1,1]]]), padding=1)
tensor([[1., 3., 6., 5.]])
>>> F.conv1d(T([[0,1,2,3]]), T([[[1,1,1]]]), padding=2)
tensor([[0., 1., 3., 6., 5., 3.]])
>>> # can actually be >= kernel size!
>>> F.conv1d(T([[0,1,2,3]]), T([[[1,1,1]]]), padding=3)
tensor([[0., 0., 1., 3., 6., 5., 3., 0.]])
>>> F.conv1d(T([[0,1,2,3]]), T([[[1,1,1]]]), padding=4)
tensor([[0., 0., 0., 1., 3., 6., 5., 3., 0., 0.]])
>>>
>>> ## multi-kernel: 2 out dims
>>> F.conv1d(T([[0,1,2,3]]), T([[[1,2,3]],[[1,1,1]]]))
tensor([[ 8., 14.],
[ 3., 6.]])
>>>
>>> ## multi-channel input 3xN - channel must be outer/left dim, and conv sums them
>>> F.conv1d(T([[0,1,2,3],[4,5,6,7],[8,9,10,11]]), T([[[1],[1],[1]]]))
tensor([[12., 15., 18., 21.]])
>>>
>>> ## combo: 3 in, 2 out, size-1 kernel
>>> F.conv1d(T([[0,1,2,3],[4,5,6,7],[8,9,10,11]]), T([[[1],[2],[3]],[[1],[1],[1]]]))
tensor([[32., 38., 44., 50.],
[12., 15., 18., 21.]])
Einsum
# copied from docs
>>> # trace - sum of diagonal
>>> torch.einsum('ii', torch.randn(4, 4))
tensor(-1.2104)
>>> # diagonal
>>> torch.einsum('ii->i', torch.randn(4, 4))
tensor([-0.1034, 0.7952, -0.2433, 0.4545])
>>> # outer product
>>> x = torch.randn(5)
>>> y = torch.randn(4)
>>> torch.einsum('i,j->ij', x, y)
tensor([[ 0.1156, -0.2897, -0.3918, 0.4963],
[-0.3744, 0.9381, 1.2685, -1.6070],
[ 0.7208, -1.8058, -2.4419, 3.0936],
[ 0.1713, -0.4291, -0.5802, 0.7350],
[ 0.5704, -1.4290, -1.9323, 2.4480]])
>>> # batch matrix multiplication
>>> As = torch.randn(3, 2, 5)
>>> Bs = torch.randn(3, 5, 4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]],
[[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]],
[[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
>>> # with sublist format and ellipsis
>>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]],
[[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]],
[[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
>>> # batch permute
>>> A = torch.randn(2, 3, 4, 5)
>>> torch.einsum('...ij->...ji', A).shape
torch.Size([2, 3, 5, 4])
>>> # equivalent to torch.nn.functional.bilinear
>>> A = torch.randn(3, 5, 4)
>>> l = torch.randn(2, 5)
>>> r = torch.randn(2, 4)
>>> torch.einsum('bn,anm,bm->ba', l, A, r)
tensor([[-0.3430, -5.2405, 0.4494],
[ 0.3311, 5.5201, -3.0356]])
Indexing
# base
>>> xs
tensor([[0, 1, 2],
[3, 4, 5]])
# full index gets 1 element
>>> xs[0,0]
tensor(0)
# can use negative indexing
>>> xs[-1,-1]
tensor(5)
# partial index gets slice, drops dim(s)
>>> xs[0]
tensor([0, 1, 2])
>>> xs[0,:]
tensor([0, 1, 2])
>>> xs[0][0]
tensor(0)
>>> xs[:]
tensor([[0, 1, 2],
[3, 4, 5]])
# multiple indexes, can repeat, keeps dim
>>> xs[[0],:]
tensor([[0, 1, 2]])
>>> xs[[-1],:]
tensor([[3, 4, 5]])
>>> xs[[0,1,0,1]]
tensor([[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5]])
# multiple indexes over multiple dimensions are *zipped*!
>>> xs[[0,1,0,1],[0,1,0,1]]
tensor([0, 4, 0, 4])
# above is different from this, which gives a subtensor
>>> xs[:,[0,1,0,1]]
tensor([[0, 1, 0, 1],
[3, 4, 3, 4]])
# can also select with bools
>>> xs[[False,True]]
tensor([[3, 4, 5]])
>>> xs[[False,True],[False,True,True]]
tensor([4, 5])
>>> xs[:,[False,True,True]]
tensor([[1, 2],
[4, 5]])
# selecting with tensors
>>> torch.arange(6).view(2,3)[torch.zeros(2,2).int()]
tensor([[[0, 1, 2],
[0, 1, 2]],
[[0, 1, 2],
[0, 1, 2]]])
Index assignment
>>> ys=torch.arange(6).view(2,3)
# can assign to a slice
>>> xs=ys[:]; xs[:,0]=7; xs
tensor([[7, 1, 2],
[7, 4, 5]])
# partial slice implies : on remaining dims
>>> xs=ys[:]; xs[:]=7; xs
tensor([[7, 7, 7],
[7, 7, 7]])
# can broadcast
>>> xs=ys[:]; xs[:] = torch.arange(3).view(1,3); xs
tensor([[0, 1, 2],
[0, 1, 2]])
>>> xs=ys[:]; xs[:,:]=torch.arange(3).view(1,3); xs
tensor([[0, 1, 2],
[0, 1, 2]])
Autograd
Misc
Einops
rearrange
# suppose we have a set of 32 images in "h w c" format (height-width-channel)
>>> images = [np.random.randn(30, 40, 3) for _ in range(32)]
# stack along first (batch) axis, output is a single array
>>> rearrange(images, 'b h w c -> b h w c').shape
(32, 30, 40, 3)
# concatenate images along height (vertical axis), 960 = 32 * 30
>>> rearrange(images, 'b h w c -> (b h) w c').shape
(960, 40, 3)
# concatenated images along horizontal axis, 1280 = 32 * 40
>>> rearrange(images, 'b h w c -> h (b w) c').shape
(30, 1280, 3)
# reordered axes to "b c h w" format for deep learning
>>> rearrange(images, 'b h w c -> b c h w').shape
(32, 3, 30, 40)
# flattened each image into a vector, 3600 = 30 * 40 * 3
>>> rearrange(images, 'b h w c -> b (c h w)').shape
(32, 3600)
# split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
(128, 15, 20, 3)
# space-to-depth operation
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
(32, 15, 20, 12)
# use ellipsis
rearrange(attn_mask, '... -> ... 1')
repeat
Distributed
DDP
RPC: supports sending remote object refs
Point-to-point send/recv
Distributed “collective”/group ops (like scatter/gather/all-reduce—implemented with NCCL
Built-in DDP and FSDP (ZeRO)
Autograd is handled
import torch
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
Activation checkpointing
Compiler
Build
Performance optimization
When evaluating model
torch.func: JAX-style higher order functions that take a function (that runs pytorch code) and return a function
Misc features
Ecosystem