Distribution Plots


Load Data

Seaborn comes with built-in data sets

> import matplotlib.pyplot as plt
+ import seaborn as sns
+ tips = sns.load_dataset('tips')
+ tips.head()
   total_bill   tip     sex smoker  day    time  size
0       16.99  1.01  Female     No  Sun  Dinner     2
1       10.34  1.66    Male     No  Sun  Dinner     3
2       21.01  3.50    Male     No  Sun  Dinner     3
3       23.68  3.31    Male     No  Sun  Dinner     2
4       24.59  3.61  Female     No  Sun  Dinner     4

Distplot

The distplot shows the distribution of a univariate set of observations.

> sns.set_style(style="darkgrid")
+ plt.figure(figsize=(8, 6))
+ sns.distplot(tips['total_bill'])

> sns.distplot(tips['total_bill'],
+              kde=False, bins=30)

Jointplot

jointplot() displays two distplots along with a specified center plot:

  • “scatter”
  • “reg”
  • “resid”
  • “kde”
  • “hex”
> sns.jointplot(x='total_bill',y='tip',
+               data=tips,kind='scatter');
+ plt.show()

> sns.jointplot(x='total_bill',y='tip',
+               data=tips,kind='hex');
+ plt.show()

> sns.jointplot(x='total_bill',y='tip',
+               data=tips,kind='reg');
+ plt.show()

Pairplot

pairplot() will plot pairwise relationships across an entire dataframe (for the numerical columns) and supports a color hue argument (for categorical columns).

> sns.pairplot(tips);
+ plt.show()

> sns.pairplot(tips,hue='sex',
+              palette='coolwarm');
+ plt.show()

Rugplot

rugplots are actually a very simple concept, they just draw a dash mark for every point on a univariate distribution. They are the building block of a KDE plot:

> plt.figure(figsize=(8, 6))
+ sns.rugplot(tips['total_bill']);
+ plt.show()

kdeplot

kdeplots are Kernel Density Estimation plots. These KDE plots replace every single observation with a Gaussian (Normal) distribution centered around that value.

http://en.wikipedia.org/wiki/Kernel_density_estimation

For example:

> import numpy as np
+ from scipy import stats
+ 
+ #Create dataset
+ dataset = np.random.randn(25)
+ 
+ # Create another rugplot
+ sns.rugplot(dataset);
+ 
+ # Set up the x-axis for the plot
+ x_min = dataset.min() - 2
+ x_max = dataset.max() + 2
+ 
+ # 100 equally spaced points from x_min to x_max
+ x_axis = np.linspace(x_min,x_max,100)
+ 
+ # Set up the bandwidth
+ 
+ bandwidth = ((4*dataset.std()**5)/
+              (3*len(dataset)))**.2
+ 
+ # Create an empty kernel list
+ kernel_list = []
+ 
+ # Plot each basis function
+ for data_point in dataset:
+     
+     # Create a kernel for each point and append to list
+     kernel = stats.norm(data_point,
+                 bandwidth).pdf(x_axis)
+     kernel_list.append(kernel)
+     
+     #Scale for plotting
+     kernel = kernel / kernel.max()
+     kernel = kernel * .4
+     plt.plot(x_axis,kernel,
+              color = 'grey',alpha=0.5)
+ 
+ plt.ylim(0,1);
+ plt.show()

> # To get the kde plot we can sum these basis functions.
+ 
+ # Plot the sum of the basis function
+ sum_of_kde = np.sum(kernel_list,axis=0)
+ 
+ # Plot figure
+ fig = plt.plot(x_axis,sum_of_kde,color='indianred')
+ 
+ # Add the initial rugplot
+ sns.rugplot(dataset,c = 'indianred')
+ 
+ # Get rid of y-tick marks
+ plt.yticks([]);
+ 
+ # Set title
> plt.suptitle("Sum of the Basis Functions")
+ plt.show()

With the tips dataset:

> sns.kdeplot(tips['total_bill']);
+ sns.rugplot(tips['total_bill']);
+ plt.show()

> sns.kdeplot(tips['tip']);
+ sns.rugplot(tips['tip']);
+ plt.show()

Categorical Plots


There are a few main plot types for this:

  • catplot
  • boxplot
  • violinplot
  • stripplot
  • swarmplot
  • barplot
  • countplot

Barplot

barplot() is a general plot that allows you to aggregate the categorical data based off some function. The default is mean:

> sns.set_style(style="white")
+ plt.figure(figsize=(8, 6))
+ sns.barplot(x='sex',y='total_bill',data=tips);
+ plt.show()

You can change the estimator object to your own function, that converts a vector to a scalar:

> sns.barplot(x='sex',y='total_bill',
+             data=tips,estimator=np.std)
+ plt.show()

Countplot

This is essentially the same as barplot() except the estimator is explicitly counting the number of occurrences. Which is why we only pass the x value:

> plt.figure(figsize=(8, 6))
+ sns.countplot(x='sex',data=tips);
+ plt.show()

Boxplot

A box plot (or box-and-whisker plot) shows the distribution of quantitative data in a way that facilitates comparisons between variables or across levels of a categorical variable. The box shows the quartiles of the dataset while the whiskers extend to show the rest of the distribution, except for points that are determined to be “outliers” using a method that is a function of the inter-quartile range.

> plt.figure(figsize=(8, 6))
+ sns.boxplot(x="day", y="total_bill", 
+             data=tips,palette='rainbow');
+ plt.show()            

> # Can do entire dataframe with orient='h'
+ sns.boxplot(data=tips
+             ,palette='rainbow',orient='h');
+ plt.show()

> sns.boxplot(x="day", y="total_bill", 
+     hue="smoker",data=tips, palette="coolwarm");
+ plt.show()

Violinplot

A violin plot plays a similar role as a box and whisker plot. It shows the distribution of quantitative data across several levels of one (or more) categorical variables such that those distributions can be compared. Unlike a box plot, in which all of the plot components correspond to actual datapoints, the violin plot features a kernel density estimation of the underlying distribution.

> plt.figure(figsize=(8, 6))
+ sns.violinplot(x="day", y="total_bill", 
+         data=tips,palette='rainbow');
+ plt.show()

> sns.violinplot(x="day", y="total_bill", 
+     data=tips,hue='sex',palette='Set1')
+ plt.show()

> sns.violinplot(x="day", y="total_bill", 
+     data=tips,hue='sex',split=True,palette='Set1')
+ plt.show()

Stripplot

The stripplot will draw a scatterplot where one variable is categorical. A strip plot can be drawn on its own, but it is also a good complement to a box or violin plot in cases where you want to show all observations along with some representation of the underlying distribution.

> plt.figure(figsize=(8, 6))
+ sns.stripplot(x="day", y="total_bill", 
+               data=tips, jitter=False);
+ plt.show()

> sns.stripplot(x="day", y="total_bill", 
+               data=tips,jitter=True);
+ plt.show()

> sns.stripplot(x="day", y="total_bill", 
+     data=tips,jitter=True,hue='sex',
+               palette='Set1');
+ plt.show()

> sns.stripplot(x="day", y="total_bill", 
+     data=tips,jitter=True,hue='sex',
+             palette='Set1',dodge=True);
+ plt.show()

Swarmplot

The swarmplot is similar to stripplot(), but the points are adjusted (only along the categorical axis) so that they don’t overlap. This gives a better representation of the distribution of values, although it does not scale as well to large numbers of observations (both in terms of the ability to show all the points and in terms of the computation needed to arrange them).

> plt.figure(figsize=(8, 6))
+ sns.swarmplot(x="day", y="total_bill", data=tips);
+ plt.show()

> sns.swarmplot(x="day", y="total_bill",
+         hue='sex',data=tips, palette="Set1",
+               dodge=True);
+ plt.show()

Combining Categorical Plots

> sns.violinplot(x="tip", y="day", data=tips,
+                palette='rainbow')
+ sns.swarmplot(x="tip", y="day", data=tips,
+               color='black',size=3);
+ plt.show()

Catplot

catplot is the most general form of a categorical plot. It can take in a kind parameter to adjust the plot type:

> sns.catplot(x='day',y='total_bill',
+                data=tips,kind='bar');
+ plt.show()

Matrix Plots


Matrix plots allow you to plot data as color-encoded matrices and can also be used to indicate clusters within the data.

Datasets

The flights and tips datasets.

> flights = sns.load_dataset('flights')
> tips.head()
   total_bill   tip     sex smoker  day    time  size
0       16.99  1.01  Female     No  Sun  Dinner     2
1       10.34  1.66    Male     No  Sun  Dinner     3
2       21.01  3.50    Male     No  Sun  Dinner     3
3       23.68  3.31    Male     No  Sun  Dinner     2
4       24.59  3.61  Female     No  Sun  Dinner     4
> flights.head()
   year     month  passengers
0  1949   January         112
1  1949  February         118
2  1949     March         132
3  1949     April         129
4  1949       May         121

Heatmap

In order for a heatmap to work properly your data should be in a matrix form. The sns.heatmap() function colors it in for you. For example:

> # Matrix form for correlation data
+ tips.corr()
            total_bill       tip      size
total_bill    1.000000  0.675734  0.598315
tip           0.675734  1.000000  0.489299
size          0.598315  0.489299  1.000000
> plt.figure(figsize=(8, 6))
+ sns.heatmap(tips.corr());
+ plt.show()

> sns.heatmap(tips.corr(),
+         cmap='coolwarm',annot=True);
+ plt.show()

Or for the flights data:

> flights.pivot_table(values='passengers',
+             index='month',columns='year')
year       1949  1950  1951  1952  1953  ...  1956  1957  1958  1959  1960
month                                    ...                              
January     112   115   145   171   196  ...   284   315   340   360   417
February    118   126   150   180   196  ...   277   301   318   342   391
March       132   141   178   193   236  ...   317   356   362   406   419
April       129   135   163   181   235  ...   313   348   348   396   461
May         121   125   172   183   229  ...   318   355   363   420   472
June        135   149   178   218   243  ...   374   422   435   472   535
July        148   170   199   230   264  ...   413   465   491   548   622
August      148   170   199   242   272  ...   405   467   505   559   606
September   136   158   184   209   237  ...   355   404   404   463   508
October     119   133   162   191   211  ...   306   347   359   407   461
November    104   114   146   172   180  ...   271   305   310   362   390
December    118   140   166   194   201  ...   306   336   337   405   432

[12 rows x 12 columns]
> pvflights = flights.pivot_table(values='passengers',
+                     index='month',columns='year')
+ sns.heatmap(pvflights, cmap='Spectral');
+ plt.show()

> sns.heatmap(pvflights,cmap='magma',
+             linecolor='white',linewidths=1);
+ plt.show()

Clustermap

The clustermap uses hierarchal clustering to produce a clustered version of the heatmap. For example:

> sns.clustermap(pvflights, cmap="GnBu");
+ plt.show()

Notice now how the years and months are no longer in order, instead they are grouped by similarity in value (passenger count). That means we can begin to infer things from this plot, such as August and July being similar (makes sense, since they are both summer travel months).

> # More options to get the information 
+ # a little clearer like normalization
+ sns.clustermap(pvflights,cmap='coolwarm',
+                standard_scale=1);
+ plt.show()

Grids


Grids are general types of plots that allow you to map plot types to rows and columns of a grid. This helps you create similar plots separated by features.

> iris = sns.load_dataset('iris')
+ iris.head()
   sepal_length  sepal_width  petal_length  petal_width species
0           5.1          3.5           1.4          0.2  setosa
1           4.9          3.0           1.4          0.2  setosa
2           4.7          3.2           1.3          0.2  setosa
3           4.6          3.1           1.5          0.2  setosa
4           5.0          3.6           1.4          0.2  setosa

PairGrid

PairGrid is a subplot grid for plotting pairwise relationships in a dataset.

> # Just the Grid
+ sns.set_style(style="darkgrid")
+ sns.PairGrid(iris);
+ plt.show()

> # Then you map to the grid
+ g = sns.PairGrid(iris);
+ g.map(plt.scatter);
+ plt.show()

> # Map to upper,lower, and diagonal
+ g = sns.PairGrid(iris);
+ g.map_diag(sns.distplot);
+ g.map_upper(plt.scatter);
+ g.map_lower(sns.kdeplot);
+ plt.show()

Pairplot

pairplot is a simpler version of PairGrid

> plt.rcParams.update({'figure.max_open_warning': 0})
+ sns.pairplot(iris);
+ plt.show()

> sns.pairplot(iris,hue='species',
+              palette='rainbow');
+ plt.show()

Facet Grid

FacetGrid is the general way to create grids of plots based off of a feature:

> tips.head()
   total_bill   tip     sex smoker  day    time  size
0       16.99  1.01  Female     No  Sun  Dinner     2
1       10.34  1.66    Male     No  Sun  Dinner     3
2       21.01  3.50    Male     No  Sun  Dinner     3
3       23.68  3.31    Male     No  Sun  Dinner     2
4       24.59  3.61  Female     No  Sun  Dinner     4
> # Just the Grid
+ g = sns.FacetGrid(tips, col="time", row="smoker");
+ plt.show()

> g = sns.FacetGrid(tips, col="time",  row="smoker");
+ g = g.map(plt.hist, "total_bill");
+ plt.show()

