{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Generating Deep Neural Network Model Explanations via ChemML's Explain Module" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The `chemml.explain` module has three eXplainable AI (XAI) methods - DeepSHAP, LRP, and LIME. It allows both local (for a single instance) and global (aggregated for multiple instances) explanations. The explainations are in the form of a relevance score attributed to each feature used to build the DNN model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use a sample dataset from ChemML library which has the SMILES codes and 200 Dragon molecular descriptors (features) for 500 small organic molecules with their densities in $kg/m^3$. We split the dataset into training and testing subsets and scale them. We then build and train a pytorch DNN on the training subset." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequential(\n", " (0): Linear(in_features=200, out_features=100, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=100, out_features=100, bias=True)\n", " (3): ReLU()\n", " (4): Linear(in_features=100, out_features=100, bias=True)\n", " (5): ReLU()\n", " (6): Linear(in_features=100, out_features=1, bias=True)\n", ")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "import shap\n", "from chemml.models import MLP\n", "from chemml.datasets import load_organic_density\n", "\n", "from sklearn.preprocessing import StandardScaler\n", "from chemml.explain import Explain\n", "\n", "_, y, X = load_organic_density()\n", "columns = list(X.columns)\n", "\n", "y = y.values.reshape(y.shape[0], 1).astype('float32')\n", "X = X.values.reshape(X.shape[0], X.shape[1]).astype('float32')\n", "\n", "# split 0.9 train / 0.1 test\n", "ytr = y[:450, :]\n", "yte = y[450:, :]\n", "Xtr = X[:450, :]\n", "Xte = X[450:, :]\n", "\n", "scale = StandardScaler()\n", "scale_y = StandardScaler()\n", "Xtr = scale.fit_transform(Xtr)\n", "Xte = scale.transform(Xte)\n", "ytr = scale_y.fit_transform(ytr)\n", "\n", "# PYTORCH\n", "r1_pytorch = MLP(engine='pytorch',nfeatures=Xtr.shape[1], nneurons=[100,100,100], activations=['ReLU','ReLU','ReLU'],\n", " learning_rate=0.001, alpha=0.0001, nepochs=100, batch_size=100, loss='mean_squared_error', \n", " is_regression=True, nclasses=None, layer_config_file=None, opt_config='Adam')\n", "\n", "r1_pytorch.fit(Xtr, ytr)\n", "engine_model = r1_pytorch.get_model()\n", "engine_model.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### DeepSHAP Explanations\n", "\n", "We instantiate the `chemml.explain` object with an instance to be explained, the pytorch DNN object, and the feature names (columns). We then call the DeepSHAP method with a set of background or reference samples as directed by the [SHAP library](https://shap-lrjball.readthedocs.io/en/latest/generated/shap.DeepExplainer.html)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | MW | \n", "AMW | \n", "Sv | \n", "Se | \n", "Sp | \n", "Si | \n", "Mv | \n", "Me | \n", "Mp | \n", "Mi | \n", "... | \n", "X4Av | \n", "X5Av | \n", "X0sol | \n", "X1sol | \n", "X2sol | \n", "X3sol | \n", "X4sol | \n", "X5sol | \n", "XMOD | \n", "RDCHI | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.031939 | \n", "-0.05199 | \n", "-0.006925 | \n", "-0.006857 | \n", "0.027388 | \n", "-0.020872 | \n", "-0.057089 | \n", "-0.115971 | \n", "-0.00581 | \n", "-0.000805 | \n", "... | \n", "-0.087464 | \n", "-0.240233 | \n", "0.002718 | \n", "0.003733 | \n", "0.008075 | \n", "0.011319 | \n", "0.007988 | \n", "-0.000032 | \n", "0.01838 | \n", "0.002904 | \n", "
1 rows × 200 columns
\n", "\n", " | MW | \n", "AMW | \n", "Sv | \n", "Se | \n", "Sp | \n", "Si | \n", "Mv | \n", "Me | \n", "Mp | \n", "Mi | \n", "... | \n", "X4Av | \n", "X5Av | \n", "X0sol | \n", "X1sol | \n", "X2sol | \n", "X3sol | \n", "X4sol | \n", "X5sol | \n", "XMOD | \n", "RDCHI | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.031939 | \n", "-0.051990 | \n", "-0.006925 | \n", "-0.006857 | \n", "0.027388 | \n", "-0.020872 | \n", "-0.057089 | \n", "-0.115971 | \n", "-0.005810 | \n", "-0.000805 | \n", "... | \n", "-0.087464 | \n", "-0.240233 | \n", "0.002718 | \n", "0.003733 | \n", "0.008075 | \n", "0.011319 | \n", "0.007988 | \n", "-0.000032 | \n", "0.018380 | \n", "0.002904 | \n", "
1 | \n", "-0.030329 | \n", "0.045286 | \n", "0.003471 | \n", "0.009392 | \n", "-0.033369 | \n", "0.005686 | \n", "0.015329 | \n", "0.132120 | \n", "-0.011706 | \n", "0.006712 | \n", "... | \n", "0.013431 | \n", "0.027834 | \n", "0.004760 | \n", "-0.020271 | \n", "-0.006235 | \n", "-0.019104 | \n", "-0.005018 | \n", "-0.010347 | \n", "-0.040244 | \n", "-0.002089 | \n", "
2 | \n", "0.073318 | \n", "-0.011014 | \n", "-0.010239 | \n", "-0.012253 | \n", "0.034129 | \n", "-0.015133 | \n", "-0.017885 | \n", "0.005359 | \n", "-0.016987 | \n", "0.026442 | \n", "... | \n", "0.006582 | \n", "0.013818 | \n", "-0.008897 | \n", "0.031285 | \n", "0.015545 | \n", "0.022423 | \n", "0.014267 | \n", "-0.002429 | \n", "0.078428 | \n", "-0.039955 | \n", "
3 | \n", "-0.051775 | \n", "0.060634 | \n", "0.028966 | \n", "0.011126 | \n", "-0.046096 | \n", "0.022779 | \n", "-0.021990 | \n", "0.079879 | \n", "-0.005687 | \n", "-0.001833 | \n", "... | \n", "-0.055763 | \n", "-0.149791 | \n", "0.011571 | \n", "-0.056075 | \n", "-0.012510 | \n", "-0.014471 | \n", "-0.010389 | \n", "-0.003207 | \n", "-0.077971 | \n", "0.076956 | \n", "
4 | \n", "0.078189 | \n", "0.049846 | \n", "-0.006229 | \n", "-0.007896 | \n", "0.024672 | \n", "-0.011916 | \n", "-0.001073 | \n", "0.035578 | \n", "-0.000805 | \n", "0.019776 | \n", "... | \n", "0.003993 | \n", "0.007286 | \n", "-0.010335 | \n", "0.039825 | \n", "0.012199 | \n", "0.021159 | \n", "0.010633 | \n", "0.004272 | \n", "0.090355 | \n", "-0.020396 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
445 | \n", "0.008847 | \n", "0.394204 | \n", "0.027389 | \n", "0.021292 | \n", "-0.046195 | \n", "0.024054 | \n", "0.122414 | \n", "0.310809 | \n", "0.111767 | \n", "0.002920 | \n", "... | \n", "0.001368 | \n", "0.002888 | \n", "0.011813 | \n", "-0.006262 | \n", "-0.004366 | \n", "0.002759 | \n", "0.004123 | \n", "-0.014602 | \n", "0.027206 | \n", "0.010870 | \n", "
446 | \n", "0.064879 | \n", "0.029331 | \n", "-0.023800 | \n", "-0.014171 | \n", "0.023401 | \n", "-0.014327 | \n", "0.026467 | \n", "-0.096276 | \n", "0.037964 | \n", "-0.028652 | \n", "... | \n", "0.000230 | \n", "0.005973 | \n", "0.003281 | \n", "0.038597 | \n", "0.030991 | \n", "0.038983 | \n", "0.024876 | \n", "0.010109 | \n", "0.078109 | \n", "-0.032479 | \n", "
447 | \n", "0.086556 | \n", "-0.004422 | \n", "-0.007464 | \n", "-0.019316 | \n", "0.037914 | \n", "-0.011631 | \n", "-0.008085 | \n", "0.005726 | \n", "-0.011518 | \n", "0.041259 | \n", "... | \n", "0.008781 | \n", "0.014713 | \n", "-0.020406 | \n", "0.043650 | \n", "0.018309 | \n", "0.026390 | \n", "0.009535 | \n", "0.016264 | \n", "0.104574 | \n", "-0.033671 | \n", "
448 | \n", "0.090357 | \n", "0.031684 | \n", "-0.036333 | \n", "-0.023286 | \n", "0.013101 | \n", "-0.025404 | \n", "0.024115 | \n", "-0.069291 | \n", "0.041967 | \n", "-0.028402 | \n", "... | \n", "0.004694 | \n", "0.006852 | \n", "0.001424 | \n", "0.085513 | \n", "0.071626 | \n", "0.091616 | \n", "0.028419 | \n", "0.041313 | \n", "0.184755 | \n", "-0.036952 | \n", "
449 | \n", "0.110965 | \n", "0.032688 | \n", "-0.040615 | \n", "-0.028315 | \n", "0.032846 | \n", "-0.037154 | \n", "0.013515 | \n", "-0.028467 | \n", "0.022227 | \n", "-0.025841 | \n", "... | \n", "-0.010460 | \n", "-0.014772 | \n", "-0.005115 | \n", "0.075755 | \n", "0.068138 | \n", "0.088936 | \n", "0.046433 | \n", "0.050386 | \n", "0.161519 | \n", "-0.073746 | \n", "
450 rows × 200 columns
\n", "\n", " | MW | \n", "AMW | \n", "Sv | \n", "Se | \n", "Sp | \n", "Si | \n", "Mv | \n", "Me | \n", "Mp | \n", "Mi | \n", "... | \n", "X4Av | \n", "X5Av | \n", "X0sol | \n", "X1sol | \n", "X2sol | \n", "X3sol | \n", "X4sol | \n", "X5sol | \n", "XMOD | \n", "RDCHI | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.0357 | \n", "0.060859 | \n", "-0.030067 | \n", "-0.000306 | \n", "0.011614 | \n", "-0.00287 | \n", "0.038521 | \n", "0.075271 | \n", "0.012589 | \n", "-0.001871 | \n", "... | \n", "0.087779 | \n", "0.240536 | \n", "-0.009204 | \n", "0.0284 | \n", "0.03282 | \n", "0.028609 | \n", "0.015671 | \n", "-0.003383 | \n", "0.059679 | \n", "-0.03687 | \n", "
1 rows × 200 columns
\n", "\n", " | MW | \n", "AMW | \n", "Sv | \n", "Se | \n", "Sp | \n", "Si | \n", "Mv | \n", "Me | \n", "Mp | \n", "Mi | \n", "... | \n", "X4Av | \n", "X5Av | \n", "X0sol | \n", "X1sol | \n", "X2sol | \n", "X3sol | \n", "X4sol | \n", "X5sol | \n", "XMOD | \n", "RDCHI | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "-0.002375 | \n", "-0.002310 | \n", "0.003538 | \n", "0.000393 | \n", "0.002531 | \n", "0.000539 | \n", "-0.003484 | \n", "0.005217 | \n", "-0.006279 | \n", "0.010001 | \n", "... | \n", "-0.000975 | \n", "0.000208 | \n", "-0.000257 | \n", "-0.006816 | \n", "-0.006439 | \n", "-0.011138 | \n", "0.000251 | \n", "-0.004060 | \n", "-0.012186 | \n", "0.000308 | \n", "
1 | \n", "0.031408 | \n", "0.003845 | \n", "-0.003586 | \n", "-0.000377 | \n", "0.027020 | \n", "-0.001780 | \n", "-0.002846 | \n", "-0.015842 | \n", "-0.002071 | \n", "-0.001520 | \n", "... | \n", "-0.003107 | \n", "-0.015789 | \n", "-0.003154 | \n", "0.009662 | \n", "0.004628 | \n", "0.018693 | \n", "0.008019 | \n", "0.004508 | \n", "0.030692 | \n", "-0.002962 | \n", "
2 | \n", "-0.003544 | \n", "0.010888 | \n", "0.010702 | \n", "0.010161 | \n", "0.009560 | \n", "0.011691 | \n", "-0.010409 | \n", "0.024565 | \n", "-0.007086 | \n", "0.010840 | \n", "... | \n", "-0.008347 | \n", "-0.008716 | \n", "-0.009222 | \n", "-0.013408 | \n", "-0.038706 | \n", "-0.029874 | \n", "-0.009861 | \n", "-0.006868 | \n", "-0.024265 | \n", "0.010711 | \n", "
3 | \n", "0.336607 | \n", "0.222507 | \n", "-0.021707 | \n", "0.050956 | \n", "0.121637 | \n", "-0.015981 | \n", "0.095192 | \n", "0.031510 | \n", "0.073009 | \n", "-0.109299 | \n", "... | \n", "-0.013377 | \n", "-0.021679 | \n", "-0.057243 | \n", "0.103574 | \n", "-0.009558 | \n", "0.117196 | \n", "0.026707 | \n", "0.029786 | \n", "0.304746 | \n", "-0.031534 | \n", "
4 | \n", "-0.008907 | \n", "-0.007431 | \n", "0.001667 | \n", "0.000218 | \n", "-0.007712 | \n", "0.000166 | \n", "-0.007769 | \n", "0.010181 | \n", "-0.007833 | \n", "0.010821 | \n", "... | \n", "0.004072 | \n", "0.007154 | \n", "0.001378 | \n", "-0.005021 | \n", "-0.004586 | \n", "-0.005019 | \n", "-0.004799 | \n", "-0.008497 | \n", "-0.007774 | \n", "-0.002996 | \n", "
5 rows × 200 columns
\n", "\n", " | Mean Absolute Relevance Score | \n", "Mean Relevance Score | \n", "
---|---|---|
nHet | \n", "0.083375 | \n", "0.080595 | \n", "
ON1V | \n", "0.082356 | \n", "0.054230 | \n", "
nS | \n", "0.082029 | \n", "0.073870 | \n", "
nHM | \n", "0.081260 | \n", "0.072965 | \n", "
nCsp3 | \n", "0.081243 | \n", "0.026852 | \n", "
... | \n", "... | \n", "... | \n", "
SRW10 | \n", "0.004217 | \n", "0.000566 | \n", "
MWC01 | \n", "0.004144 | \n", "0.001311 | \n", "
Psi_i_0d | \n", "0.003676 | \n", "-0.003646 | \n", "
LPRS | \n", "0.002950 | \n", "0.000132 | \n", "
Psi_i_1d | \n", "0.001783 | \n", "0.001676 | \n", "
200 rows × 2 columns
\n", "