{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Random Graph Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### This tutorial will introduce the following random graph models: \n",
    "- Erdos-Reyni (ER)\n",
    "- Degree-corrected Erdos-Reyni (DCER)\n",
    "- Stochastic block model (SBM)\n",
    "- Degree-corrected stochastic block model (DCSBM)\n",
    "- Random dot product graph (RDPG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load some data from _graspologic_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For this example we will use the _Drosophila melanogaster_ larva right mushroom body connectome from [Eichler et al. 2017](https://www.ncbi.nlm.nih.gov/pubmed/28796202). Here we will consider a binarized and directed version of the graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "from graspologic.datasets import load_drosophila_right\n",
    "from graspologic.plot import heatmap\n",
    "from graspologic.utils import binarize, symmetrize\n",
    "%matplotlib inline\n",
    "\n",
    "adj, labels = load_drosophila_right(return_labels=True)\n",
    "adj = binarize(adj)\n",
    "_ = heatmap(adj,\n",
    "        inner_hier_labels=labels,\n",
    "        title='Drosophila right MB',\n",
    "        font_scale=1.5,\n",
    "        sort_nodes=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preliminaries\n",
    "\n",
    "$n$ - the number of nodes in the graph\n",
    "\n",
    "$A$ - $n \\times n$ adjacency matrix\n",
    "\n",
    "$P$ - $n \\times n$ matrix of probabilities\n",
    "\n",
    "\n",
    "For the class of models we will consider here, a graph (adjacency matrix) $A$ is sampled as follows:\n",
    "\n",
    "$$A \\sim Bernoulli(P)$$\n",
    "\n",
    "While each model we will discuss follows this formulation, they differ in how the matrix $P$ is constructed. So, for each model, we will consider how to model $P_{ij}$, or the probability of connection between any node $i$ and $j$, with $i \\neq j$ in this case (i.e. no \"loops\" are allowed for the sake of this tutorial).\n",
    "\n",
    "\n",
    "------\n",
    "\n",
    "\n",
    "### For each graph model we will show: \n",
    "- how the model is formulated \n",
    "- how to fit the model using graspologic\n",
    "- the P matrix that was fit by the model \n",
    "- a single sample from the fit model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Erdos-Reyni (ER)\n",
    "The Erdos-Reyni (ER) model is the simplest random graph model. We are interested in modeling the probability of an edge existing between any two nodes, $i$ and $j$. We denote this probability $P_{ij}$. For the ER model:\n",
    "\n",
    "$$P_{ij} = p$$\n",
    "\n",
    "for any combination of $i$ and $j$.  \n",
    "This means that the one parameter $p$ is the overall probability of connection for any two nodes. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from graspologic.models import EREstimator\n",
    "er = EREstimator(directed=True,loops=False)\n",
    "er.fit(adj)\n",
    "print(f\"ER \\\"p\\\" parameter: {er.p_}\")\n",
    "heatmap(er.p_mat_,\n",
    "        inner_hier_labels=labels,\n",
    "        font_scale=1.5,\n",
    "        title=\"ER probability matrix\",\n",
    "        vmin=0, vmax=1, \n",
    "        sort_nodes=True)\n",
    "_ = heatmap(er.sample()[0],\n",
    "        inner_hier_labels=labels,\n",
    "        font_scale=1.5,\n",
    "        title=\"ER sample\",\n",
    "        sort_nodes=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Degree-corrected Erdos-Reyni (DCER)\n",
    "\n",
    "A slightly more complicated variant of the ER model is the degree-corrected Erdos-Reyni model (DCER). Here, there is still a global parameter $p$ to specify relative connection probability between all edges. However, we add a promiscuity parameter $\\theta_i$ for each node $i$ which specifies its expected degree relative to other nodes:\n",
    "\n",
    "$$P_{ij} = \\theta_i \\theta_j p$$\n",
    "\n",
    "so the probability of an edge from $i$ to $j$ is a function of the two nodes' degree-correction parameters, and the overall probability of an edge in the graph. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from graspologic.models import DCEREstimator\n",
    "dcer = DCEREstimator(directed=True,loops=False)\n",
    "dcer.fit(adj)\n",
    "print(f\"DCER \\\"p\\\" parameter: {dcer.p_}\")\n",
    "heatmap(dcer.p_mat_,\n",
    "        inner_hier_labels=labels,\n",
    "        vmin=0,\n",
    "        vmax=1,\n",
    "        font_scale=1.5,\n",
    "        title=\"DCER probability matrix\", \n",
    "        sort_nodes=True);\n",
    "_ = heatmap(dcer.sample()[0],\n",
    "        inner_hier_labels=labels,\n",
    "        font_scale=1.5,\n",
    "        title=\"DCER sample\", \n",
    "        sort_nodes=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stochastic block model (SBM)\n",
    "Under the stochastic block model (SBM), each node is modeled as belonging to a block (sometimes called a community or group). The probability of node $i$ connecting to node $j$ is simply a function of the block membership of the two nodes. Let $n$ be the number of nodes in the graph, then $\\tau$ is a length $n$ vector which indicates the block membership of each node in the graph. Let $K$ be the number of blocks, then $B$ is a $K \\times K$ matrix of block-block connection probabilities. \n",
    "\n",
    "$$P_{ij} = B_{\\tau_i \\tau_j}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from graspologic.models import SBMEstimator\n",
    "sbme = SBMEstimator(directed=True,loops=False)\n",
    "sbme.fit(adj, y=labels)\n",
    "print(\"SBM \\\"B\\\" matrix:\")\n",
    "print(sbme.block_p_)\n",
    "heatmap(sbme.p_mat_,\n",
    "        inner_hier_labels=labels,\n",
    "        vmin=0,\n",
    "        vmax=1,\n",
    "        font_scale=1.5,\n",
    "        title=\"SBM probability matrix\", \n",
    "        sort_nodes=True)\n",
    "_ = heatmap(sbme.sample()[0],\n",
    "        inner_hier_labels=labels,\n",
    "        font_scale=1.5,\n",
    "        title=\"SBM sample\", \n",
    "        sort_nodes=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Degree-corrected stochastic block model (DCSBM)\n",
    "Just as we could add a degree-correction term to the ER model, so too can we modify the stochastic block model to allow for heterogeneous expected degrees. Again, we let $\\theta$ be a length $n$ vector of degree correction parameters, and all other parameters remain as they were defined above for the SBM: \n",
    "\n",
    "$$P_{ij} = \\theta_i \\theta_j B_{\\tau_i, \\tau_j}$$\n",
    "\n",
    "Note that the matrix $B$ may no longer represent true probabilities, because the addition of the $\\theta$ vectors introduces a multiplicative constant that can be absorbed into the elements of $\\theta$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from graspologic.models import DCSBMEstimator\n",
    "dcsbme = DCSBMEstimator(directed=True,loops=False)\n",
    "dcsbme.fit(adj, y=labels)\n",
    "print(\"DCSBM \\\"B\\\" matrix:\")\n",
    "print(dcsbme.block_p_)\n",
    "heatmap(dcsbme.p_mat_,\n",
    "        inner_hier_labels=labels,\n",
    "        font_scale=1.5, \n",
    "        title=\"DCSBM probability matrix\",\n",
    "        vmin=0,\n",
    "        vmax=1,\n",
    "        sort_nodes=True)\n",
    "_ = heatmap(dcsbme.sample()[0],\n",
    "        inner_hier_labels=labels,\n",
    "        title=\"DCSBM sample\",\n",
    "        font_scale=1.5,\n",
    "        sort_nodes=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Random dot product graph (RDPG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Under the random dot product graph model, each node is assumed to have a \"latent position\" in some $d$-dimensional Euclidian space. This vector dictates that node's probability of connection to other nodes. For a given pair of nodes $i$ and $j$, the probability of connection is the dot product between their latent positions. If $x_i$ and $x_j$ are the latent positions of nodes $i$ and $j$, respectively:\n",
    "\n",
    "$$P_{ij} = \\langle x_i, x_j \\rangle$$\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from graspologic.models import RDPGEstimator\n",
    "rdpge = RDPGEstimator(loops=False)\n",
    "rdpge.fit(adj, y=labels)\n",
    "heatmap(rdpge.p_mat_,\n",
    "        inner_hier_labels=labels,\n",
    "        vmin=0,\n",
    "        vmax=1,\n",
    "        font_scale=1.5,\n",
    "        title=\"RDPG probability matrix\",\n",
    "        sort_nodes=True\n",
    "       )\n",
    "_ = heatmap(rdpge.sample()[0],\n",
    "        inner_hier_labels=labels,\n",
    "        font_scale=1.5,\n",
    "        title=\"RDPG sample\",\n",
    "        sort_nodes=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a figure combining all of the models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from matplotlib import cm\n",
    "from matplotlib.pyplot import get_cmap\n",
    "\n",
    "fig, axs = plt.subplots(2, 3, figsize=(20, 16))\n",
    "\n",
    "# colormapping\n",
    "cmap = get_cmap(\"RdBu_r\")\n",
    "\n",
    "center = 0\n",
    "vmin = 0\n",
    "vmax = 1\n",
    "norm = mpl.colors.Normalize(0, 1)\n",
    "cc = np.linspace(0.5, 1, 256) \n",
    "cmap = mpl.colors.ListedColormap(cmap(cc))\n",
    "\n",
    "# heatmapping\n",
    "heatmap_kws = dict(\n",
    "    inner_hier_labels=labels,\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    "    cbar=False,\n",
    "    cmap=cmap,\n",
    "    center=None,\n",
    "    hier_label_fontsize=20,\n",
    "    title_pad=45,\n",
    "    font_scale=1.6,\n",
    "    sort_nodes=True\n",
    ")\n",
    "\n",
    "models = [rdpge, dcsbme, dcer, sbme, er]\n",
    "model_names = [\"RDPG\", \"DCSBM\",\"DCER\", \"SBM\", \"ER\"]\n",
    "\n",
    "heatmap(adj, ax=axs[0][0], title=\"IER\", **heatmap_kws)\n",
    "heatmap(models[0].p_mat_, ax=axs[0][1], title=model_names[0], **heatmap_kws)\n",
    "heatmap(models[1].p_mat_, ax=axs[0][2], title=model_names[1], **heatmap_kws)\n",
    "heatmap(models[2].p_mat_, ax=axs[1][0], title=model_names[2], **heatmap_kws)\n",
    "heatmap(models[3].p_mat_, ax=axs[1][1], title=model_names[3], **heatmap_kws)\n",
    "heatmap(models[4].p_mat_, ax=axs[1][2], title=model_names[4], **heatmap_kws)\n",
    "\n",
    "\n",
    "# add colorbar\n",
    "sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)\n",
    "sm.set_array(dcsbme.p_mat_)\n",
    "_ = fig.colorbar(sm,\n",
    "             ax=axs,\n",
    "             orientation=\"horizontal\",\n",
    "             pad=0.04,\n",
    "             shrink=0.8,\n",
    "             fraction=0.08,\n",
    "             drawedges=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}