> g = sns.FacetGrid(tips, col="time",  
+             row="smoker",hue='sex');
+ # Notice how the arguments come 
+ # after plt.scatter call
+ g = g.map(plt.scatter, "total_bill", 
+           "tip").add_legend();
+ plt.show()

JointGrid

JointGrid is the general version for jointplot() type grids, for a quick example:

> g = sns.JointGrid(x="total_bill", 
+                   y="tip", data=tips);
+ plt.show()

> g = sns.JointGrid(x="total_bill", 
+           y="tip", data=tips);
+ g = g.plot(sns.regplot, sns.distplot);
+ plt.show()

Regression Plots


Lmplot

lmplot allows you to display linear models and allows you to split up those plots based off of features, as well as coloring the hue based off of features.

> tips.head()
   total_bill   tip     sex smoker  day    time  size
0       16.99  1.01  Female     No  Sun  Dinner     2
1       10.34  1.66    Male     No  Sun  Dinner     3
2       21.01  3.50    Male     No  Sun  Dinner     3
3       23.68  3.31    Male     No  Sun  Dinner     2
4       24.59  3.61  Female     No  Sun  Dinner     4
> sns.set_style(style="darkgrid")
+ sns.lmplot(x='total_bill',y='tip',data=tips);
+ plt.show()

> sns.lmplot(x='total_bill',y='tip',
+            data=tips,hue='sex');
+ plt.show()

> sns.lmplot(x='total_bill',y='tip',data=tips,
+            hue='sex',palette='spring');
+ plt.show()

Working with Markers

lmplot kwargs get passed through to regplot which is a more general form of lmplot(). regplot has a scatter_kws parameter that gets passed to plt.scatter. So you want to set the s parameter in that dictionary, which corresponds (a bit confusingly) to the squared markersize. In other words you end up passing a dictionary with the base matplotlib arguments, in this case, s for size of a scatter plot.

http://matplotlib.org/api/markers_api.html

> sns.lmplot(x='total_bill',y='tip',
+     data=tips,hue='sex',palette='winter',
+     markers=['o','v'],scatter_kws={'s':100});
+ plt.show()

Using a Grid

We can add more variable separation through columns and rows with the use of a grid. Just indicate this with the col or row arguments:

> sns.lmplot(x='total_bill',y='tip',
+            data=tips,col='sex');
+ plt.show()

> sns.lmplot(x="total_bill", y="tip", 
+         row="sex", col="time",data=tips);
+ plt.show()

> sns.lmplot(x='total_bill',y='tip',
+            data=tips,col='day',hue='sex',
+            palette='hot');
+ plt.show()

Aspect and Size

Seaborn figures can have their size and aspect ratio adjusted with the size and aspect parameters:

> sns.lmplot(x='total_bill',y='tip',data=tips,
+            col='day',hue='sex',palette='hot',
+           aspect=0.6,height=8);
+ plt.show()

Style and Color


Styles

You can set particular styles:

> sns.set_style(style='darkgrid')
+ plt.figure(figsize=(8,6))
+ sns.countplot(x='sex',data=tips);
+ plt.show()

> sns.set_style('white')
+ sns.countplot(x='sex',data=tips);
+ plt.show()

> sns.set_style('ticks')
+ sns.countplot(x='sex',data=tips,
+               palette='deep');
+ plt.show()

Spine Removal

> sns.countplot(x='sex',data=tips);
+ sns.despine()
+ plt.show()

> sns.countplot(x='sex',data=tips);
+ sns.despine(left=True, bottom=True)
+ plt.show()

> sns.countplot(x='sex',data=tips);
+ sns.despine(left=True, bottom=True)
+ plt.show()

Size and Aspect

You can use matplotlib’s plt.figure(figsize=(width,height) to change the size of most seaborn plots.

You can control the size and aspect ratio of most seaborn grid plots by passing in parameters: size, and aspect. For example:

> # Non Grid Plot
+ plt.figure(figsize=(12,3))
+ sns.countplot(x='sex',data=tips);
+ plt.show()

> # Grid Type Plot
+ sns.lmplot(x='total_bill',y='tip',height=2,
+            aspect=4,data=tips);
+ plt.show()

Scale and Context

The set_context() allows you to override default parameters:

> plt.figure(figsize=(12,12))
+ sns.set_context('poster',font_scale=4)
+ sns.countplot(x='sex',data=tips,
+       palette='coolwarm');
+ plt.tight_layout()
+ plt.show()