diff --git a/mesa/visualization/mpl_space_drawing.py b/mesa/visualization/mpl_space_drawing.py index 6afdac759c7..52182bc9caa 100644 --- a/mesa/visualization/mpl_space_drawing.py +++ b/mesa/visualization/mpl_space_drawing.py @@ -10,6 +10,7 @@ import itertools import warnings from collections.abc import Callable +from itertools import pairwise from typing import Any import networkx as nx @@ -346,7 +347,7 @@ def draw_hex_grid( def setup_hexmesh(width, height): """Helper function for creating the hexmesh with unique edges.""" - edges = set() + edges = [] size = 1.0 x_spacing = np.sqrt(3) * size y_spacing = 1.5 * size @@ -374,18 +375,15 @@ def get_hex_vertices( vertices = get_hex_vertices(x, y) # Edge logic, connecting each vertex to the next - for i in range(len(vertices)): - v1 = vertices[i] - v2 = vertices[(i + 1) % len(vertices)] - - # Sort vertices to ensure consistent edge representation + for v1, v2 in pairwise(vertices + [vertices[0]]): + # Sort vertices to ensure consistent edge representation and avoid duplicates. edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))])) - edges.add(edge) + if edge not in edges: + edges.append(edge) - # Convert to LineCollection format - edges_list = [np.array(edge) for edge in edges] + # Return LineCollection for hexmesh return LineCollection( - edges_list, linestyle=":", color="black", linewidth=1, alpha=1 + edges, linestyle=":", color="black", linewidth=1, alpha=1 ) if draw_grid: