Seaborn

Categorical Data Plots

This notebook will give you a detailed overview of how to perform data visualization using the powerful Seaborn module in Python.

To install Seaborn type pip install seaborn or conda install seaborn in the terminal window.

Lets first look into plotting Categorical Data. Categorical data means a data column which has certain levels or categories (for example Sex column can have two distinct values - Male and Female). There are a few main plot types for this:

  • barplot
  • countplot
  • boxplot
  • violinplot
  • stripplot
  • swarmplot
  • factorplot

Let's go through examples of each!

In [2]:
import seaborn as sns     # import the module 
%matplotlib inline
In [3]:
tips = sns.load_dataset('tips') # tips data is available in the seaborn module which we will use for this tutorial notebook.
tips.head()                     # this is how the data looks like
Out[3]:
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4

barplot and countplot

These very similar plots allow you to get aggregate data off a categorical feature in your data. barplot is a general plot that allows you to aggregate the categorical data based off some function, by default the mean. So, in this example the mean bill for Males is around 21 dollars as compared to around 17-18 dollars for the females.

In [24]:
barplot = sns.barplot(x='sex',y='total_bill',data=tips,palette="Set1") 
barplot.set(xlabel='Sex', ylabel='Mean Bill') # to set x and y labels 
Out[24]:
[<matplotlib.text.Text at 0x11d356b38>, <matplotlib.text.Text at 0x11d324390>]
In [9]:
import numpy as np

In this plot we are looking for the standard deviation instead of Mean.

In [25]:
barplot2 = sns.barplot(x='sex',y='total_bill',data=tips,estimator=np.std,palette="Set1")
barplot2.set(xlabel='Sex', ylabel='Standard Deviation Bill')
Out[25]:
[<matplotlib.text.Text at 0x11d351940>, <matplotlib.text.Text at 0x11d32ea90>]

countplot

This is essentially the same as barplot except the estimator is explicitly counting the number of occurrences. Which is why we only pass the x value:

In [26]:
sns.countplot(x='sex',data=tips,palette="Set1") # answers how many males and females in our data 
Out[26]:
<matplotlib.axes._subplots.AxesSubplot at 0x11d351d68>

boxplot and violinplot

boxplots and violinplots are used to shown the distribution of categorical data. A box plot shows the distribution of quantitative data in a way that facilitates comparisons between variables or across levels of a categorical variable. The box shows the quartiles of the dataset while the whiskers extend to show the rest of the distribution, except for points that are determined to be “outliers” using a method that is a function of the inter-quartile range.

In [23]:
sns.boxplot(x="day", y="total_bill", data=tips,palette='coolwarm')
Out[23]:
<matplotlib.axes._subplots.AxesSubplot at 0x11d260fd0>

The above boxplot tells you what is the distribution of bill per day.

In [22]:
# Can do entire dataframe with orient='h'
sns.boxplot(data=tips,palette='coolwarm',orient='h')
Out[22]:
<matplotlib.axes._subplots.AxesSubplot at 0x11d0e6b00>
In [21]:
sns.boxplot(x="day", y="total_bill", hue="smoker",data=tips, palette="coolwarm")
Out[21]:
<matplotlib.axes._subplots.AxesSubplot at 0x11cecd588>

The above boxplot tells you what is the distribution of total bill per day given the person is a smoker or not.

violinplot

A violin plot plays a similar role as a boxplot. It shows the distribution of quantitative data across several levels of one (or more) categorical variables such that those distributions can be compared. Unlike a box plot, in which all of the plot components correspond to actual datapoints, the violin plot features a kernel density estimation of the underlying distribution.

In [19]:
sns.violinplot(x="day", y="total_bill", data=tips,palette='Set2')
Out[19]:
<matplotlib.axes._subplots.AxesSubplot at 0x11bc00eb8>
In [5]:
sns.violinplot(x="day", y="total_bill", data=tips,hue='sex',palette='Set2')
Out[5]:
<matplotlib.axes._subplots.AxesSubplot at 0x10e25f4e0>

The above plot shows the distribution of total bill per day given the person is a male or a female.

In [17]:
sns.violinplot(x="day", y="total_bill", data=tips,hue='sex',split=True,palette='Set2')
Out[17]:
<matplotlib.axes._subplots.AxesSubplot at 0x11c91f588>

You can use split parameter as true to have a single violin plot instead of two violin plots for each sex.

stripplot and swarmplot

The stripplot will draw a scatterplot where one variable is categorical. A strip plot can be drawn on its own, but it is also a good complement to a box or violin plot in cases where you want to show all observations along with some representation of the underlying distribution.

The swarmplot is similar to stripplot(), but the points are adjusted (only along the categorical axis) so that they don’t overlap. This gives a better representation of the distribution of values, although it does not scale as well to large numbers of observations (both in terms of the ability to show all the points and in terms of the computation needed to arrange them).

In [13]:
sns.stripplot(x="day", y="total_bill", data=tips,palette = "husl")
Out[13]:
<matplotlib.axes._subplots.AxesSubplot at 0x10ef49080>
In [16]:
sns.stripplot(x="day", y="total_bill", data=tips,jitter=True,palette = "deep")
Out[16]:
<matplotlib.axes._subplots.AxesSubplot at 0x10f2ac3c8>
In [21]:
sns.stripplot(x="day", y="total_bill", data=tips,jitter=True,hue='sex',palette='Set1')
Out[21]:
<matplotlib.axes._subplots.AxesSubplot at 0x10f99aac8>

The above stripplot shows the data points as per the sex.

In [43]:
sns.stripplot(x="day", y="total_bill", data=tips,jitter=True,hue='sex',palette='Set1',split=True)
Out[43]:
<matplotlib.axes._subplots.AxesSubplot at 0x12099db70>
In [27]:
sns.swarmplot(x="day", y="total_bill", data=tips,palette = "Paired")
Out[27]:
<matplotlib.axes._subplots.AxesSubplot at 0x10ffbc898>
In [47]:
sns.swarmplot(x="day", y="total_bill",hue='sex',data=tips, palette="Set1", split=True)
Out[47]:
<matplotlib.axes._subplots.AxesSubplot at 0x1211b6da0>

The above plots show swarmplots which can be used alternatively to the boxplots or violin plots.

Combining Categorical Plots

In [31]:
sns.violinplot(x="tip", y="day", data=tips,palette='rainbow')
sns.swarmplot(x="tip", y="day", data=tips,color='black',size=3)
Out[31]:
<matplotlib.axes._subplots.AxesSubplot at 0x1103567f0>

The above plot combines the violin plot and swarmplot. The violin plot gives you a visualition of the overall distribution, while the swarmplot shows each data observation.

factorplot

factorplot is the most general form of a categorical plot. It can take in a kind parameter to adjust the plot type:

In [15]:
sns.factorplot(x='sex',y='total_bill',data=tips,kind='bar')
Out[15]:
<seaborn.axisgrid.FacetGrid at 0x11d03a278>

The factorplot is same as the barplot which we saw earlier.

Distribution Plots

Next in the notebook, let's discuss some plots that allow us to visualize the distribution of a data set. These plots are:

  • distplot
  • jointplot
  • pairplot

We will use the same tips data for the demonstration.

In [32]:
tips.head()   # lets look again at the tips data
Out[32]:
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4

distplot

The distplot shows the distribution of a univariate set of observations.

In [50]:
sns.distplot(tips['total_bill'],color="b",hist=True) # this plot shows a histogram with an overlayed density plot.
Out[50]:
<matplotlib.axes._subplots.AxesSubplot at 0x1145dacf8>
In [52]:
sns.distplot(tips['total_bill'],kde=False,color="b") # if you dont want overlayed density, use kde = false
Out[52]:
<matplotlib.axes._subplots.AxesSubplot at 0x11490cdd8>

jointplot

jointplot() allows you to plot for bivariate data with your choice of what kind parameter to compare with:

  • “scatter”
  • “reg”
  • “resid”
  • “kde”
  • “hex”
In [53]:
sns.jointplot(x='total_billhttp://localhost:8888/notebooks/Desktop/Data%20Science/DataScience_Python/Python-for-Data-Visualization/Seaborn/Seaborn%20Notebook.ipynb#',y='tip',data=tips,kind='scatter')
Out[53]:
<seaborn.axisgrid.JointGrid at 0x114d322e8>

The above plot shows a scatterplot for total_bill and tip.

In [54]:
sns.jointplot(x='total_bill',y='tip',data=tips,kind='hex')
Out[54]:
<seaborn.axisgrid.JointGrid at 0x114d38da0>

You can specify the kind = "hex" to get a hex plot. This plot could be useful if you have many data points.

In [55]:
sns.jointplot(x='total_bill',y='tip',data=tips,kind='reg')
Out[55]:
<seaborn.axisgrid.JointGrid at 0x115295240>

You can also use kind = "reg", which will also plot a regression line in addition to the scatterplot.

pairplot

pairplot will plot pairwise relationships across an entire dataframe (for the numerical columns) and supports a color hue argument (for categorical columns).

In [61]:
sns.pairplot(tips)
Out[61]:
<seaborn.axisgrid.PairGrid at 0x1171c70b8>
In [59]:
sns.pairplot(tips,hue='sex',palette='Set1')
Out[59]:
<seaborn.axisgrid.PairGrid at 0x1171c7f60>

Regression Plots

lmplot allows you to display regression plots, and also allows you to split up those plots based off other features, as well as coloring the hue based off categorical features.

Let's explore how this works:

In [62]:
tips.head() # again we are using the tips data 
Out[62]:
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4

lmplot()

In [63]:
sns.lmplot(x='total_bill',y='tip',data=tips)
Out[63]:
<seaborn.axisgrid.FacetGrid at 0x1171c7eb8>

The above plot shows a scatterplot between total_bill and tip and also displays the regression line.

In [65]:
sns.lmplot(x='total_bill',y='tip',data=tips,hue='sex',palette="Set1")
Out[65]:
<seaborn.axisgrid.FacetGrid at 0x1187eb7f0>

The above plot is same as the previous plot except it plots two regression lines based on the sex feature.

Using a Grid

We can add more variable separation through columns and rows with the use of a grid. Just indicate this with the col or row arguments:

In [66]:
sns.lmplot(x='total_bill',y='tip',data=tips,col='sex')
Out[66]:
<seaborn.axisgrid.FacetGrid at 0x118aa6630>

So in this plot we separated the scatterplot into two columns based on the sex feature. You can do this by specifying the col parameter.

In [67]:
sns.lmplot(x="total_bill", y="tip", row="sex", col="time",data=tips)
Out[67]:
<seaborn.axisgrid.FacetGrid at 0x118b842e8>

We can split the scatter plot based on rows and columns as shown in this example.

In [68]:
sns.lmplot(x='total_bill',y='tip',data=tips,col='day',hue='sex',palette='coolwarm')
Out[68]:
<seaborn.axisgrid.FacetGrid at 0x118d39828>

In this case we split as per the day feature (4 columns) and color is based on the sex feature.

Aspect and Size

Seaborn figures can have their size and aspect ratio adjusted with the size and aspect parameters:

In [69]:
sns.lmplot(x='total_bill',y='tip',data=tips,col='day',hue='sex',palette='coolwarm', aspect=0.6,size=8)
Out[69]:
<seaborn.axisgrid.FacetGrid at 0x119a944e0>

Matrix Plots

Matrix plots allow you to plot data as color-encoded matrices and can also be used to indicate clusters within the data.

Let's begin by exploring seaborn's heatmap and clustermap:

In [70]:
flights = sns.load_dataset('flights') # in addition with the tips data we will also use the flights data. 
In [71]:
tips.head()  # tips data
Out[71]:
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4
In [72]:
flights.head()   # flights data
Out[72]:
year month passengers
0 1949 January 112
1 1949 February 118
2 1949 March 132
3 1949 April 129
4 1949 May 121

Heatmap

In order for a heatmap to work properly, your data should already be in a matrix form, the sns.heatmap function basically just colors it in for you. For example:

In [75]:
# The corr() function gives a matrix output of correlation coefficient of the numerical features of the data. 

tips.corr() 
Out[75]:
total_bill tip size
total_bill 1.000000 0.675734 0.598315
tip 0.675734 1.000000 0.489299
size 0.598315 0.489299 1.000000
In [83]:
sns.heatmap(tips.corr(),cmap="coolwarm",annot=True)
Out[83]:
<matplotlib.axes._subplots.AxesSubplot at 0x11c1e6a90>

Lets consider the example for the flights data. First we will pivot the dataframe so as to make it appropriate for plotting.

In [84]:
flights.pivot_table(values='passengers',index='month',columns='year')
Out[84]:
year 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960
month
January 112 115 145 171 196 204 242 284 315 340 360 417
February 118 126 150 180 196 188 233 277 301 318 342 391
March 132 141 178 193 236 235 267 317 356 362 406 419
April 129 135 163 181 235 227 269 313 348 348 396 461
May 121 125 172 183 229 234 270 318 355 363 420 472
June 135 149 178 218 243 264 315 374 422 435 472 535
July 148 170 199 230 264 302 364 413 465 491 548 622
August 148 170 199 242 272 293 347 405 467 505 559 606
September 136 158 184 209 237 259 312 355 404 404 463 508
October 119 133 162 191 211 229 274 306 347 359 407 461
November 104 114 146 172 180 203 237 271 305 310 362 390
December 118 140 166 194 201 229 278 306 336 337 405 432
In [88]:
pvflights = flights.pivot_table(values='passengers',index='month',columns='year')
sns.heatmap(pvflights,cmap='coolwarm',linecolor='white',linewidths=1)
Out[88]:
<matplotlib.axes._subplots.AxesSubplot at 0x11c896ac8>

clustermap

The clustermap uses hierarchal clustering to produce a clustered version of the heatmap. For example:

In [92]:
sns.clustermap(pvflights,cmap="coolwarm",standard_scale=1)
/anaconda/lib/python3.6/site-packages/matplotlib/cbook.py:136: MatplotlibDeprecationWarning: The axisbg attribute was deprecated in version 2.0. Use facecolor instead.
  warnings.warn(message, mplDeprecation, stacklevel=1)
Out[92]:
<seaborn.matrix.ClusterGrid at 0x11d646160>

Notice now how the years and months are no longer in order, instead they are grouped by similarity in value (passenger count). That means we can begin to infer things from this plot, such as August and July being similar (makes sense, since they are both summer travel months)

I hope you liked this detailed overview of the seaborn module in python for data visualization. I also encourage you to look at the official documentation page for seaborn which have many examples and use cases which you could use.

Seaborn

Great Job!