Sklearn fit() vs transform() vs fit_transform() – What’s the Difference?

Scikit-learn has a library of transformers to preprocess a data set. These transformers clean, generate, reduce or expand the feature representation of the data set. These transformers provide the fit(), transform() and fit_transform() methods.

  • The fit() method identifies and learns the model parameters from a training data set. For example, standard deviation and mean for normalization. Or Min (and Max) for scaling features to a given range.
  • The transform() method applies parameters learned from the fit() method. The transform() method transforms the training data and the test data (aka. unseen data)
  • The fit_transform() method first fits, then transforms the data-set in the same implementation. The fit_transform() method is an efficient implementation of the fit() and transform() methods. fit_transform() is only used on the training data set as a “best practice”.

Note: All the solutions provided below have been verified using Python 3.9.0b5

Problem Formulation

What is the difference between the fit(), transform() and fit_transform() methods in scikit-learn transformer classes?

Background

Scikit-learn is an open-source Machine Learning library. It supports supervised and unsupervised learning. 

Scikit-learn provides excellent tools for model fitting, selection, and evaluation.  It also provides a multitude of helpful utilities for data preprocessing and analysis. Scikit-learn is an Open Source library with a commercially usable BSD license.

The developers of Scikit-learn work hard to keep the API uniform across the library. Scikit-learn provides a User Guide, many tutorials, and examples.  Scikit-learn is an excellent resource for Pythonistas who wish to master Machine Learning.  

That is great!! But You Have Not Told Me Anything About fit(), transform() and fit_transform()

When implementing Machine Learning algorithms, one finds the need for preprocessing the data set. The preprocessing can take various forms such as

  • Cleaning
  • Centering
  • Imputing
  • Reduction
  • Expansion
  • Generation

The Scikit-learn library provides a multitude of classes called transformers for preprocessing. Most of these transformers share a common API. A common API provides simplicity and clarity to a given library. fit(), transform() and fit_transform() are common API methods for transformer classes. Let’s examine these methods one at a time. 

Ok Good!! First Tell Me About The fit() Method

In Machine learning projects, the data is often split into training and test data-sets. The fit() method identifies and learns the model parameters only from the training data-set. For example, it identifies and learns standard deviation (and mean), for normalization. Or Min (and Max) for scaling the features to a given range. The fit() method is best demonstrated by using it in an example. Let’s use the Centering preprocessing step on a data-set to show how fit() works.

Centering the data-set is one example of preprocessing. This involves making the data-set have zero mean and unit standard deviation. To do this, one would first subtract each example in the data by the mean. Next, divide the result for each example by the standard deviation. To summarize and illustrate:

This calculation is easily handled using the StandardScalar class, a utility class from the scikit-learn transformer library

StandardScaler().fit() learns the mean and standard deviation from the given training data set. It stores these learned parameters in the object.  Let’s see how this works with an example…

$ python
Python 3.9.0b5 (default, Oct 19 2020, 11:11:59) 
>>>
>>> ## Import the StandardScalar module from the sklearn.preprocessing
>>> ## library. The Standard Scalar module attempts to standardize the
>>> ## features. It does this by removing the mean (i.e. centering the
>>> ## data-set around 0). It also scales to unit variance (i.e. 
>>> ## (Standard Deviation is 1.0)
>>> from sklearn.preprocessing import StandardScaler
>>> 
>>> ## trn_ds is a 4x3 matrix. Each column is referred to as a feature.
>>> ## Each element in each row is referred to as an example (or a 
>>> ## sample).
>>> ## trn_ds is referred to as the training data-set. It is used to train the
>>> ## machine learning model
>>> trn_ds = [[0, 4, 8],
...                     [1, 5, 9],
...                     [2, 6, 10],
...                     [3, 7, 11]]
>>> 
>>> ## Initialize a StandardScalar object.
>>> ss = StandardScaler()
>>> 
>>> ## Call the fit() routine to identify and learn Mean and Standard
>>> ## Deviation of the trn_ds.
>>> print(ss.fit(trn_ds))
StandardScaler()
>>> 
>>> ## Mean is a learned value that is stored as the attribute mean_
>>> ## in the ss object.
>>> print(ss.mean_)
[1.5 5.5 9.5]
>>> 
>>> ## Standard Deviation is another learned value that is stored as
>>> ## the attribute var_ in the ss object.
>>> print(ss.var_)
[1.25 1.25 1.25]
>>> 

Note: The fit() method only examines the data-set, to learn(i.e. extract) the underlying parameters. For the StandardScaler Class, these parameters are Mean and Standard Deviation (variance). fit() methods of other transformer classes learn other parameters specific to those classes.

Ok Got That!! Now What About transform()?

Transforming the training and the test data-sets is the next preprocessing step.  The transform() method uses the learned parameters from fit(), to transform the data-sets.  Let’s continue to work on the Centering example shown above. The learned parameters are already stored in the StandardScalar object. The transform() method is first used on the training data-set. The transform() method centers the data-set around 0 and scales it to have unit variance. The transform() method is also used on the test data-set to center and scale it in the same way. 

$ python
Python 3.9.0b5 (default, Oct 19 2020, 11:11:59) 
>>>
>>> ## Import the StandardScalar module from the sklearn.preprocessing
>>> ## library. This step is the same as in the previous example.
>>> from sklearn.preprocessing import StandardScaler
>>> 
>>> ## We keep the same training data-set for comparison purposes.
>>> trn_ds = [[0, 4, 8],
...                     [1, 5, 9],
...                     [2, 6, 10],
...                     [3, 7, 11]]
>>> 
>>> ## Initialize a StandardScalar object.
>>> ss = StandardScaler()
>>> 
>>> ## Call the fit() routine to identify and learn Mean and Standard
>>> ## Deviation of the trn_ds.
>>> print(ss.fit(trn_ds))
StandardScaler()
>>> 
>>> ## As before, Mean is a learned value that is stored as the
>>> ## attribute mean_ in the ss object.
>>> print(ss.mean_)
[1.5 5.5 9.5]
>>> 
>>> ## Standard Deviation is another learned value that is stored as
>>> ## the attribute var_ in the ss object.
>>> print(ss.var_)
[1.25 1.25 1.25]
>>> 
>>> ## Ok!! So far, so good!!. Next, transform the training data.
>>> print(ss.transform(trn_ds))
[[-1.34164079 -1.34164079 -1.34164079]
 [-0.4472136  -0.4472136  -0.4472136 ]
 [ 0.4472136   0.4472136   0.4472136 ]
 [ 1.34164079  1.34164079  1.34164079]]
>>> 
>>> ## It worked!! The transformed trn_ds data-set is now centered 
>>> ## around 0, i.e has 0 mean. It has also been scaled to have unit
>>> ## variance (i.e. standard deviation).
>>> 
>>> ## Next, let’s see how the test data-set is transformed. Note that
>>> ## the mean and std was calculated using *only* the trn_ds data-set
>>> ## So the transform() function will try to center and scale this
>>> ## new unseen data (i.e. tst_ds) using the parameters learned from
>>> ## the trn_ds data-set.
>>> tst_ds = [[30, 34, 38],
...                    [31, 35, 39],
...                    [32, 36, 40],
...                    [33, 37, 41]]
>>> 
>>> print(ss.transform(tst_ds))
[[25.49117494 25.49117494 25.49117494]
 [26.38560213 26.38560213 26.38560213]
 [27.28002933 27.28002933 27.28002933]
 [28.17445652 28.17445652 28.17445652]]
>>>

To reiterate the steps performed so far,

  • Call the fit() method once (on training data-set only).
    • The fit() method learned the underlying parameters from the training data-set *only*.
  • Call the transform() method twice (once on training data-set, once on test data-set).
    • The transform() method first transformed the training data-set.
    • The transform() method also transformed the test data-set 

Ah! OK!! So What Does fit_transform() Do?

The developers of scikit-learn are always thinking of ways to optimize the library.  The fit() and the transform() methods are *always* applied on the training data-set. So why not offer a fit_transform() method and optimize it, they thought.  fit_transform() is the optimized version of fit() and transform() put together. fit_transform() is *only* called for the training data-set. Let’s see how this works for the data-sets, used in the above example. The final result should be exactly the same for both use cases.

$ python
Python 3.9.0b5 (default, Oct 19 2020, 11:11:59) 
>>> 
>>> ## Import the StandardScalar module from the sklearn.preprocessing
>>> ## library. This step is the same as in the previous examples.
>>> from sklearn.preprocessing import StandardScaler
>>>
>>> ## We keep the same training data-set for comparison purposes.
>>> trn_ds = [[0, 4, 8],
...                     [1, 5, 9],
...                     [2, 6, 10],
...                     [3, 7, 11]]
>>> 
>>> ## Test data-set is the same as before too.
>>> tst_ds = [[30, 34, 38],
...                    [31, 35, 39],
...                    [32, 36, 40],
...                    [33, 37, 41]]
>>> 
>>> ## Initialize a StandardScalar object.
>>> ss = StandardScaler()
>>> 
>>> ## Call the fit_transform() routine on the training data-set.
>>> ## - The method first identifies and learns Mean and Standard
>>> ##   Deviation of the trn_ds.
>>> ## - Next it Centers and Scales the training data.
>>> ## All this is done in one optimized step, by using the
>>> ## fit_transform() method.
>>> print(ss.fit_transform(trn_ds))
[[-1.34164079 -1.34164079 -1.34164079]
 [-0.4472136  -0.4472136  -0.4472136 ]
 [ 0.4472136   0.4472136   0.4472136 ]
 [ 1.34164079  1.34164079  1.34164079]]
>>> 
>>> ## As before, Mean is a learned value that is stored as the
>>> ## attribute mean_ in the ss object.
>>> print(ss.mean_)
[1.5 5.5 9.5]
>>> 
>>> ## Standard Deviation is another learned value that is stored as
>>> ## the attribute var_ in the ss object.
>>> print(ss.var_)
[1.25 1.25 1.25]
>>> 
>>> ## Ok!! So far, so good!!. Everything looks to be the same.
>>> ## The transformed trn_ds data-set continues to be centered 
>>> ## around 0, i.e has 0 mean. It has also been scaled to have unit
>>> ## variance (i.e. standard deviation).
>>> 
>>> ## Next, lets see how the test data-set is transformed. The result
>>> ## should be the same as in the previous example.
>>> print(ss.transform(tst_ds))
[[25.49117494 25.49117494 25.49117494]
 [26.38560213 26.38560213 26.38560213]
 [27.28002933 27.28002933 27.28002933]
 [28.17445652 28.17445652 28.17445652]]
>>> 
>>> ## Perfect!! So there!! fit_transform() is fit() and transform() 
>>> ## steps put together and optimized into one function. A step saved
>>> ## is valuable time earned!!

To reiterate the steps performed in this section,

  • Call the fit_transform() method once (on training data-set only).
    • The fit_transform() method learned the underlying parameters from the training data-set *only*.
    • Next, it transformed the training data-set *only*. This is all done in one call, in one step!!
  • Call the transform() method on the test data-set *only*. 
    • Note how the 3 separate calls (i.e. fit(), transform(on training data-set), transform(on test data-set)) got reduced to 2 calls (i.e. fit_transform(on training data-set), transform(on test data-set)).

Conclusion

The scikit-learn community is quite active about optimizing the library. They continue to improve and update the library. As we saw above, three separate pre-processing steps are now done in two steps!! This saves time and time is precious. Time is Money!!

Here is a parting thought!! As coders, we spend a lot of time researching and coding. It is easy to forget to nourish our Body *and* Soul. Ignoring the Body *and* the Soul will lead to all sorts of mental and physical illness.  Illness leads could lead us to seek medical care. That is a serious loss of Time *and* Money.  So Invest in yourself, eat healthily and take frequent breaks to stretch or walk. After all, what is all this Money for, if one cannot enjoy it!!

Finxter Academy

This blog was brought to you by Girish Rao, a student of Finxter Academy. You can find his Upwork profile here.

Reference

All research for this blog article was done using Python Documents, the Google Search Engine and the shared knowledge-base of the Finxter Academy, scikit-learn and the Stack Overflow Communities.