Scikit-learn has a library of transformers to preprocess a data set. These transformers clean, generate, reduce or expand the feature representation of the data set. These transformers provide the `fit()`

, `transform()`

and `fit_transform()`

methods.

- The
`fit()`

method identifies and learns the model parameters from a training data set. For example, standard deviation and mean for normalization. Or Min (and Max) for scaling features to a given range. - The
`transform()`

method applies parameters learned from the`fit()`

method. The`transform()`

method transforms the training data and the test data (aka. unseen data) - The
`fit_transform()`

method first fits, then transforms the data-set in the same implementation. The`fit_transform()`

method is an efficient implementation of the`fit()`

and`transform()`

methods.`fit_transform()`

is only used on the training data set as a “best practice”.

**Note**: All the solutions provided below have been verified using Python 3.9.0b5

## Problem Formulation

What is the difference between the `fit()`

, `transform()`

and `fit_transform()`

methods in scikit-learn transformer classes?

## Background

Scikit-learn is an open-source Machine Learning library. It supports supervised and unsupervised learning.

Scikit-learn provides excellent tools for model fitting, selection, and evaluation. It also provides a multitude of helpful utilities for data preprocessing and analysis. Scikit-learn is an Open Source library with a commercially usable BSD license.

The developers of Scikit-learn work hard to keep the API uniform across the library. Scikit-learn provides a User Guide, many tutorials, and examples. Scikit-learn is an excellent resource for Pythonistas who wish to master Machine Learning.

## That is great!! But You Have Not Told Me Anything About fit(), transform() and fit_transform()

When implementing Machine Learning algorithms, one finds the need for preprocessing the data set. The preprocessing can take various forms such as

- Cleaning
- Centering
- Imputing
- Reduction
- Expansion
- Generation

The Scikit-learn library provides a multitude of classes called transformers for preprocessing. Most of these transformers share a common API. A common API provides simplicity and clarity to a given library. `fit()`

, `transform()`

and fit_transform() are common API methods for transformer classes. Let’s examine these methods one at a time.

## Ok Good!! First Tell Me About The fit() Method

In Machine learning projects, the data is often split into training and test data-sets. The `fit()`

method identifies and learns the model parameters only from the training data-set. For example, it identifies and learns **standard** **deviation** (and mean), for normalization. Or Min (and Max) for scaling the features to a given range. The `fit() `

method is best demonstrated by using it in an example. Let’s use the Centering preprocessing step on a data-set to show how `fit()`

works.

Centering the data-set is one example of preprocessing. This involves making the data-set have zero mean and unit standard deviation. To do this, one would first subtract each example in the data by the mean. Next, divide the result for each example by the standard deviation. To summarize and illustrate:

This calculation is easily handled using the `StandardScalar`

class, a utility class from the **scikit-learn transformer library**.

** StandardScaler().fit()** learns the

**mean**and

**standard**

**deviation**from the given training data set. It stores these learned parameters in the object. Let’s see how this works with an example…

$ python Python 3.9.0b5 (default, Oct 19 2020, 11:11:59) >>> >>> ## Import the StandardScalar module from the sklearn.preprocessing >>> ## library. The Standard Scalar module attempts to standardize the >>> ## features. It does this by removing the mean (i.e. centering the >>> ## data-set around 0). It also scales to unit variance (i.e. >>> ## (Standard Deviation is 1.0) >>> from sklearn.preprocessing import StandardScaler >>> >>> ## trn_ds is a 4x3 matrix. Each column is referred to as a feature. >>> ## Each element in each row is referred to as an example (or a >>> ## sample). >>> ## trn_ds is referred to as the training data-set. It is used to train the >>> ## machine learning model >>> trn_ds = [[0, 4, 8], ... [1, 5, 9], ... [2, 6, 10], ... [3, 7, 11]] >>> >>> ## Initialize a StandardScalar object. >>> ss = StandardScaler() >>> >>> ## Call the fit() routine to identify and learn Mean and Standard >>> ## Deviation of the trn_ds. >>> print(ss.fit(trn_ds)) StandardScaler() >>> >>> ## Mean is a learned value that is stored as the attribute mean_ >>> ## in the ss object. >>> print(ss.mean_) [1.5 5.5 9.5] >>> >>> ## Standard Deviation is another learned value that is stored as >>> ## the attribute var_ in the ss object. >>> print(ss.var_) [1.25 1.25 1.25] >>>

