• Pitfalls/gotchas

    • x.T is no-op when d>2
    • The most common mistake is the mismatch between loss function and output activation function. The loss module nn.CrossEntropyLoss in PyTorch performs two operations: nn.LogSoftmax and nn.NLLLoss (see below)
    • Use BinaryAUROC not AUC in torcheval
  • Hooks

    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            self.cl1 = nn.Linear(25, 60)
            self.cl2 = nn.Linear(60, 16)
            self.fc1 = nn.Linear(16, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
            
        def forward(self, x):
            x = F.relu(self.cl1(x))
            x = F.relu(self.cl2(x))
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = F.log_softmax(self.fc3(x), dim=1)
            return x
    
    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook
    
    model = MyModel()
    model.fc2.register_forward_hook(get_activation('fc2'))
    
  • Tensors

    • Dims

      # implicitly selects : for remaining unspecified dims
      torch.empty(3,4,5)[:,3].shape # 3,5
      
      t[None] # insert a dimension at beginning
      
      torch.empty(3,4)[:,None,:].shape # torch.Size([3, 1, 4])
      
      torch.empty(3,4,5).t() # fails!
      
      # permute is generalization of .t(), expects all dims listed
      torch.empty(3,4,5).permute(2,0,1).shape # 5,3,4
      
      # movedim is similar
      torch.movedim(torch.empty(1,2,3,4),(1,0),(2,1)).shape # 3,1,2,4
      torch.movedim(torch.empty(1,2,3,4),(1,2),(0,1)).shape # 2,3,1,4
      torch.movedim(torch.empty(1,2,3,4),(0,1),(1,0)).shape # 2,1,3,4
      torch.movedim(torch.empty(1,2,3,4),(0,2),(2,1)).shape # 2,3,1,4
      
      # unsqueeze same as [None], and squeeze only applies to dims of 1
      torch.empty(3,4).unsqueeze(0).shape # 1,3,4
      torch.empty(3,4).unsqueeze(1).shape # 3,1,4
      torch.empty(3,4).squeeze(0).shape # 3,4
      torch.empty(1,3,4).squeeze(0).shape # 3,4
      torch.empty(3,1,4).squeeze(1).shape # 3,4
      
    • Broadcasting rules

      • If the number of dimensions of x and y are not equal, prepend 1 to the dimensions of the tensor with fewer dimensions to make them equal length.
      • Then, for each dimension size, the resulting dimension size is the max of the sizes of x and y along that dimension.
      • Except e.g. matmul
      # can line up trailing dimensions to make reading easier
      >>> x=torch.empty(5,1,4,1)
      >>> y=torch.empty(  3,1,1)
      >>> (x+y).size()
      torch.Size([5, 3, 4, 1])
      
      >>> x=torch.empty(5,2,4,1)
      >>> y=torch.empty(  3,1,1)
      >>> (x+y).size()
      RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
      
      >>> (torch.ones(2,1,4,5) @ torch.ones(3,5,6)).shape
      torch.Size([2, 3, 4, 6])
      
    • Reshaping

      # Returns a tensor with the same data and number of elements as self but with the specified shape. This method returns a view if shape is compatible with the current shape. See torch.Tensor.view() on when it is possible to return a view.
      
      >>> torch.zeros(2,3,4).reshape(2,12).shape
      torch.Size([2, 12])
      >>> torch.zeros(2,3,4).reshape(3,8).shape
      torch.Size([3, 8])
      >>> torch.zeros(2,3,4).reshape(4,6).shape
      torch.Size([4, 6])
      >>> torch.zeros(2,3,4).reshape(3,-1).shape
      torch.Size([3, 8])
      >>> torch.zeros(2,3,4).reshape(4,-1).shape
      torch.Size([4, 6])
      
      # -1 means infer remaining dimensions
      >>> torch.zeros(2,3,4).reshape(-1).shape
      torch.Size([24])
      >>> torch.zeros(2,3,4).reshape(2,-1).shape
      torch.Size([2, 12])
      >>> torch.zeros(2,3,4).reshape(3,-1).shape
      torch.Size([3, 8])
      >>> torch.zeros(2,3,4).view(-1,3).shape
      torch.Size([8, 3])
      
      >>> torch.zeros(2,3,4).reshape(2,1,3,4,1).shape
      torch.Size([2, 1, 3, 4, 1])
      
      # .view is similar but no copies
      torch.empty(3,4).view(3,1,4).shape # 3,1,4
      
    • Reducing

      >>> torch.ones(3,4,5).sum()
      tensor(60.)
      >>> torch.ones(3,4,5).sum(dim=0).shape
      torch.Size([4, 5])
      >>> torch.ones(3,4,5).sum(dim=1).shape
      torch.Size([3, 5])
      >>> torch.ones(3,4,5).sum(dim=2).shape
      torch.Size([3, 4])
      >>> torch.ones(3,4,5).sum(0,keepdim=True).shape
      torch.Size([1, 4, 5])
      >>> torch.ones(3,4,5).sum([0,1],keepdim=True).shape
      torch.Size([1, 1, 5])
      
      
    • Combining

      torch.cat([torch.empty(3,4), torch.empty(3,4)]).shape # 6,4
      torch.cat([torch.empty(3,4), torch.empty(3,4)], dim=0).shape # 6,4
      torch.cat([torch.empty(3,4), torch.empty(3,4)], dim=1).shape # 3,8
      
      # stack adds dimension
      # this `dim` specifies where to insert the new dimension
      torch.stack([torch.empty(3,4), torch.empty(3,4)]).shape # 2,3,4
      torch.stack([torch.empty(3,4), torch.empty(3,4)], dim=1).shape # 3,2,4
      torch.stack([torch.empty(3,4), torch.empty(3,4)], dim=2).shape # 3,4,2
      
    • Splitting

      >>> a = torch.arange(10).reshape(5, 2)
      >>> a
      tensor([[0, 1],
              [2, 3],
              [4, 5],
              [6, 7],
              [8, 9]])
      >>> torch.split(a, 2)
      (tensor([[0, 1],
               [2, 3]]),
       tensor([[4, 5],
               [6, 7]]),
       tensor([[8, 9]]))
      >>> torch.split(a, [1, 4])
      (tensor([[0, 1]]),
       tensor([[2, 3],
               [4, 5],
               [6, 7],
               [8, 9]]))
      
      >>> [tuple(x.shape) for x in torch.zeros(6).chunk(2)]
      [(3,), (3,)]
      >>> [tuple(x.shape) for x in torch.zeros(6,5).chunk(2)]
      [(3, 5), (3, 5)]
      >>> [tuple(x.shape) for x in torch.zeros(5,5).chunk(2)]
      [(3, 5), (2, 5)]
      
      >>> ip_tensor=torch.tensor([[1,2,3],[4,5,6]])
      >>> torch.unbind(ip_tensor,dim=0)
      (tensor([1, 2, 3]), tensor([4, 5, 6]))
      >>> torch.unbind(ip_tensor,dim=1)
      (tensor([1, 4]), tensor([2, 5]), tensor([3, 6]))
      
    • Resources

      • Mastering Tensors in PyTorch: A Comprehensive Guide | by Soorya Narayan Satheesh | Jan, 2024 | Medium
  • Gather, scatter are not broadcastable - you need to make the dimensions all align first. Maybe this is for safety.

    • Gather: read/index along a certain dimension

      out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
      out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
      out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
      
      • Example: index along specific dimension, i.e. one col per row, e.g. get the logits/scores for the correct word/class per example, e.g. for calculating cross entropy loss

        # similar task/expression that requires hardcoding the shape:
        # cross_entropy_loss = -logprobs[range(n), Yb].mean()
        
        num_examples = 5
        num_classes = 3
        scores = tensor([
                [ 1.5410, -0.2934, -2.1788],
                [ 0.5684, -1.0845, -1.3986],
                [ 0.4033,  0.8380, -0.7193],
                [-0.4033, -0.5966,  0.1820],
                [-0.8567,  1.1006, -1.0712]])
        y = torch.LongTensor([1, 2, 1, 0, 2])
        res = scores.gather(1, y.view(-1, 1)).squeeze()
        # tensor([-0.2934, -1.3986,  0.8380, -0.4033, -1.0712])
        
      • Example: switch between lo and hi in stochastic rounding

        def sround(x,p): # p=precision
            lo=torch.floor(x*(2**p))/2**p
            hi=torch.ceil(x*(2**p))/2**p
            plo=(hi-x)/(hi-lo)
            islo = torch.rand(x.shape)<plo
            a=torch.stack((hi,lo))
            return a.gather(0,islo.to(torch.int64)[None,...]).squeeze(0)
        
    • Scatter: write/place along a certain dimension

      self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
      self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
      self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
      
      • Example: reverse / backprop the above flows, where you want to scatter some gradient back to the corresponding one-hot element that it came from—e.g. for cross entropy loss, or for softmax’s normalization (logits.max(1))

        # forward:
        logit_maxes = logits.max(1,keepdim=True).values
        
        # backward:
        dlogits = torch.scatter(logits.max(1).indices, dlogit_maxes)
        
        # alt: dlogits = F.one_hot(logits.max (1).indices, num_classes=logits.shape[1]) * dlogit_maxes
        
      • Example: in backprop, placing the gradient back to the embedding

        # forward:
        emb = C[Xb] # embed the characters into vectors
        
        # backward:
        dC = torch.zeros_like(C)
        for i in range(Xb.shape[0]):
          for j in range(Xb.shape[1]):
            dC[Xb[i,j]] += demb[i,j]
        
  • Convolutions

    >>> 
    >>> ## basics
    >>> F.conv1d(T([[0,1,2,3]]), T([[[1]]]))
    tensor([[0., 1., 2., 3.]])
    
    >>> F.conv1d(T([[[0,1,2,3]]]), T([[[1]]]))
    tensor([[[0., 1., 2., 3.]]])
    
    >>> F.conv1d(T([[0,1,2,3]]), T([[[1,1,1]]]))
    tensor([[3., 6.]])
    
    >>> 
    >>> ## paddings
    >>> F.conv1d(T([[0,1,2,3]]), T([[[1,1,1]]]), padding=1)
    tensor([[1., 3., 6., 5.]])
    
    >>> F.conv1d(T([[0,1,2,3]]), T([[[1,1,1]]]), padding=2)
    tensor([[0., 1., 3., 6., 5., 3.]])
    
    >>> # can actually be >= kernel size!
    >>> F.conv1d(T([[0,1,2,3]]), T([[[1,1,1]]]), padding=3)
    tensor([[0., 0., 1., 3., 6., 5., 3., 0.]])
    
    >>> F.conv1d(T([[0,1,2,3]]), T([[[1,1,1]]]), padding=4)
    tensor([[0., 0., 0., 1., 3., 6., 5., 3., 0., 0.]])
    
    >>> 
    >>> ## multi-kernel: 2 out dims
    >>> F.conv1d(T([[0,1,2,3]]), T([[[1,2,3]],[[1,1,1]]]))
    tensor([[ 8., 14.],
            [ 3.,  6.]])
    
    >>> 
    >>> ## multi-channel input 3xN - channel must be outer/left dim, and conv sums them
    >>> F.conv1d(T([[0,1,2,3],[4,5,6,7],[8,9,10,11]]), T([[[1],[1],[1]]]))
    tensor([[12., 15., 18., 21.]])
    
    >>> 
    >>> ## combo: 3 in, 2 out, size-1 kernel
    >>> F.conv1d(T([[0,1,2,3],[4,5,6,7],[8,9,10,11]]), T([[[1],[2],[3]],[[1],[1],[1]]]))
    tensor([[32., 38., 44., 50.],
            [12., 15., 18., 21.]])
    
  • Einsum

    # copied from docs
    >>> # trace - sum of diagonal
    >>> torch.einsum('ii', torch.randn(4, 4))
    tensor(-1.2104)
    
    >>> # diagonal
    >>> torch.einsum('ii->i', torch.randn(4, 4))
    tensor([-0.1034,  0.7952, -0.2433,  0.4545])
    
    >>> # outer product
    >>> x = torch.randn(5)
    >>> y = torch.randn(4)
    >>> torch.einsum('i,j->ij', x, y)
    tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
            [-0.3744,  0.9381,  1.2685, -1.6070],
            [ 0.7208, -1.8058, -2.4419,  3.0936],
            [ 0.1713, -0.4291, -0.5802,  0.7350],
            [ 0.5704, -1.4290, -1.9323,  2.4480]])
    
    >>> # batch matrix multiplication
    >>> As = torch.randn(3, 2, 5)
    >>> Bs = torch.randn(3, 5, 4)
    >>> torch.einsum('bij,bjk->bik', As, Bs)
    tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
            [-1.6706, -0.8097, -0.8025, -2.1183]],
    
            [[ 4.2239,  0.3107, -0.5756, -0.2354],
            [-1.4558, -0.3460,  1.5087, -0.8530]],
    
            [[ 2.8153,  1.8787, -4.3839, -1.2112],
            [ 0.3728, -2.1131,  0.0921,  0.8305]]])
    
    >>> # with sublist format and ellipsis
    >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
    tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
            [-1.6706, -0.8097, -0.8025, -2.1183]],
    
            [[ 4.2239,  0.3107, -0.5756, -0.2354],
            [-1.4558, -0.3460,  1.5087, -0.8530]],
    
            [[ 2.8153,  1.8787, -4.3839, -1.2112],
            [ 0.3728, -2.1131,  0.0921,  0.8305]]])
    
    >>> # batch permute
    >>> A = torch.randn(2, 3, 4, 5)
    >>> torch.einsum('...ij->...ji', A).shape
    torch.Size([2, 3, 5, 4])
    
    >>> # equivalent to torch.nn.functional.bilinear
    >>> A = torch.randn(3, 5, 4)
    >>> l = torch.randn(2, 5)
    >>> r = torch.randn(2, 4)
    >>> torch.einsum('bn,anm,bm->ba', l, A, r)
    tensor([[-0.3430, -5.2405,  0.4494],
            [ 0.3311,  5.5201, -3.0356]])
    
  • Indexing

    # base
    >>> xs
    tensor([[0, 1, 2],
            [3, 4, 5]])
    
    # full index gets 1 element
    >>> xs[0,0]
    tensor(0)
    # can use negative indexing
    >>> xs[-1,-1]
    tensor(5)
    
    # partial index gets slice, drops dim(s)
    >>> xs[0]
    tensor([0, 1, 2])
    >>> xs[0,:]
    tensor([0, 1, 2])
    >>> xs[0][0]
    tensor(0)
    >>> xs[:]
    tensor([[0, 1, 2],
            [3, 4, 5]])
    
    # multiple indexes, can repeat, keeps dim
    >>> xs[[0],:]
    tensor([[0, 1, 2]])
    >>> xs[[-1],:]
    tensor([[3, 4, 5]])
    >>> xs[[0,1,0,1]]
    tensor([[0, 1, 2],
            [3, 4, 5],
            [0, 1, 2],
            [3, 4, 5]])
    # multiple indexes over multiple dimensions are *zipped*!
    >>> xs[[0,1,0,1],[0,1,0,1]]
    tensor([0, 4, 0, 4])
    # above is different from this, which gives a subtensor
    >>> xs[:,[0,1,0,1]]
    tensor([[0, 1, 0, 1],
            [3, 4, 3, 4]])
    
    # can also select with bools
    >>> xs[[False,True]]
    tensor([[3, 4, 5]])
    >>> xs[[False,True],[False,True,True]]
    tensor([4, 5])
    >>> xs[:,[False,True,True]]
    tensor([[1, 2],
            [4, 5]])
    
    # selecting with tensors
    >>> torch.arange(6).view(2,3)[torch.zeros(2,2).int()]
    tensor([[[0, 1, 2],
             [0, 1, 2]],
    
            [[0, 1, 2],
             [0, 1, 2]]])
    
  • Index assignment

    >>> ys=torch.arange(6).view(2,3)
    
    # can assign to a slice
    >>> xs=ys[:]; xs[:,0]=7; xs
    tensor([[7, 1, 2],
            [7, 4, 5]])
    
    # partial slice implies : on remaining dims
    >>> xs=ys[:]; xs[:]=7; xs
    tensor([[7, 7, 7],
            [7, 7, 7]])
    
    # can broadcast
    >>> xs=ys[:]; xs[:] = torch.arange(3).view(1,3); xs
    tensor([[0, 1, 2],
            [0, 1, 2]])
    >>> xs=ys[:]; xs[:,:]=torch.arange(3).view(1,3); xs
    tensor([[0, 1, 2],
            [0, 1, 2]])
    
  • Autograd

    • Freeze parameters by setting requires grad to false, but that prop will still flow through them to earlier dependencies! https://stackoverflow.com/questions/72665429/pytorch-if-you-detach-a-nn-module-in-the-middle-of-a-network-do-all-the-module
    • Requires grad is propagated to intermediate computed tensors. Usually is false on input tensors with raw data values.
    • t.detach creates a tensor that shares storage with the original tensor but does not require grad
    • with no grad disables requiring grad on all tensors created within the block
    • https://stackoverflow.com/questions/56816241/difference-between-detach-and-with-torch-nograd-in-pytorch
  • Misc

    • nn.Parameter is a subclass of nn.Variable so most behaviors are the same. The most important difference is that if you use nn.Parameter in a nn.Module's constructor, it will be added into the modules parameters just like nn.Module object do.
    • Variables are not needed anymore. You can simply use Tensors. And a Parameters is a specific Tensor that is marked as being a parameter from an nn.Module and so will be returned when calling .parameters() on this Module.
  • Einops

    • rearrange

      # suppose we have a set of 32 images in "h w c" format (height-width-channel)
      >>> images = [np.random.randn(30, 40, 3) for _ in range(32)]
      
      # stack along first (batch) axis, output is a single array
      >>> rearrange(images, 'b h w c -> b h w c').shape
      (32, 30, 40, 3)
      
      # concatenate images along height (vertical axis), 960 = 32 * 30
      >>> rearrange(images, 'b h w c -> (b h) w c').shape
      (960, 40, 3)
      
      # concatenated images along horizontal axis, 1280 = 32 * 40
      >>> rearrange(images, 'b h w c -> h (b w) c').shape
      (30, 1280, 3)
      
      # reordered axes to "b c h w" format for deep learning
      >>> rearrange(images, 'b h w c -> b c h w').shape
      (32, 3, 30, 40)
      
      # flattened each image into a vector, 3600 = 30 * 40 * 3
      >>> rearrange(images, 'b h w c -> b (c h w)').shape
      (32, 3600)
      
      # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
      >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
      (128, 15, 20, 3)
      
      # space-to-depth operation
      >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
      (32, 15, 20, 12)
      
      # use ellipsis
      rearrange(attn_mask, '... -> ... 1')
      
    • repeat

  • Math

    • NLL vs cross entropy (source)

      torch.nn.functional.cross_entropy function combines log_softmax(softmax followed by a logarithm) and nll_loss(negative log likelihood loss) in a single function, i.e. it is equivalent to F.nll_loss(F.log_softmax(x, 1), y).

      Code:

      x = torch.FloatTensor([[1.,0.,0.],
                             [0.,1.,0.],
                             [0.,0.,1.]])
      y = torch.LongTensor([0,1,2])
      
      print(torch.nn.functional.cross_entropy(x, y))
      
      print(F.softmax(x, 1).log())
      print(F.log_softmax(x, 1))
      
      print(F.nll_loss(F.log_softmax(x, 1), y))
      
      

      output:

      tensor(0.5514)
      tensor([[-0.5514, -1.5514, -1.5514],
              [-1.5514, -0.5514, -1.5514],
              [-1.5514, -1.5514, -0.5514]])
      tensor([[-0.5514, -1.5514, -1.5514],
              [-1.5514, -0.5514, -1.5514],
              [-1.5514, -1.5514, -0.5514]])
      tensor(0.5514)
      
      
  • Distributed

    • DDP

      • auto broadcasts initial weights in DDP constructor, ignoring effect of different random seeds (source, source)
      • registers an autograd hook for each parameter given by model.parameters() and the hook will fire when the corresponding gradient is computed in the backward pass. Then DDP uses that signal to trigger gradient synchronization across processes.
    • Groups

    • RPC: supports sending remote object refs

    • Point-to-point send/recv

      Untitled

    • Distributed “collective”/group ops (like scatter/gather/all-reduce—implemented with NCCL)—see Distributed and parallel computation

    • Built-in DDP and FSDP (ZeRO)

    • Autograd is handled

      import torch
      import torch.distributed.rpc as rpc
      
      def my_add(t1, t2):
        return torch.add(t1, t2)
      
      # On worker 0:
      t1 = torch.rand((3, 3), requires_grad=True)
      t2 = torch.rand((3, 3), requires_grad=True)
      
      # Perform some computation remotely.
      t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
      
      # Perform some computation locally based on remote result.
      t4 = torch.rand((3, 3), requires_grad=True)
      t5 = torch.mul(t3, t4)
      
      # Compute some loss.
      loss = t5.sum()
      

      Untitled

  • Activation checkpointing

    • Think of the API as similar to wrapping blocks of code in transaction. This prevents storing activations anywhere throughout the transaction as if no grad was set, until the very end.
    • There's also a simpler checkpoint sequential API that gives you a sequential block and lets you choose the number of layers to use as the interval between checkpoints
  • Compiler

    • Torch compile uses torch Dynamo back-end, deprecates torch.jit both tracing and scripting https://dev-discuss.pytorch.org/t/the-nuances-of-pytorch-graph-capture/501/9
    • Able to compile much more python than torch script https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
    • Jit comes in either script or tracing, achieves roughly 2 to 3x speed up
    • Compute graph needs to be static; dynamic tensor shapes will break the graph and cause deopt
    • Dynamo
  • Build

  • Performance optimization

  • When evaluating model

  • torch.func: JAX-style higher order functions that take a function (that runs pytorch code) and return a function

  • Misc features