PCA for dimensionality reduction and better understanding dataset

# https://towardsdatascience.com/pca-using-python-scikit-learn-e653f8989e60
# We can use PCA tSpeed up fitting of ML algos
import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
# %matplotlib inline
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
df = pd.read_csv(url, names=['sepal length','sepal width','petal length','petal width','target'])

df.head()
##    sepal length  sepal width  petal length  petal width       target
## 0           5.1          3.5           1.4          0.2  Iris-setosa
## 1           4.9          3.0           1.4          0.2  Iris-setosa
## 2           4.7          3.2           1.3          0.2  Iris-setosa
## 3           4.6          3.1           1.5          0.2  Iris-setosa
## 4           5.0          3.6           1.4          0.2  Iris-setosa

Standardise data

features = ['sepal length', 'sepal width', 'petal length', 'petal width']
x=df.loc[:,features].values
y = df.loc[:,['target']].values

x= StandardScaler().fit_transform(x)
pd.DataFrame(x, columns=features).head()
##    sepal length  sepal width  petal length  petal width
## 0     -0.900681     1.032057     -1.341272    -1.312977
## 1     -1.143017    -0.124958     -1.341272    -1.312977
## 2     -1.385353     0.337848     -1.398138    -1.312977
## 3     -1.506521     0.106445     -1.284407    -1.312977
## 4     -1.021849     1.263460     -1.341272    -1.312977
pca= PCA(n_components=2)

PCs=pca.fit_transform(x)

PC_df = pd.DataFrame(PCs, columns= ["PC1","PC2"])
PC_df.head()
##         PC1       PC2
## 0 -2.264542  0.505704
## 1 -2.086426 -0.655405
## 2 -2.367950 -0.318477
## 3 -2.304197 -0.575368
## 4 -2.388777  0.674767
df.target.head()
## 0    Iris-setosa
## 1    Iris-setosa
## 2    Iris-setosa
## 3    Iris-setosa
## 4    Iris-setosa
## Name: target, dtype: object

final_df=pd.concat([PC_df,df[['target']]],axis=1)
final_df
##           PC1       PC2          target
## 0   -2.264542  0.505704     Iris-setosa
## 1   -2.086426 -0.655405     Iris-setosa
## 2   -2.367950 -0.318477     Iris-setosa
## 3   -2.304197 -0.575368     Iris-setosa
## 4   -2.388777  0.674767     Iris-setosa
## ..        ...       ...             ...
## 145  1.870522  0.382822  Iris-virginica
## 146  1.558492 -0.905314  Iris-virginica
## 147  1.520845  0.266795  Iris-virginica
## 148  1.376391  1.016362  Iris-virginica
## 149  0.959299 -0.022284  Iris-virginica
## 
## [150 rows x 3 columns]
fig = plt.figure(figsize = (8,8))
ax = fig.add_subplot(1,1,1)
ax.set_xlabel('Principal Component 1', fontsize = 15)
ax.set_ylabel('Principal Component 2', fontsize = 15)
ax.set_title('2 Component PCA', fontsize = 20)


targets = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
colors = ['r', 'g', 'b']
for target, color in zip(targets,colors):
    indicesToKeep = final_df['target'] == target
    ax.scatter(final_df.loc[indicesToKeep, 'PC1']
               , final_df.loc[indicesToKeep, 'PC2']
               , c = color
               , s = 50)
ax.legend(targets)
ax.grid()
plt.show()

How much Variance explained

pca.explained_variance_ratio_
## array([0.72770452, 0.23030523])

PCA for speeding ML algos

from sklearn.datasets import fetch_openml
mnist= fetch_openml('mnist_784')
mnist
## {'data': array([[0., 0., 0., ..., 0., 0., 0.],
##        [0., 0., 0., ..., 0., 0., 0.],
##        [0., 0., 0., ..., 0., 0., 0.],
##        ...,
##        [0., 0., 0., ..., 0., 0., 0.],
##        [0., 0., 0., ..., 0., 0., 0.],
##        [0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'format': 'ARFF', 'upload_date': '2014-09-29T03:28:38', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2018-10-03 21:23:30', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
# These are the images
mnist.data.shape

# These are the labels
## (70000, 784)
mnist.target.shape
## (70000,)
from sklearn.model_selection import train_test_split
train_img, test_img, train_lbl, test_lbl = train_test_split(
    mnist.data, mnist.target, test_size=1/7.0, random_state=0)
print(train_img.shape)
## (60000, 784)
print(train_lbl.shape)
## (60000,)
print(test_img.shape)
## (10000, 784)
print(test_lbl.shape)
## (10000,)

Standardise Data


from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
scaler.fit(train_img)
## StandardScaler(copy=True, with_mean=True, with_std=True)
train_img = scaler.transform(train_img)
test_img = scaler.transform(test_img)

Use PCA to reduce dimensionality

from sklearn.decomposition import PCA
pca = PCA(0.95) # select number of components that explain the amount of variance, here .95
pca.fit(train_img) # fit PCA on training dataset only
## PCA(copy=True, iterated_power='auto', n_components=0.95, random_state=None,
##     svd_solver='auto', tol=0.0, whiten=False)
pca.n_components_ # 327 down from 784


# Apply the mapping (transform) to both the training set and the test set.
## 327
train_img = pca.transform(train_img)
test_img = pca.transform(test_img)
from sklearn.linear_model import LogisticRegression

logisticRegr = LogisticRegression(max_iter=12000) # default solver is slow
logisticRegr.fit(train_img, train_lbl)
## LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
##                    intercept_scaling=1, l1_ratio=None, max_iter=12000,
##                    multi_class='auto', n_jobs=None, penalty='l2',
##                    random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
##                    warm_start=False)

Prediction

logisticRegr.predict(test_img[0].reshape(1,-1))
## array(['0'], dtype=object)
logisticRegr.predict(test_img[0:10])
## array(['0', '4', '1', '2', '4', '7', '7', '1', '1', '7'], dtype=object)
lbl_predict= logisticRegr.predict(test_img)

lbl_predict
## array(['0', '4', '1', ..., '1', '3', '0'], dtype=object)

Measuring Model Performance

logisticRegr.score(test_img, test_lbl)
## 0.9184
import sklearn.metrics as metrics
from sklearn.metrics import confusion_matrix
cnf_matrix = confusion_matrix(y_true=test_lbl , y_pred=lbl_predict)
pd.DataFrame(cnf_matrix)
##      0     1    2    3    4    5    6    7    8    9
## 0  965     0    2    2    1   10    9    1    5    1
## 1    0  1105   13    1    1    6    0    4    9    2
## 2    3    15  929   19   13    4   14   12   27    4
## 3    1     7   39  889    1   29    1   13   20   13
## 4    1     3    8    0  901    0   11    7    4   27
## 5    7     2    9   29    7  759   15    3   27    5
## 6    8     2    9    0   13   14  935    1    5    2
## 7    4     4   16    2   11    5    0  977    6   39
## 8    3    19    8   20    7   25    7    2  858   14
## 9    4     4    3   11   30   10    1   34    6  866

Mostly good but the model predicts 7 as 9 and 3 as 2 in 39 instances each.

Viz the Confusion matrix

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
class_names=[0,1,2,3,4,5,6,7,8,9] 
fig, ax = plt.subplots()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names)
## ([<matplotlib.axis.XTick object at 0x1a233b5860>, <matplotlib.axis.XTick object at 0x1a233b50b8>, <matplotlib.axis.XTick object at 0x1a29a94dd8>, <matplotlib.axis.XTick object at 0x11ed4cbe0>, <matplotlib.axis.XTick object at 0x11ed4ccf8>, <matplotlib.axis.XTick object at 0x1a29abf898>, <matplotlib.axis.XTick object at 0x1a29abfd68>, <matplotlib.axis.XTick object at 0x1a29ac8278>, <matplotlib.axis.XTick object at 0x1a29ac8748>, <matplotlib.axis.XTick object at 0x1a29ac8c18>], <a list of 10 Text xticklabel objects>)
plt.yticks(tick_marks, class_names)

# create heatmap
## ([<matplotlib.axis.YTick object at 0x1a29aa65c0>, <matplotlib.axis.YTick object at 0x1a233b5eb8>, <matplotlib.axis.YTick object at 0x1a29a94e10>, <matplotlib.axis.YTick object at 0x1a29ad1668>, <matplotlib.axis.YTick object at 0x1a29ac8828>, <matplotlib.axis.YTick object at 0x1a29abf9e8>, <matplotlib.axis.YTick object at 0x1a29ad1cf8>, <matplotlib.axis.YTick object at 0x1a29ada160>, <matplotlib.axis.YTick object at 0x1a29ada630>, <matplotlib.axis.YTick object at 0x1a29adab38>], <a list of 10 Text yticklabel objects>)
sns.heatmap(pd.DataFrame(cnf_matrix), annot=True, cmap="YlGnBu" ,fmt='g')
ax.xaxis.set_label_position("top")
plt.tight_layout()
plt.title('Confusion matrix', y=1.1)
plt.ylabel('Actual label')
plt.xlabel('Predicted label')

Model evaluation Metrics

Calculate Accuracy, Precision and Recall.

Accuracy: Overall True Positives divided by all cases. Precision: True Positives divided by the sum of true positives and false positives Recall: True Positives divided by the sum of true positives and false negatives

print("Accuracy:", metrics.accuracy_score(y_true=test_lbl , y_pred=lbl_predict) )

# print("Precision:",metrics.precision_score(y_test, y_pred))
# print("Recall:",metrics.recall_score(y_test, y_pred))
## Accuracy: 0.9184
print(metrics.classification_report( y_true=test_lbl , y_pred=lbl_predict))
##               precision    recall  f1-score   support
## 
##            0       0.97      0.97      0.97       996
##            1       0.95      0.97      0.96      1141
##            2       0.90      0.89      0.89      1040
##            3       0.91      0.88      0.90      1013
##            4       0.91      0.94      0.93       962
##            5       0.88      0.88      0.88       863
##            6       0.94      0.95      0.94       989
##            7       0.93      0.92      0.92      1064
##            8       0.89      0.89      0.89       963
##            9       0.89      0.89      0.89       969
## 
##     accuracy                           0.92     10000
##    macro avg       0.92      0.92      0.92     10000
## weighted avg       0.92      0.92      0.92     10000