? **Note**: The `fit()`

method only examines the data-set, to learn(i.e. extract) the underlying parameters. For the `StandardScaler`

Class, these parameters are Mean and Standard Deviation (variance). `fit()`

methods of other transformer classes learn other parameters specific to those classes.

## Ok Got That!! Now What About transform()?

Transforming the training and the test data-sets is the next preprocessing step. The `transform()`

method uses the learned parameters from fit(), to transform the data-sets. Let’s continue to work on the Centering example shown above. The learned parameters are already stored in the `StandardScalar`

object. The `transform()`

method is first used on the training data-set. The `transform()`

method centers the data-set around 0 and scales it to have unit variance. The `transform()`

method is also used on the test data-set to center and scale it in the same way.

$ python Python 3.9.0b5 (default, Oct 19 2020, 11:11:59) >>> >>> ## Import the StandardScalar module from the sklearn.preprocessing >>> ## library. This step is the same as in the previous example. >>> from sklearn.preprocessing import StandardScaler >>> >>> ## We keep the same training data-set for comparison purposes. >>> trn_ds = [[0, 4, 8], ... [1, 5, 9], ... [2, 6, 10], ... [3, 7, 11]] >>> >>> ## Initialize a StandardScalar object. >>> ss = StandardScaler() >>> >>> ## Call the fit() routine to identify and learn Mean and Standard >>> ## Deviation of the trn_ds. >>> print(ss.fit(trn_ds)) StandardScaler() >>> >>> ## As before, Mean is a learned value that is stored as the >>> ## attribute mean_ in the ss object. >>> print(ss.mean_) [1.5 5.5 9.5] >>> >>> ## Standard Deviation is another learned value that is stored as >>> ## the attribute var_ in the ss object. >>> print(ss.var_) [1.25 1.25 1.25] >>> >>> ## Ok!! So far, so good!!. Next, transform the training data. >>> print(ss.transform(trn_ds)) [[-1.34164079 -1.34164079 -1.34164079] [-0.4472136 -0.4472136 -0.4472136 ] [ 0.4472136 0.4472136 0.4472136 ] [ 1.34164079 1.34164079 1.34164079]] >>> >>> ## It worked!! The transformed trn_ds data-set is now centered >>> ## around 0, i.e has 0 mean. It has also been scaled to have unit >>> ## variance (i.e. standard deviation). >>> >>> ## Next, let’s see how the test data-set is transformed. Note that >>> ## the mean and std was calculated using *only* the trn_ds data-set >>> ## So the transform() function will try to center and scale this >>> ## new unseen data (i.e. tst_ds) using the parameters learned from >>> ## the trn_ds data-set. >>> tst_ds = [[30, 34, 38], ... [31, 35, 39], ... [32, 36, 40], ... [33, 37, 41]] >>> >>> print(ss.transform(tst_ds)) [[25.49117494 25.49117494 25.49117494] [26.38560213 26.38560213 26.38560213] [27.28002933 27.28002933 27.28002933] [28.17445652 28.17445652 28.17445652]] >>>

To reiterate the steps performed so far,

- Call the
`fit()`

method once (on training data-set only).- The
`fit()`

method learned the underlying parameters from the training data-set *only*.

- The
- Call the
`transform()`

method twice (once on training data-set, once on test data-set).- The
`transform()`

method first transformed the training data-set. - The
`transform()`

method also transformed the test data-set

- The

## Ah! OK!! So What Does fit_transform() Do?

The developers of scikit-learn are always thinking of ways to optimize the library. The `fit()`

and the `transform()`

methods are ** *always*** applied on the training data-set. So why not offer a

`fit_transform()`

method and optimize it, they thought. `fit_transform()`

is the optimized version of `fit()`

and `transform()`

put together. `fit_transform()`

