![]() |
VOOZH | about |
Linear Regression is a very commonly used statistical method that allows us to determine and study the relationship between two continuous variables. The various properties of linear regression and its Python implementation have been covered in this article previously. Now, we shall find out how to implement this in PyTorch, a very popular deep learning library that is being developed by Facebook.
Firstly, you will need to install PyTorch into your Python environment. The easiest way to do this is to use the pip or conda tool. Visit pytorch.org and install the version of your Python interpreter and the package manager that you would like to use.
With PyTorch installed, let us now have a look at the code.
Write the two lines given below to import the necessary library functions and objects.
We also define some data and assign them to variables x_data and y_data as given below:
Here, x_data is our independent variable and y_data is our dependent variable. This will be our dataset for now. Next, we need to define our model. There are two main steps associated with defining our model. They are:
We use the class given below:
As you can see, our Model class is a subclass of torch.nn.module. Also, since here we have only one input and one output, we use a Linear model with both the input and output dimension as 1.
Next, we create an object of this model.
After this, we select the optimizer and the loss criteria. Here, we will use the mean squared error (MSE) as our loss function and stochastic gradient descent (SGD) as our optimizer. Also, we arbitrarily fix a learning rate of 0.01.
We now arrive at our training step. We perform the following tasks 500 times during training:
Once the training is completed, we test if we are getting correct results using the model that we defined. So, we test it for an unknown value of x_data, in this case, 4.0.
If you performed all steps correctly, you will see that for input 4.0, you are getting a value that is very close to 8.0 as below. So, our model inherently learns the relationship between the input data and the output data without being programmed explicitly.
predict (after training) 4 7.966438293457031
For your reference, you can find the entire code of this article given below: