Source code for graspologic.layouts.render

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Any

import matplotlib.pyplot as plt
import networkx as nx

from graspologic.layouts.classes import NodePosition
from graspologic.types import Dict, List, Tuple


def _calculate_x_y_domain(
    positions: List[NodePosition],
) -> Tuple[Tuple[float, float], Tuple[float, float]]:
    """calculate the overall x/y domain, converting to a square
    so we can have a consistent scale
    """
    min_x = min_y = float("inf")
    max_x = max_y = float("-inf")
    for node_position in positions:
        min_x = min(min_x, node_position.x - node_position.size)
        max_x = max(max_x, node_position.x + node_position.size)
        min_y = min(min_y, node_position.y - node_position.size)
        max_y = max(max_y, node_position.y + node_position.size)

    x_delta = max_x - min_x
    y_delta = max_y - min_y
    max_delta = max(x_delta, y_delta)

    if max_delta == x_delta:
        difference = (max_delta - y_delta) / 2
        min_y = min_y - difference
        max_y = max_y + difference
    elif max_delta == y_delta:
        difference = (max_delta - x_delta) / 2
        min_x = min_x - difference
        max_x = max_x + difference

    return (min_x, max_x), (min_y, max_y)


def _scale_value(
    domain: Tuple[float, float], data_range: Tuple[float, float], value: float
) -> float:
    return data_range[0] + (data_range[1] - data_range[0]) * (
        (value - domain[0]) / (domain[1] - domain[0])
    )


def _scale_node_sizes_for_rendering(
    sizes: List[float],
    spatial_domain: Tuple[float, float],
    spatial_range: Tuple[float, float],
    dpi: float,
) -> List[float]:
    """
    Scale the size again to match the rendered pixel range
    we would expect this to be handled by the underlying viz framework, but it isn't, size is specified
    as the bounding box in points of the rendered output, so we need to transform our size to match.

    There are 72 points per inch. Multiplying by 72 / dpi converts from pixels to points.
    """
    spatial_domain = (0, spatial_domain[1] - spatial_domain[0])
    return [
        _scale_value(spatial_domain, spatial_range, s * 2 * 72.0 / dpi) ** 2
        for s in sizes
    ]


def _draw_graph(
    graph: nx.Graph,
    positions: List[NodePosition],
    node_colors: Dict[Any, str],
    vertex_alpha: float,
    edge_line_width: float,
    edge_alpha: float,
    figure_width: float,
    figure_height: float,
    vertex_line_width: float = 0.01,
    vertex_shape: str = "o",
    arrows: bool = False,
    dpi: int = 100,
) -> None:
    if len(positions) != len(graph.nodes()):
        raise ValueError(
            f"The number of positions provided {len(positions)} is not the same as the "
            f"number of nodes in the graph {len(graph.nodes())}"
        )
    for position in positions:
        if position.node_id not in graph:
            raise ValueError(
                f"The node position provided for {position.node_id} references a node "
                f"not found in our graph"
            )

    plt.rcParams["figure.dpi"] = dpi  # TODO, test at different dpi

    plt.clf()
    figure = plt.gcf()
    ax = plt.gca()
    ax.set_axis_off()
    figure.set_size_inches(figure_width, figure_height)
    window_extent_width = ax.get_window_extent().width

    x_domain, y_domain = _calculate_x_y_domain(positions)

    position_map = {position.node_id: position for position in positions}
    node_positions = {
        position.node_id: (position.x, position.y) for position in positions
    }

    vertices = []
    vertex_sizes = []
    node_color_list = []
    edge_color_list = []

    for node in graph.nodes():
        vertices.append(node)
        vertex_sizes.append(position_map[node].size)
        node_color_list.append(node_colors[node])

    vertex_sizes = _scale_node_sizes_for_rendering(
        vertex_sizes, x_domain, (0, window_extent_width), dpi
    )

    for source, target in graph.edges():
        edge_color_list.append(node_colors[source])

    ax.set_xbound(*x_domain)
    ax.set_xlim(x_domain)
    ax.set_ybound(*y_domain)
    ax.set_ylim(y_domain)

    nx.draw_networkx_edges(
        graph,
        pos=node_positions,
        alpha=edge_alpha,
        width=edge_line_width,
        edge_color=edge_color_list,
        arrows=arrows,
        ax=ax,
    )

    nx.draw_networkx_nodes(
        graph,
        pos=node_positions,
        nodelist=vertices,
        node_color=node_color_list,
        alpha=vertex_alpha,
        linewidths=vertex_line_width,
        node_size=vertex_sizes,
        node_shape=vertex_shape,
        ax=ax,
    )