is **called for the training data-set. Let’s see how this works for the data-sets, used in the above example. The final result should be exactly the same for both use cases.**

**only**$ python Python 3.9.0b5 (default, Oct 19 2020, 11:11:59) >>> >>> ## Import the StandardScalar module from the sklearn.preprocessing >>> ## library. This step is the same as in the previous examples. >>> from sklearn.preprocessing import StandardScaler >>> >>> ## We keep the same training data-set for comparison purposes. >>> trn_ds = [[0, 4, 8], ... [1, 5, 9], ... [2, 6, 10], ... [3, 7, 11]] >>> >>> ## Test data-set is the same as before too. >>> tst_ds = [[30, 34, 38], ... [31, 35, 39], ... [32, 36, 40], ... [33, 37, 41]] >>> >>> ## Initialize a StandardScalar object. >>> ss = StandardScaler() >>> >>> ## Call the fit_transform() routine on the training data-set. >>> ## - The method first identifies and learns Mean and Standard >>> ## Deviation of the trn_ds. >>> ## - Next it Centers and Scales the training data. >>> ## All this is done in one optimized step, by using the >>> ## fit_transform() method. >>> print(ss.fit_transform(trn_ds)) [[-1.34164079 -1.34164079 -1.34164079] [-0.4472136 -0.4472136 -0.4472136 ] [ 0.4472136 0.4472136 0.4472136 ] [ 1.34164079 1.34164079 1.34164079]] >>> >>> ## As before, Mean is a learned value that is stored as the >>> ## attribute mean_ in the ss object. >>> print(ss.mean_) [1.5 5.5 9.5] >>> >>> ## Standard Deviation is another learned value that is stored as >>> ## the attribute var_ in the ss object. >>> print(ss.var_) [1.25 1.25 1.25] >>> >>> ## Ok!! So far, so good!!. Everything looks to be the same. >>> ## The transformed trn_ds data-set continues to be centered >>> ## around 0, i.e has 0 mean. It has also been scaled to have unit >>> ## variance (i.e. standard deviation). >>> >>> ## Next, lets see how the test data-set is transformed. The result >>> ## should be the same as in the previous example. >>> print(ss.transform(tst_ds)) [[25.49117494 25.49117494 25.49117494] [26.38560213 26.38560213 26.38560213] [27.28002933 27.28002933 27.28002933] [28.17445652 28.17445652 28.17445652]] >>> >>> ## Perfect!! So there!! fit_transform() is fit() and transform() >>> ## steps put together and optimized into one function. A step saved >>> ## is valuable time earned!!

To reiterate the steps performed in this section,

- Call the
`fit_transform()`

method once (on training data-set only).- The
`fit_transform()`

method learned the underlying parameters from the training data-set *only*. - Next, it transformed the training data-set *only*. This is all done in one call, in one step!!

- The
- Call the
`transform()`

method on the test data-set *only*.- Note how the 3 separate calls (i.e.
`fit()`

,`transform(on training data-set)`

,`transform(on test data-set)`

) got reduced to 2 calls (i.e.`fit_transform(on training data-set)`

,`transform(on test data-set)`

).

- Note how the 3 separate calls (i.e.

## Conclusion

The **scikit-learn** community is quite active about optimizing the library. They continue to improve and update the library. As we saw above, three separate pre-processing steps are now done in two steps!! This saves time and time is precious. **Time is Money!!**

Here is a parting thought!! As coders, we spend a lot of time researching and coding. It is easy to forget to nourish our Body *and* Soul. Ignoring the Body *and* the Soul will lead to all sorts of mental and physical illness. Illness leads could lead us to seek medical care. That is a serious loss of Time *and* Money. So Invest in yourself, eat healthily and take frequent breaks to stretch or walk. After all, what is all this Money for, if one cannot enjoy it!!

## Finxter Academy

This blog was brought to you by **Girish Rao**, a student of Finxter Academy. You can find his Upwork profile here.

## Reference

All research for this blog article was done using Python Documents, the Google Search Engine and the shared knowledge-base of the Finxter Academy, scikit-learn and the Stack Overflow Communities.