Predicting Churn for users of Music Application

Parsai Anu
10 min readNov 23, 2020

Project Overview

This project aims to predict churn for users of a fictitious music application called Sparkify. The problem statement and the dataset is provided by Udacity. The dataset consist of logs of actions performed by any user who visits the application. The log is captured for each action across 18 columns for 226 users. We want to use this data to build ML model that predicts if a user will churn or not by using these attributes directly or through some transformation.

The log data comes in 3 sizes based on the data handling capability of the platform used . We are used the smallest size of the data.

Problem Statement

The log dataset contains 286500 rows,18 columns for 226 users. Following are the 18 attributes used.

We first want to define what a churn means. Then, we want to analyze these attributes/combination of these attributes at user level/overall level to identify key features that may cause users to churn. After feature selection, we aim to build multiple ML models and then select the model that performs the best on the test as well as validation set based on the metric (given below).

This solution will help sparkify identify the major reasons for users churning and also potential churn users, thus helping them some personalized offers etc . to these users and retain them.

Metric

Since it’s a churn prediction i.e a binary 0/1 classification, we will using classification models, hence we have chosen classification metric called as F1 score. Only 47/226 users will be tagged as churn in our problem which makes the dataset imbalanced. This may make the model predict less churn users (less positives) and more negatives which will not be clearly reflected in the accuracy, hence we have chosen F1 score that balances both False positives as well as False negatives.

Data Exploration & Data Visualization

The dataset contains log of user actions on a music platform, which is stored in 18 columns recorded in 286500 rows.

Null Data Identification

We wanted to understand first the number of nulls in the dataset. What we found was there were 8346/286500 rows with null in user info columns. These rows were null because they were recorded when user was logged in as Guest or was Logged out of the application.

After the removal of these rows, there were still some null rows ~ 50046 related to artist data. On further investigation we found out these artist information is null when user is engaged in action not related to listening songs such as error help, adding friend etc. We decide to keep these rows since these rows represent valid actions not missing data related to user.

Outlier Identification

For outlier treatment, we analyzed how many actions(called as itemInsession) does a user perform in one session and also the duration of each of these action on an average. On an average in a session a user performs 108 actions, we decided to remove those rows which were one std. deviation away from these mean number of items.

Statistica about Data

The first task was to define churn of a user. We defined churn as those users who have clicked on page Cancellation Confirmation. 47/225 users have churned in the data based on the above definition. Now, that we had churn defined, we analyzed the variables for churn and non-churn users separately.

The first analysis was to see if there is any difference in distribution of churn/No churn users among the male and female population. It turned out Male has a churn rate of 24% as opposed to women who has churn rate of 16% [Refer to Fig 1 below]. Therefor, the gender of a user may be a predictor of churn.

Fig 1: Gender distribution across Churn/No Churn Users
Fig 1: Gender Distribution of Churn/No-Churn Users

The second feature that we analyzed was the distribution of actions performed by churn/no-churn users. It turns out the users who stayed were clicking more on pages such as ‘Next Song’, ‘Add friend,’ ‘Add to playlist’ etc. as compared to users who churned. This made senses as these features are indicative of how much the user likes an application. If a user is liking the experience of the application then he/she is more likely to play more songs, like more songs, refer applications to other friends. [Refer to fig 2 ]

Fig 2: Distribution of user actions by churn/no churn

Another observation was about the mean time spent by a user on a song. We found out the user who churned spent around 175 mins in a session as compared to users who stayed spending 225 min on an average. Also the users who had a paid subscription had a churn rate of 18% as compared to users with free subscription having churn rate of 21%. [Refer to fig 3]

Fig 3: Distribution of free/paid users across churn/no churn

Data Pre-processing

Null Treatment

We decided to remove 8346 rows containing null information for users, as the user related information is critical to our model, since we will built models at user level. For the 50046, rows having null related to artist data, we found out these artist information is null when user is engaged in action not related to listening songs such as error help, adding friend etc. We decide to keep these rows since these rows represent valid actions not missing data related to user.

Outlier Treatment

Removed rows with itemInsession number more than 231 using mean+1 std. dev criteria.Because, these rows represent anomalous rows as the user can only perform limited actions in one session before it expires.

Feature Engineering

The data exploration steps were key in identifying the features that need to be created for building the model. As the model has to predict churn at user level, all the features had to be aggregated or created at user level. The exploration helped us identify the below features for the model building-

1. Gender of User
2. Mean time spent by a user in a session
3. Number of artist listened to by a user
4. Mean time spent by a user in a item in a session
5. No. of time a user clicks on Next Song
6. No. of time a user click on Add to Playlist
7. No. of time a user clicks on Thumbs up
8. No. of times a user clicks on Thumbs down

Implementation

We split the data into 80:10:10 Train, Test and Validation split. We used two types of model for our problem statement- one is logistic regression and other random forest.

The model is trained on 80% data, then it's checked on 10% Test Data to select the model with the best F1 score. Since both Train and Test are influencing our model selection, we have kept 10% validation aside to see how good our model performs on unseen dataset.

We used two metric to evaluate our model — accuracy and F1 score. We eventually used F1 score in choosing between models because it is a more robust metric that looks at the number of false negatives as well.

Refinement

We used maxIter and regParam in hyperparameter tuning. Following snippet showed the range of parameter chosen:

RegParam controls the overfitting, the lower the param the lesser the overfitting. MaxIter controls the iteration for the algorithm to converge.

Based on 3 fold cross validation the best param were below:

The lr-model after hyperparameter tuning gave the following results:

The cross validation showed the algorithm with lowest C values that mean less overfitting was selected. We decided to then train a random forest model as it also is best used when we want to avoid overfitting.

First we tried random forest with default parameters and got the following results, which was almost same as logistic regression result:

Hence, we decided to tune the numTrees parameter and did a 3 fold cross validation to get the best parameter value. Below is the range used for the parameter. We used the smallest of trees to most commonly used set of 100 trees to find the best numTrees.

The best parameter for rf is given below:

The result for RF after hyper-parameter are given below-

Justification

The reason for using Logistic Regression is it’s the simplest and widely using algorithm that is quick in helping us establish the bench mark.However, logistic regression assumes linear relationship and does not use ensemble learning thereby more prone to overfitting.

The reason for choosing the second model random forest is that the model can explore complex relationships between variables and it takes a sample of data in each tree thus ensuring the model does overfit on the train dataset.

Results

Model Evaluation and Validation

After using hyper-parameter tuning on LR and RF, we found that the RF model has best F1 score on the test set. The numTrees parameter of the RF model was cross-validated for [5,10,50,100] using 3-fold cross-validation . The NumTree=5 turned out to be the best parameter. The following are the best results obtained after all the analysis on the test set:

The actual versus predicted for the 10% Test set are shown below:

The precision of the model is 50%(1/(1+1)) and Recall is 25%(1/(1+3)). Based on the default beta we are getting the F1 of 0.87. The F1 score is giving importance to False negatives as well, we can see 3 Churn users have been predicted as no churn hence the recall is so low i.e 25%. It’s not just about correctly predicting churn but also ensuing we are not incorrectly labelling churn users as no churn. Because this means we will not be targeting retaining of these users believing the model result that they will stay, but in reality they will churn away.

The parameters are chosen using 3-fold cross validation on the train set, hence hence we expect them to perform consistent on out of sample set as well. We have checked a good range of parameters from 5 to 100, to find the optimum tree.

Another way to ensure robustness of model has been keeping of a 10% Validation Set aside which was no-where involved in the decision making process of model choosing.

We checked the performance of our random forest model on 10% validation set.

Result on 10% Validation or Out of Sample Set

Actual versus predicted for Validation set are:

The model performed consistent in 10% Validation set is f1= 0.86 which is consistent with test set f1 of 0.87.

The precision of the validation set is 100%(1/(1+0)) and Recall is 25%(1/(1+3)). Based on the default beta value we are getting the F1 of 0.86.

This gives us confidence the model will hold good on different sets of sample as well.

We also looked the feature importance of the variables to understand which features where picked up by random forest as most important and found out No. of Thumbs up, No. of Thumbs down , Add friend were the top 3 important factors, which make sense as these are indicative of the user linking the application and showing interest in improving the suggestions made by application to him.

Justification

The reason we chose RF over LR is it is having a better F1 score. The LR model is actually having high accuracy because it is slightly overfitting the train data that means its easily capturing No churns since there are hardly 47 churns in the original dataset. While the RF has similar accuracy to LR but better F1 because its capturing the churns better than LR at the expense of slightly less No churns but that is fine as the accuracy is almost same for both.

Conclusion

Summary & Reflection

We have used the user log of a music application to successfully predict churn/no churn users. We did a through data exploration to find out the anomalies in the dataset and treated those anomalies. We also analyzed each variable with respect to Churn variable to see how likely is that the variable is affecting the churn. Feature engineering was performed to create new features that looked significant in the exploration phase and were intuitive also. Total 8 new features were created. Two models logistic regression and random forest were applied on the data set with proper parameter tuning using 3-fold cross validation on 80% Train set. The model were evaluated on 10% test set and the model with best F1 score was chosen. Random forest was eventually chosen based on F1 score on Test set. The robustness of the model was checked on the 10% validation set which also showed consistent result with the test set.

The most interesting part about the project is using Big Data platform to implement ML models. As far as challenge is concerned, inappropriate usage/no-usage of optimizing statements such as persist etc. made the program run lengthy.

Improvement

The current project is built on the small data, in future the project can be tried on the large dataset that will help in better tuning of model as more data will be available plus we can choose some data heavy models neural network known for better accuracy.

--

--