[docs] def show_graph( graph: nx.Graph, positions: List[NodePosition], node_colors: Dict[Any, str], vertex_line_width: float = 0.01, vertex_alpha: float = 0.55, edge_line_width: float = 0.5, edge_alpha: float = 0.02, figure_width: float = 15.0, figure_height: float = 15.0, light_background: bool = True, vertex_shape: str = "o", arrows: bool = False, dpi: int = 500, ) -> None: """ Renders and displays a graph. Attempts to display it via the platform-specific display library such as TkInter Edges will be displayed with the same color as the source node. Parameters ---------- graph : nx.Graph The graph to be displayed. If the networkx Graph contains only nodes, no edges will be displayed. positions : List[:class:`graspologic.layouts.NodePosition`] The positionsfor every node in the graph. node_colors : Dict[Any, str] A mapping of node id to colors. Must contain an entry for every node in the graph. vertex_line_width : float Line width of vertex outline. Default is``0.01``. vertex_alpha : float Alpha (transparency) of vertices in visualization. Default is``0.55``. edge_line_width : float Line width of edge. Default is``0.5``. edge_alpha : float Alpha (transparency) of edges in visualization. Default is``0.02``. figure_width : float Width of figure. Default is ``15.0``. figure_height : float eight of figure. Default is``15.0``. light_background : bool Light background or dark background. Default is``True``. vertex_shape : str Matplotlib Marker for the vertex shape. See `https://matplotlib.org/api/markers_api.html <https://matplotlib.org/api/markers_api.html>`_ for a list of allowed values . Default is ``o`` (i.e: a circle) arrows : bool For directed graphs, if ``True``, draw arrow heads. Default is ``False`` dpi : float Dots per inch of the figure. Default is ``500``. """ ax = plt.gca() if light_background: facecolor = ax.get_facecolor() else: facecolor = "#030303" _draw_graph( graph=graph, positions=positions, node_colors=node_colors, vertex_line_width=vertex_line_width, vertex_alpha=vertex_alpha, edge_line_width=edge_line_width, edge_alpha=edge_alpha, figure_width=figure_width, figure_height=figure_height, vertex_shape=vertex_shape, arrows=arrows, dpi=dpi, ) plt.gcf().set_facecolor(facecolor) plt.show() plt.close("all")
[docs] def save_graph( output_path: str, graph: nx.Graph, positions: List[NodePosition], node_colors: Dict[Any, str], vertex_line_width: float = 0.01, vertex_alpha: float = 0.55, edge_line_width: float = 0.5, edge_alpha: float = 0.02, figure_width: float = 15.0, figure_height: float = 15.0, light_background: bool = True, vertex_shape: str = "o", arrows: bool = False, dpi: int = 100, ) -> None: """ Renders a graph to file. Edges will be displayed with the same color as the source node. Parameters ---------- output_path : str The output path to write the rendered graph to. Suggested file extension is ``.png``. graph : nx.Graph The graph to be displayed. If the networkx Graph contains only nodes, no edges will be displayed. positions : List[:class:`graspologic.layouts.NodePosition`] The positionsfor every node in the graph. node_colors : Dict[Any, str] A mapping of node id to colors. Must contain an entry for every node in the graph. vertex_line_width : float Line width of vertex outline. Default is``0.01``. vertex_alpha : float Alpha (transparency) of vertices in visualization. Default is``0.55``. edge_line_width : float Line width of edge. Default is``0.5``. edge_alpha : float Alpha (transparency) of edges in visualization. Default is``0.02``. figure_width : float Width of figure. Default is ``15.0``. figure_height : float eight of figure. Default is``15.0``. light_background : bool Light background or dark background. Default is``True``. vertex_shape : str Matplotlib Marker for the vertex shape. See `https://matplotlib.org/api/markers_api.html <https://matplotlib.org/api/markers_api.html>`_ for a list of allowed values . Default is ``o`` (i.e: a circle) arrows : bool For directed graphs, if ``True``, draw arrow heads. Default is ``False`` dpi : float Dots per inch of the figure. Default is ``100``. Returns ------- """ _draw_graph( graph=graph, positions=positions, node_colors=node_colors, vertex_line_width=vertex_line_width, vertex_alpha=vertex_alpha, edge_line_width=edge_line_width, edge_alpha=edge_alpha, figure_width=figure_width, figure_height=figure_height, vertex_shape=vertex_shape, arrows=arrows, dpi=dpi, ) ax = plt.gca() if light_background: facecolor = ax.get_facecolor() else: facecolor = "#030303" plt.savefig(output_path, facecolor=facecolor) plt.close("all")