Seaborn comes with built-in data sets
> import matplotlib.pyplot as plt
+ import seaborn as sns
+ tips = sns.load_dataset('tips')
+ tips.head()
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4
The distplot shows the distribution of a univariate set of observations.
jointplot()
displays two distplots along with a specified center plot:
pairplot()
will plot pairwise relationships across an entire dataframe (for the numerical columns) and supports a color hue argument (for categorical columns).
rugplots
are actually a very simple concept, they just draw a dash mark for every point on a univariate distribution. They are the building block of a KDE plot:
kdeplots
are Kernel Density Estimation plots. These KDE plots replace every single observation with a Gaussian (Normal) distribution centered around that value.
http://en.wikipedia.org/wiki/Kernel_density_estimation
For example:
> import numpy as np
+ from scipy import stats
+
+ #Create dataset
+ dataset = np.random.randn(25)
+
+ # Create another rugplot
+ sns.rugplot(dataset);
+
+ # Set up the x-axis for the plot
+ x_min = dataset.min() - 2
+ x_max = dataset.max() + 2
+
+ # 100 equally spaced points from x_min to x_max
+ x_axis = np.linspace(x_min,x_max,100)
+
+ # Set up the bandwidth
+
+ bandwidth = ((4*dataset.std()**5)/
+ (3*len(dataset)))**.2
+
+ # Create an empty kernel list
+ kernel_list = []
+
+ # Plot each basis function
+ for data_point in dataset:
+
+ # Create a kernel for each point and append to list
+ kernel = stats.norm(data_point,
+ bandwidth).pdf(x_axis)
+ kernel_list.append(kernel)
+
+ #Scale for plotting
+ kernel = kernel / kernel.max()
+ kernel = kernel * .4
+ plt.plot(x_axis,kernel,
+ color = 'grey',alpha=0.5)
+
+ plt.ylim(0,1);
+ plt.show()
> # To get the kde plot we can sum these basis functions.
+
+ # Plot the sum of the basis function
+ sum_of_kde = np.sum(kernel_list,axis=0)
+
+ # Plot figure
+ fig = plt.plot(x_axis,sum_of_kde,color='indianred')
+
+ # Add the initial rugplot
+ sns.rugplot(dataset,c = 'indianred')
+
+ # Get rid of y-tick marks
+ plt.yticks([]);
+
+ # Set title
With the tips dataset:
There are a few main plot types for this:
barplot()
is a general plot that allows you to aggregate the categorical data based off some function. The default is mean:
> sns.set_style(style="white")
+ plt.figure(figsize=(8, 6))
+ sns.barplot(x='sex',y='total_bill',data=tips);
+ plt.show()
You can change the estimator object to your own function, that converts a vector to a scalar:
This is essentially the same as barplot()
except the estimator is explicitly counting the number of occurrences. Which is why we only pass the x value:
A box plot (or box-and-whisker plot) shows the distribution of quantitative data in a way that facilitates comparisons between variables or across levels of a categorical variable. The box shows the quartiles of the dataset while the whiskers extend to show the rest of the distribution, except for points that are determined to be “outliers” using a method that is a function of the inter-quartile range.
> plt.figure(figsize=(8, 6))
+ sns.boxplot(x="day", y="total_bill",
+ data=tips,palette='rainbow');
+ plt.show()
> # Can do entire dataframe with orient='h'
+ sns.boxplot(data=tips
+ ,palette='rainbow',orient='h');
+ plt.show()
A violin plot plays a similar role as a box and whisker plot. It shows the distribution of quantitative data across several levels of one (or more) categorical variables such that those distributions can be compared. Unlike a box plot, in which all of the plot components correspond to actual datapoints, the violin plot features a kernel density estimation of the underlying distribution.
> plt.figure(figsize=(8, 6))
+ sns.violinplot(x="day", y="total_bill",
+ data=tips,palette='rainbow');
+ plt.show()
> sns.violinplot(x="day", y="total_bill",
+ data=tips,hue='sex',split=True,palette='Set1')
+ plt.show()
The stripplot will draw a scatterplot where one variable is categorical. A strip plot can be drawn on its own, but it is also a good complement to a box or violin plot in cases where you want to show all observations along with some representation of the underlying distribution.
> plt.figure(figsize=(8, 6))
+ sns.stripplot(x="day", y="total_bill",
+ data=tips, jitter=False);
+ plt.show()
> sns.stripplot(x="day", y="total_bill",
+ data=tips,jitter=True,hue='sex',
+ palette='Set1');
+ plt.show()
> sns.stripplot(x="day", y="total_bill",
+ data=tips,jitter=True,hue='sex',
+ palette='Set1',dodge=True);
+ plt.show()
The swarmplot is similar to stripplot()
, but the points are adjusted (only along the categorical axis) so that they don’t overlap. This gives a better representation of the distribution of values, although it does not scale as well to large numbers of observations (both in terms of the ability to show all the points and in terms of the computation needed to arrange them).
> sns.swarmplot(x="day", y="total_bill",
+ hue='sex',data=tips, palette="Set1",
+ dodge=True);
+ plt.show()
Matrix plots allow you to plot data as color-encoded matrices and can also be used to indicate clusters within the data.
The flights
and tips
datasets.
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4
year month passengers
0 1949 January 112
1 1949 February 118
2 1949 March 132
3 1949 April 129
4 1949 May 121
In order for a heatmap to work properly your data should be in a matrix form. The sns.heatmap()
function colors it in for you. For example:
total_bill tip size
total_bill 1.000000 0.675734 0.598315
tip 0.675734 1.000000 0.489299
size 0.598315 0.489299 1.000000
Or for the flights data:
year 1949 1950 1951 1952 1953 ... 1956 1957 1958 1959 1960
month ...
January 112 115 145 171 196 ... 284 315 340 360 417
February 118 126 150 180 196 ... 277 301 318 342 391
March 132 141 178 193 236 ... 317 356 362 406 419
April 129 135 163 181 235 ... 313 348 348 396 461
May 121 125 172 183 229 ... 318 355 363 420 472
June 135 149 178 218 243 ... 374 422 435 472 535
July 148 170 199 230 264 ... 413 465 491 548 622
August 148 170 199 242 272 ... 405 467 505 559 606
September 136 158 184 209 237 ... 355 404 404 463 508
October 119 133 162 191 211 ... 306 347 359 407 461
November 104 114 146 172 180 ... 271 305 310 362 390
December 118 140 166 194 201 ... 306 336 337 405 432
[12 rows x 12 columns]
> pvflights = flights.pivot_table(values='passengers',
+ index='month',columns='year')
+ sns.heatmap(pvflights, cmap='Spectral');
+ plt.show()
The clustermap uses hierarchal clustering to produce a clustered version of the heatmap. For example:
Notice now how the years and months are no longer in order, instead they are grouped by similarity in value (passenger count). That means we can begin to infer things from this plot, such as August and July being similar (makes sense, since they are both summer travel months).
> # More options to get the information
+ # a little clearer like normalization
+ sns.clustermap(pvflights,cmap='coolwarm',
+ standard_scale=1);
+ plt.show()
Grids are general types of plots that allow you to map plot types to rows and columns of a grid. This helps you create similar plots separated by features.
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
PairGrid
is a subplot grid for plotting pairwise relationships in a dataset.
> # Map to upper,lower, and diagonal
+ g = sns.PairGrid(iris);
+ g.map_diag(sns.distplot);
+ g.map_upper(plt.scatter);
+ g.map_lower(sns.kdeplot);
+ plt.show()
pairplot
is a simpler version of PairGrid
FacetGrid
is the general way to create grids of plots based off of a feature:
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4
> g = sns.FacetGrid(tips, col="time", row="smoker");
+ g = g.map(plt.hist, "total_bill");
+ plt.show()
> g = sns.FacetGrid(tips, col="time",
+ row="smoker",hue='sex');
+ # Notice how the arguments come
+ # after plt.scatter call
+ g = g.map(plt.scatter, "total_bill",
+ "tip").add_legend();
+ plt.show()
lmplot
allows you to display linear models and allows you to split up those plots based off of features, as well as coloring the hue based off of features.
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4
lmplot
kwargs get passed through to regplot which is a more general form of lmplot()
. regplot
has a scatter_kws
parameter that gets passed to plt.scatter
. So you want to set the s
parameter in that dictionary, which corresponds (a bit confusingly) to the squared markersize. In other words you end up passing a dictionary with the base matplotlib arguments, in this case, s
for size of a scatter plot.
http://matplotlib.org/api/markers_api.html
> sns.lmplot(x='total_bill',y='tip',
+ data=tips,hue='sex',palette='winter',
+ markers=['o','v'],scatter_kws={'s':100});
+ plt.show()
You can set particular styles:
> sns.set_style(style='darkgrid')
+ plt.figure(figsize=(8,6))
+ sns.countplot(x='sex',data=tips);
+ plt.show()
You can use matplotlib’s plt.figure(figsize=(width,height)
to change the size of most seaborn plots.
You can control the size and aspect ratio of most seaborn grid plots by passing in parameters: size, and aspect. For example:
> # Grid Type Plot
+ sns.lmplot(x='total_bill',y='tip',height=2,
+ aspect=4,data=tips);
+ plt.show()