Exploring the Titanic Dataset with TensorFlow Estimators

πŸ’‘ Problem Formulation: How can TensorFlow, a powerful machine learning library, be leveraged with estimators to analyze historical data such as the Titanic dataset? This article aims to demonstrate methods using Python to predict the survival outcomes of passengers based on features such as age, class, and gender. For example, given a passenger’s information, the outcome would be a prediction of their survival (1 for survived, 0 for didn’t survive).

Method 1: Data Preprocessing with tf.feature_column

An essential step in machine learning is preprocessing the data into a suitable format for analysis. TensorFlow provides the tf.feature_column module, which is used to define how each feature should be treated by the model. This includes normalizing numerical data, encoding categorical variables, as well as handling missing values.

Here’s an example:

import tensorflow as tf

# Define feature columns
age = tf.feature_column.numeric_column('age')
gender = tf.feature_column.categorical_column_with_vocabulary_list('sex', ['male', 'female'])
class = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third'])

age_buckets = tf.feature_column.bucketized_column(age, boundaries=[20, 30, 40, 50, 60, 70, 80])

# Convert categorical columns to indicator columns
gender_one_hot = tf.feature_column.indicator_column(gender)
class_one_hot = tf.feature_column.indicator_column(class)

The output for each feature column would be a tensor with processed values ready to be fed to a model.

This code defines feature columns using TensorFlow’s API, categorizing continuous data into buckets and converting categorical variables into one-hot encoded tensors. This process is vital to make the raw data interpretable for the ensuing machine learning models.

Method 2: Building a Linear Model with TensorFlow Estimators

After preprocessing the data, the next step is to create a predictive model. TensorFlow Estimators provide an easy-to-use abstraction for creating models. The tf.estimator.LinearClassifier is suitable for binary classification tasks, such as predicting survivorship on the Titanic.

Here’s an example:

import tensorflow as tf

# Assume features is the preprocessed feature columns from Method 1
feature_columns = [age_buckets, gender_one_hot, class_one_hot]

# Build a linear estimator
linear_est = tf.estimator.LinearClassifier(feature_columns=feature_columns)

# Train the model on the Titanic data
# Assume train_input_fn is the input function that feeds the data

To be filled in after successful training and evaluation.

This simple example creates a linear classifier with the prepared feature columns and then trains it on the training data from the Titanic dataset using an appropriate input function. This model can then be used to predict the probability of survivorship.

Method 3: Evaluating Model Performance

Once the model is trained, evaluating its performance on a test dataset is crucial for understanding its accuracy. TensorFlow’s Estimators allow for easy evaluation using the evaluate method, which provides various metrics such as accuracy, precision, and recall based on the test data.

Here’s an example:

# Assume eval_input_fn is the input function for test data
evaluation_results = linear_est.evaluate(eval_input_fn)

The output will be a dictionary with evaluation metrics for the model.

In this snippet, we use the trained linear estimator to evaluate its performance on the test data. The resulting metrics are useful for judging the model’s effectiveness and can guide further tuning or selection of models.

Method 4: Predicting Outcomes with the Model

The final step in the workflow is to use the trained model to make predictions. TensorFlow Estimators simplify this process with the predict method, which takes a dataset and returns the predicted class along with probabilities for each instance.

Here’s an example:

# Assume pred_input_fn is the input function for new data
predictions = linear_est.predict(pred_input_fn)

for pred in predictions:
    print(pred['class_ids'][0])  # predicted class
    print(pred['probabilities'])  # class probabilities

The predicted class and its probabilities will be displayed for each passenger.

This code demonstrates how to make predictions with a trained estimator based on new data. It’s useful for applying the model to make real-world predictions or to validate the model against unseen data.

Bonus One-Liner Method 5: Simplified Data Input with TensorFlow’s Dataset API

TensorFlow’s Dataset API proves invaluable for feeding data into estimators in an efficient manner, and can be utilized with just one line of code.

Here’s an example:

# Assume df is a Pandas DataFrame with the Titanic data
dataset = tf.data.Dataset.from_tensor_slices((dict(df), df['Survived']))

A TensorFlow Dataset object is created and ready to be used in an Estimator’s input function.

This line generates a TensorFlow Dataset from the Titanic DataFrame, linking features with the target variable, which greatly simplifies the input process for model training and evaluation.


  • Method 1: Data Preprocessing with tf.feature_column. Strengths: This method allows for detailed customization of data preprocessing, which is essential for accurate model training. Weaknesses: It requires a good understanding of the data and preprocessing techniques.
  • Method 2: Building a Linear Model with TensorFlow Estimators. Strengths: Estimators provide a high-level API for model construction, making it quick and easy to build and train models. Weaknesses: It may not offer the same level of customization as lower-level TensorFlow APIs.
  • Method 3: Evaluating Model Performance. Strengths: Provides key insights into model accuracy and generalizability. Weaknesses: Evaluation only indicates performance on the current test set and may not reflect real-world performance.
  • Method 4: Predicting Outcomes with the Model. Strengths: Enables the application of the trained model to make predictions on new data, which is the end goal of most machine learning projects. Weaknesses: Prediction accuracy depends heavily on the quality and representativeness of the training data.
  • Method 5: Simplified Data Input with TensorFlow’s Dataset API. Strengths: Increases efficiency and simplicity of feeding data into the model. Weaknesses: Assumes the dataset is already cleaned and may require additional steps for complex preprocessing.