All you need guide for Pytorch Broadcasting

Jino Rohit
3 min readApr 26, 2023

--

Understand the in-depths of pytorch broadcasting

Broadcasting is a very fundamental yet important operation that you come across when you are dealing with matrices. This article discusses how broadcasting occurs along with the behind the scenes of the operation

broadcasting operation

What is broadcasting?

𝐁𝐫𝐨𝐚𝐝𝐜𝐚𝐬𝐭𝐒𝐧𝐠 is a concept that deals with array of different shapes are handled during any arithmetic operations. But there are few rules the arrays must satisfy for broadcasting to happen

  1. The array should have at least one dimension.
  2. When iterating from right to left, the dimension size must either be equal, or one of them should be 1, or one of them shouldn't exist.

Let’s take an example

What happens if we try to add tensors of shape (2 , 3) and (3,) ?

First let’s see if the broadcasting rules check out

  1. The array should have at least one dimension. βœ”οΈ
  2. When iterating from right to left, the (2, 3) matches with (, 3). βœ”οΈ

Now the smaller tensor b will expand over the missing dimension i.e. make copies across the rows and the addition happens

Is it really efficient to make a copy of the tensor like that? What if the tensor was really huge? This would mean a huge computational overhead. How is this tackled efficiently?

Striding

The actual data in a torch tensor is actually stored in a continuous block of memory, regardless of the multidimensionality.

This means if you create a tensor like this

tensor of size (3, 3)

The actual data is stored like this

tensor of shape (9, )

Striding is what gives the notion of multi-dimensionality to these blocks of data. It is basically the number of skips to skip past in the memory to move along a particular axis in the tensor. This is used by a lot of frameworks to efficiently retrieve data across the data structure.

You can use the tensor.stride() to check this in pytorch

This a tuple of the steps needed to skip to reach the next row and column.

3 denotes that we need to skip past 3 steps to reach the next row

1 denotes that we need to skip past 1 step to reach the next column

Remember the data is sequential

Now start at 1

  1. On adding 3 steps, we reach 4 i.e. the next row in the original tensor we created
  2. On adding 1 step, we reach 2 i.e. the next column

Take a moment to convince yourself this is true

And tada ! we can now broadcast any N dimensional shape without having to make copies across the dimensions, we simply have to use the notion of strides to make the operation more faster and efficient

Thank you for making it this far! You can find me on linkedin https://www.linkedin.com/in/jino-rohit-6032541b5/

--

--

Jino Rohit
Jino Rohit

Written by Jino Rohit

Neurons that fire together, wire together

Responses (1)