Source code for chaski.utils.viz

"""
==============
Visualizations
==============

This module provides functions for visualizing the network graph and latency heatmap of ChaskiNode objects.
It uses `networkx` and `matplotlib` to render the network graph, which displays nodes, connections, and
latency values. A heatmap of latencies between nodes is created using `seaborn` to help in identifying
patterns and high-latency connections.
"""

import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from typing import List
from chaski.node import ChaskiNode
import seaborn as sns


[docs]def display_graph( nodes: List[ChaskiNode], layout=nx.circular_layout, show_latencies=False ) -> None: """ Display the network graph and latency statistics. Parameters ---------- nodes : List[Node] A list of nodes, each containing name and server pairs. """ G = nx.Graph() # Prepare nodes for graph representation nodes_ = [ { "name": node.name, "paired": all( [node.paired_event[sub].is_set() for sub in node.subscriptions] ), "subscriptions": f"{{{''.join(node.subscriptions)}}}", "server_pairs": {v.name: v.latency for v in node.edges}, } for node in nodes ] # Add edges to the graph for node in nodes_: for neighbor, latency in node["server_pairs"].items(): G.add_edge(node["name"], neighbor, weight=latency) pos = layout(G) # Graph display options options = { "with_labels": False, "node_color": "#6DA58A", "edgecolors": ["#008080" if node["paired"] else "#6DA58A" for node in nodes_], "linewidths": 5, "edge_color": "#B3B3BD", "width": 3, "node_size": 1500, "font_color": "#ffffff", "font_family": "Noto Sans", "font_size": 11, "pos": pos, } # Create the plot plt.figure(figsize=(16, 9), dpi=90) if show_latencies: ax1 = plt.subplot(111) else: ax1 = plt.subplot2grid((3, 5), (0, 0), colspan=4, rowspan=3) nx.draw(G, ax=ax1, **options) labels = {node["name"]: node["name"] for node in nodes_} nx.draw_networkx_labels( G, {k: pos[k] + np.array([0.0, 0.03]) for k in pos}, labels, ax=ax1, font_color="#ffffff", font_size=11, ) labels = {node["name"]: node["subscriptions"] for node in nodes_} nx.draw_networkx_labels( G, {k: pos[k] + np.array([0.0, -0.02]) for k in pos}, labels, ax=ax1, font_color="#ffffff", font_size=8, ) # Collect latencies for statistics # latencies = [peer.latency for node in nodes for peer in node['server_pairs']] latencies = [] for node in nodes: for edge in node.edges: latencies.append(edge.latency) # Log statistics log = f""" nodes: {len(nodes)} connections: {0.5 * sum(len(node.edges) for node in nodes): .0f} max(latency): {np.max(latencies): .3f} ms min(latency): {np.min(latencies): .3f} ms mean(latency): {np.mean(latencies): .3f} ms std(latency): {np.std(latencies): .3f} ms """ # Display log statistics font_options = { "fontsize": 12, "fontfamily": "Noto Sans Mono", "fontweight": "normal", "ha": "left", "color": "#0B5D37", } if show_latencies: ax2 = plt.subplot2grid((3, 5), (2, 4), colspan=1) ax2.text(0, 1, log, **font_options) ax2.axis("off") plt.show()
[docs]def display_heatmap(nodes: List[ChaskiNode], show=True) -> None: """ Display a heatmap of latencies between nodes. Parameters ---------- nodes : List[Node] A list of Node objects, each containing server pairs with latency information. Returns ------- None """ # Create a graph dictionary from nodes graph = {} for node in nodes: row = {} for peer in node.edges: row[peer.name] = peer.latency graph[node.name] = row # Extract node names and initialize a latency matrix nodes_ = list(graph.keys()) n = len(nodes_) latency_matrix = np.zeros((n, n)) # Map node names to matrix indices node_index = {node: idx for idx, node in enumerate(nodes_)} # Fill the latency matrix with values from the graph dictionary for node, neighbors in graph.items(): for neighbor, latency in neighbors.items(): i, j = node_index[node], node_index[neighbor] latency_matrix[i, j] = latency if not show: return latency_matrix, nodes_ # Plot the heatmap plt.figure(figsize=(8, 6), dpi=90) sns.heatmap( latency_matrix, xticklabels=nodes_, yticklabels=nodes_, annot=True, fmt=".2f", cmap="GnBu", cbar=True, linewidths=0.5, ) plt.title("Latency Heatmap") # plt.xlabel('Nodes') # plt.ylabel('Nodes') plt.show()
[docs]def display_subscriptions_graph( nodes: List[ChaskiNode], layout=nx.spring_layout, show_latencies=False, chain="_CHAIN_", ) -> None: """ Display the network graph and latency statistics. Parameters ---------- nodes : List[Node] A list of nodes, each containing name and server pairs. """ G = nx.Graph() # Prepare nodes for graph representation nodes_ = [ { "name": node.name, "paired": all( [node.paired_event[sub].is_set() for sub in node.subscriptions] ), "group": sorted(list(node.subscriptions))[0], "subscriptions": f"{{{','.join(node.subscriptions.difference({chain}))}}}", "server_pairs": {v.name: v.latency for v in node.edges}, } for node in nodes ] # Add edges to the graph for node in nodes_: for neighbor, latency in node["server_pairs"].items(): G.add_edge(node["name"], neighbor, weight=latency) pos = layout(G) def generate_subscription_colors(): """ Generates a dictionary mapping uppercase letters (A-Z) to unique colors. """ letters = [chr(i) for i in range(65, 91)] # Generate A-Z colors = ( plt.cm.tab20.colors ) # Use a colormap from Matplotlib (tab20 has 20 distinct colors) colors = list(colors) * 2 # Repeat the colors to cover 26 letters return { letter: f"#{int(r*255):02X}{int(g*255):02X}{int(b*255):02X}" for letter, (r, g, b) in zip(letters, colors) } # Generate the dictionary SUBSCRIPTION_COLORS = generate_subscription_colors() def get_subscription_color(sub): return SUBSCRIPTION_COLORS.get(sub, "#D3D3D3") # Se asignan colores a cada nodo segĂșn su grupo node_colors_dict = { nd["name"]: get_subscription_color(nd["group"]) for nd in nodes_ } node_colors = [node_colors_dict[name] for name in G.nodes] # Graph display options options = { # "with_labels": True, "with_labels": False, "node_color": node_colors, # 'edgecolors': "#ffffff", # 'linewidths': 2, "edge_color": "#B3B3BD", "width": 3, # "node_size": 1100, "node_size": 1800, "font_color": "#ffffff", "font_family": "Noto Sans", "font_size": 11, "pos": pos, } # Create the plot plt.figure(figsize=(16, 9), dpi=90) if show_latencies: ax1 = plt.subplot(111) else: ax1 = plt.subplot2grid((3, 5), (0, 0), colspan=4, rowspan=3) nx.draw(G, ax=ax1, **options) labels = {node["name"]: node["name"] for node in nodes_} nx.draw_networkx_labels( G, {k: pos[k] + np.array([0.0, 0.03]) for k in pos}, labels, ax=ax1, font_color="#ffffff", font_size=11, ) labels = {node["name"]: node["subscriptions"] for node in nodes_} nx.draw_networkx_labels( G, {k: pos[k] + np.array([0.0, -0.02]) for k in pos}, labels, ax=ax1, font_color="#ffffff", font_size=8, ) # Collect latencies for statistics # latencies = [peer.latency for node in nodes for peer in node['server_pairs']] latencies = [] for node in nodes: for edge in node.edges: latencies.append(edge.latency) # Log statistics log = f""" nodes: {len(nodes)} connections: {0.5 * sum(len(node.edges) for node in nodes): .0f} max(latency): {np.max(latencies): .3f} ms min(latency): {np.min(latencies): .3f} ms mean(latency): {np.mean(latencies): .3f} ms std(latency): {np.std(latencies): .3f} ms """ # Display log statistics font_options = { "fontsize": 12, "fontfamily": "Noto Sans Mono", "fontweight": "normal", "ha": "left", "color": "#0B5D37", } if show_latencies: ax2 = plt.subplot2grid((3, 5), (2, 4), colspan=1) ax2.text(0, 1, log, **font_options) ax2.axis("off") plt.show()