VOOZH about

URL: https://towardsdatascience.com/fit-vs-predict-vs-fit-predict-in-python-scikit-learn-f15a34a8d39f/

⇱ fit() vs predict() vs fit_predict() in Python scikit-learn | Towards Data Science


fit() vs predict() vs fit_predict() in Python scikit-learn

What's the difference between fit, predict and fit_predict methods in sklearn

5 min read
👁 Photo by Kelly Sikkema on Unsplash
Photo by Kelly Sikkema on Unsplash

scikit-learn (or commonly referred to as sklearn) is probably one of the most powerful and widely used Machine Learning libraries in Python. It comes with a comprehensive set of tools and ready-to-train models – from pre-processing utilities, to model training and model evaluation utilities.

Many sklearn objects, implement three specific methods namely fit(), predict()and fit_predict(). Essentially, they are conventions applied in scikit-learn and its API. In this article, we are going to explore how each of these work and when to use one over the other.

Note that in this article we are going to explore the aforementioned functions using specific examples, but the concepts explained here are applicable to most (if not all) objects that implement these methods.


Subscribe to Data Pipeline, a newsletter dedicated to Data Engineering


Before explaining the intuition behind fit(), predict()and fit_predict() , it is important to first understand what an estimator is in scikit-learn API. The reason why we need to know about estimators is simply because such objects implement the methods we are interested in.

What are estimators in scikit-learn

In scikit-learn, an estimator is an object that fits a model based on the input data (i.e. training data) and performs specific calculations that correspond to properties on new, unseen data. In other words, an estimator can be a regressor or a classifier.

The library comes with the base class [sklearn.base.BaseEstimator](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html#) and all estimators should inherit from that class. The base class comes with two methods, namely get_params() and set_params() that can be used to get and set the parameters of an estimator respectively. Note that the estimators must explicitly provide all of their parameters in the constructor method (i.e. the __init__ method).


What does fit() do

fit() is implemented by every estimator and it accepts an input for the sample data (X) and for supervised models it also accepts an argument for labels (i.e. target data y ). Optionally, it can also accept additional sample properties such as weights etc.

fit methods are usually responsible for numerous operations. Typically, they should start by clearing any attributes already stored on the estimator and then perform parameter and data validation. They are also responsible for estimating the attributes out of the input data and store the model attributes and finally return the fitted estimator.

Now as an example, let’s consider a classification problem where we need to train a SVC model to recognise hand-written images. In the code below, we first load our data and then split it into training and testing sets. Then we instantiate a SVC classifier and finally call fit() to train the model using the input training and data.

[fit](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html#sklearn.svm.SVC.fit)(X, y, _sampleweight=None): Fit the SVM model according to the given training data.

X – Training vectors, where n_samples is the number of samples and n_features is the number of features.

y – Target values (class labels in classification, real numbers in regression).

sample_weight – Per-sample weights. Rescale C per sample. Higher weights force the classifier to put more emphasis on these points.

Now that we have successfully trained our model, we can now access the fitted parameters, as shown below:

Note that every estimator might have different parameters that you can access once the model is fitted. You can find which parameters you can access in the official documentation and in the ‘Attributes’ section of the specific estimator you are working with. Typically, fitted parameters use an underscore _ as a suffix. For the SVC classifier in particular, you can find the available fitted parameters in this section of the documentation.

What does predict() do

Now that we have trained our model, the next step typically involves predictions over the testing set. To do so, we need to call the method predict() that will essentially use the learned parameters by fit() in order to perform predictions on new, unseen test data points.

Essentially, predict() will perform a prediction for each test instance and it usually accepts only a single input (X). For classifiers and regressors, the predicted value will be in the same space as the one seen in training set. In clustering estimators, the predicted value will be an integer. The predicted values of the provided test instances will be returned in a form of an output of an array or sparse matrix.

Note that if you attempt to run predict() without first executing fit() you will receive a [exceptions.NotFittedError](https://scikit-learn.org/stable/modules/generated/sklearn.exceptions.NotFittedError.html#sklearn.exceptions.NotFittedError), as shown below.

What does fit_predict() do

Going forward, fit_predict() is more relevant to unsupervised or transductive estimators. Essentially, this method will fit and perform predictions over training data thus, is more appropriate when performing operations such as clustering.

[fit_transform](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html?highlight=k%20means#sklearn.cluster.KMeans.fit_transform)(X, y=None, _sampleweight=None) Compute clustering and transform X to cluster-distance space. Equivalent to fit(X).transform(X), but more efficiently implemented.

Note that

  • clustering estimators in scikit-learn must implement fit_predict() method but not all estimators do so
  • the arguments passed to fit_predict() are the same as those to fit()

Conclusion

In this article, we discussed what is the purpose of the three most commonly implemented functions in sklearn, namely fit(), predict()and fit_predict() . We explored what each does and what their differences are as well as in what use-cases you should use one over the other.

As mentioned in the introduction of this article, even though we used specific examples to demonstrate their behaviour, the concepts explained in the article are applicable to pretty much all estimators implementing these methods in scikit-learn.

fit() method will fit the model to the input training instances while predict() will perform predictions on the testing instances, based on the learned parameters during fit. On the other hand, fit_predict() is more relevant to unsupervised learning where we don’t have labelled inputs.


A very similar topic is probably the comparison between fit() , transform() and fit_transform() methods which are implemented by scikit-learn transformers that are used to transform features. If you want to learn more about them you can read my Medium article below.

fit() vs transform() vs fit_transform() in Python scikit-learn


Subscribe to Data Pipeline, a newsletter dedicated to Data Engineering


Written By

Giorgos Myrianthous

Towards Data Science is a community publication. Submit your insights to reach our global audience and earn through the TDS Author Payment Program.

Write for TDS

Related Articles