diff --git a/src/maxplotlib/canvas/canvas.py b/src/maxplotlib/canvas/canvas.py index a64520e..962ab21 100644 --- a/src/maxplotlib/canvas/canvas.py +++ b/src/maxplotlib/canvas/canvas.py @@ -18,6 +18,141 @@ from maxplotlib.utils.options import Backends +def plot_matplotlib(tikzfigure: TikzFigure, ax, layers=None): + """ + Plot all nodes and paths on the provided axis using Matplotlib. + + Parameters: + - ax (matplotlib.axes.Axes): Axis on which to plot the figure. + """ + + # TODO: Specify which layers to retreive nodes from with layers=layers + nodes = tikzfigure.layers.get_nodes() + paths = tikzfigure.layers.get_paths() + + for path in paths: + x_coords = [node.x for node in path.nodes] + y_coords = [node.y for node in path.nodes] + + # Parse path color + path_color_spec = path.kwargs.get("color", "black") + try: + color = Color(path_color_spec).to_rgb() + except ValueError as e: + print(e) + color = "black" + + # Parse line width + line_width_spec = path.kwargs.get("line_width", 1) + if isinstance(line_width_spec, str): + match = re.match(r"([\d.]+)(pt)?", line_width_spec) + if match: + line_width = float(match.group(1)) + else: + print( + f"Invalid line width specification: '{line_width_spec}', defaulting to 1", + ) + line_width = 1 + else: + line_width = float(line_width_spec) + + # Parse line style using Linestyle class + style_spec = path.kwargs.get("style", "solid") + linestyle = Linestyle(style_spec).to_matplotlib() + + ax.plot( + x_coords, + y_coords, + color=color, + linewidth=line_width, + linestyle=linestyle, + zorder=1, # Lower z-order to place behind nodes + ) + + # Plot nodes after paths so they appear on top + for node in nodes: + # Determine shape and size + shape = node.kwargs.get("shape", "circle") + fill_color_spec = node.kwargs.get("fill", "white") + edge_color_spec = node.kwargs.get("draw", "black") + linewidth = float(node.kwargs.get("line_width", 1)) + size = float(node.kwargs.get("size", 1)) + + # Parse colors using the Color class + try: + facecolor = Color(fill_color_spec).to_rgb() + except ValueError as e: + print(e) + facecolor = "white" + + try: + edgecolor = Color(edge_color_spec).to_rgb() + except ValueError as e: + print(e) + edgecolor = "black" + + # Plot shapes + if shape == "circle": + radius = size / 2 + circle = patches.Circle( + (node.x, node.y), + radius, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=linewidth, + zorder=2, # Higher z-order to place on top of paths + ) + ax.add_patch(circle) + elif shape == "rectangle": + width = height = size + rect = patches.Rectangle( + (node.x - width / 2, node.y - height / 2), + width, + height, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=linewidth, + zorder=2, # Higher z-order + ) + ax.add_patch(rect) + else: + # Default to circle if shape is unknown + radius = size / 2 + circle = patches.Circle( + (node.x, node.y), + radius, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=linewidth, + zorder=2, + ) + ax.add_patch(circle) + + # Add text inside the shape + if node.content: + ax.text( + node.x, + node.y, + node.content, + fontsize=10, + ha="center", + va="center", + wrap=True, + zorder=3, # Even higher z-order for text + ) + + # Remove axes, ticks, and legend + ax.axis("off") + + # Adjust plot limits + all_x = [node.x for node in nodes] + all_y = [node.y for node in nodes] + padding = 1 # Adjust padding as needed + ax.set_xlim(min(all_x) - padding, max(all_x) + padding) + ax.set_ylim(min(all_y) - padding, max(all_y) + padding) + ax.set_aspect("equal", adjustable="datalim") + + class Canvas: def __init__( self, @@ -29,7 +164,7 @@ def __init__( label: str | None = None, fontsize: int = 14, dpi: int = 300, - width: str = "17cm", + width: str = "5cm", ratio: str = "golden", # TODO Add literal gridspec_kw: Dict = {"wspace": 0.08, "hspace": 0.1}, ): @@ -62,6 +197,8 @@ def __init__( self._ratio = ratio self._gridspec_kw = gridspec_kw self._plotted = False + self._matplotlib_fig = None + self._matplotlib_axes = None # Dictionary to store lines for each subplot # Key: (row, col), Value: list of lines with their data and kwargs @@ -106,7 +243,6 @@ def add_line( subplot: LinePlot | None = None, row: int | None = None, col: int | None = None, - plot_type="plot", **kwargs, ): if row is not None and col is not None: @@ -126,7 +262,6 @@ def add_line( x_data=x_data, y_data=y_data, layer=layer, - plot_type=plot_type, **kwargs, ) @@ -304,7 +439,7 @@ def show( elif backend == "plotly": self.plot_plotly(savefig=False) elif backend == "tikzpics": - fig = self.plot_tikzpics(savefig=False) + fig = self.plot_tikzpics(savefig=False, verbose=verbose) fig.show() else: raise ValueError("Invalid backend") @@ -374,8 +509,8 @@ def plot_matplotlib( def plot_tikzpics( self, - savefig=None, - verbose=False, + savefig: str | None = None, + verbose: bool = False, ) -> TikzFigure: if len(self.subplots) > 1: raise NotImplementedError( @@ -507,13 +642,6 @@ def label(self, value): def figsize(self, value): self._figsize = value - # Magic methods - def __str__(self): - return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, figsize={self.figsize})" - - def __repr__(self): - return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, caption={self.caption}, label={self.label})" - def __getitem__(self, key): """Allows accessing subplots by tuple index.""" row, col = key @@ -528,140 +656,12 @@ def __setitem__(self, key, value): raise IndexError("Subplot index out of range") self._subplot_matrix[row][col] = value + def __repr__(self): + return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, caption={self.caption}, label={self.label})" -def plot_matplotlib(tikzfigure: TikzFigure, ax, layers=None): - """ - Plot all nodes and paths on the provided axis using Matplotlib. - - Parameters: - - ax (matplotlib.axes.Axes): Axis on which to plot the figure. - """ - - # TODO: Specify which layers to retreive nodes from with layers=layers - nodes = tikzfigure.layers.get_nodes() - paths = tikzfigure.layers.get_paths() - - for path in paths: - x_coords = [node.x for node in path.nodes] - y_coords = [node.y for node in path.nodes] - - # Parse path color - path_color_spec = path.kwargs.get("color", "black") - try: - color = Color(path_color_spec).to_rgb() - except ValueError as e: - print(e) - color = "black" - - # Parse line width - line_width_spec = path.kwargs.get("line_width", 1) - if isinstance(line_width_spec, str): - match = re.match(r"([\d.]+)(pt)?", line_width_spec) - if match: - line_width = float(match.group(1)) - else: - print( - f"Invalid line width specification: '{line_width_spec}', defaulting to 1", - ) - line_width = 1 - else: - line_width = float(line_width_spec) - - # Parse line style using Linestyle class - style_spec = path.kwargs.get("style", "solid") - linestyle = Linestyle(style_spec).to_matplotlib() - - ax.plot( - x_coords, - y_coords, - color=color, - linewidth=line_width, - linestyle=linestyle, - zorder=1, # Lower z-order to place behind nodes - ) - - # Plot nodes after paths so they appear on top - for node in nodes: - # Determine shape and size - shape = node.kwargs.get("shape", "circle") - fill_color_spec = node.kwargs.get("fill", "white") - edge_color_spec = node.kwargs.get("draw", "black") - linewidth = float(node.kwargs.get("line_width", 1)) - size = float(node.kwargs.get("size", 1)) - - # Parse colors using the Color class - try: - facecolor = Color(fill_color_spec).to_rgb() - except ValueError as e: - print(e) - facecolor = "white" - - try: - edgecolor = Color(edge_color_spec).to_rgb() - except ValueError as e: - print(e) - edgecolor = "black" - - # Plot shapes - if shape == "circle": - radius = size / 2 - circle = patches.Circle( - (node.x, node.y), - radius, - facecolor=facecolor, - edgecolor=edgecolor, - linewidth=linewidth, - zorder=2, # Higher z-order to place on top of paths - ) - ax.add_patch(circle) - elif shape == "rectangle": - width = height = size - rect = patches.Rectangle( - (node.x - width / 2, node.y - height / 2), - width, - height, - facecolor=facecolor, - edgecolor=edgecolor, - linewidth=linewidth, - zorder=2, # Higher z-order - ) - ax.add_patch(rect) - else: - # Default to circle if shape is unknown - radius = size / 2 - circle = patches.Circle( - (node.x, node.y), - radius, - facecolor=facecolor, - edgecolor=edgecolor, - linewidth=linewidth, - zorder=2, - ) - ax.add_patch(circle) - - # Add text inside the shape - if node.content: - ax.text( - node.x, - node.y, - node.content, - fontsize=10, - ha="center", - va="center", - wrap=True, - zorder=3, # Even higher z-order for text - ) - - # Remove axes, ticks, and legend - ax.axis("off") - - # Adjust plot limits - all_x = [node.x for node in nodes] - all_y = [node.y for node in nodes] - padding = 1 # Adjust padding as needed - ax.set_xlim(min(all_x) - padding, max(all_x) + padding) - ax.set_ylim(min(all_y) - padding, max(all_y) + padding) - ax.set_aspect("equal", adjustable="datalim") + # Magic methods + def __str__(self): + return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, figsize={self.figsize})" if __name__ == "__main__": diff --git a/src/maxplotlib/colors/colors.py b/src/maxplotlib/colors/colors.py index fdb117e..4d04287 100644 --- a/src/maxplotlib/colors/colors.py +++ b/src/maxplotlib/colors/colors.py @@ -5,16 +5,6 @@ class Color: - def __init__(self, color_spec): - """ - Initialize the Color object by parsing the color specification. - - Parameters: - - color_spec: Can be a TikZ color string (e.g., 'blue!20'), a standard color name, - an RGB tuple, a hex code, etc. - """ - self.color_spec = color_spec - self.rgb = self._parse_color(color_spec) def _parse_color(self, color_spec): """ @@ -53,6 +43,17 @@ def _parse_color(self, color_spec): except ValueError: raise ValueError(f"Invalid color specification: '{color_spec}'") + def __init__(self, color_spec): + """ + Initialize the Color object by parsing the color specification. + + Parameters: + - color_spec: Can be a TikZ color string (e.g., 'blue!20'), a standard color name, + an RGB tuple, a hex code, etc. + """ + self.color_spec = color_spec + self.rgb = self._parse_color(color_spec) + def to_rgb(self): """ Return the color as an RGB tuple. diff --git a/src/maxplotlib/linestyle/linestyle.py b/src/maxplotlib/linestyle/linestyle.py index 0ba1f04..27e0758 100644 --- a/src/maxplotlib/linestyle/linestyle.py +++ b/src/maxplotlib/linestyle/linestyle.py @@ -2,16 +2,6 @@ class Linestyle: - def __init__(self, style_spec): - """ - Initialize the Linestyle object by parsing the style specification. - - Parameters: - - style_spec: Can be a TikZ-style line style string (e.g., 'dashed', 'dotted', 'solid', 'dashdot'), - or a custom dash pattern. - """ - self.style_spec = style_spec - self.matplotlib_style = self._parse_style(style_spec) def _parse_style(self, style_spec): """ @@ -48,6 +38,17 @@ def _parse_style(self, style_spec): print(f"Unknown line style: '{style_spec}', defaulting to 'solid'") return "solid" + def __init__(self, style_spec): + """ + Initialize the Linestyle object by parsing the style specification. + + Parameters: + - style_spec: Can be a TikZ-style line style string (e.g., 'dashed', 'dotted', 'solid', 'dashdot'), + or a custom dash pattern. + """ + self.style_spec = style_spec + self.matplotlib_style = self._parse_style(style_spec) + def to_matplotlib(self): """ Return the line style in Matplotlib format. diff --git a/src/maxplotlib/subfigure/line_plot.py b/src/maxplotlib/subfigure/line_plot.py index 8a50cfc..f314c3d 100644 --- a/src/maxplotlib/subfigure/line_plot.py +++ b/src/maxplotlib/subfigure/line_plot.py @@ -106,7 +106,6 @@ def add_line( x_data, y_data, layer=0, - plot_type="plot", **kwargs, ): """ @@ -122,34 +121,34 @@ def add_line( "x": np.array(x_data), "y": np.array(y_data), "layer": layer, - "plot_type": plot_type, + "plot_type": "plot", "kwargs": kwargs, } self._add(ld, layer) - def add_imshow(self, data, layer=0, plot_type="imshow", **kwargs): + def add_imshow(self, data, layer=0, **kwargs): ld = { "data": np.array(data), "layer": layer, - "plot_type": plot_type, + "plot_type": "imshow", "kwargs": kwargs, } self._add(ld, layer) - def add_patch(self, patch, layer=0, plot_type="patch", **kwargs): + def add_patch(self, patch, layer=0, **kwargs): ld = { "patch": patch, "layer": layer, - "plot_type": plot_type, + "plot_type": "patch", "kwargs": kwargs, } self._add(ld, layer) - def add_colorbar(self, label="", layer=0, plot_type="colorbar", **kwargs): + def add_colorbar(self, label="", layer=0, **kwargs): cb = { "label": label, "layer": layer, - "plot_type": plot_type, + "plot_type": "colorbar", "kwargs": kwargs, } self._add(cb, layer) @@ -235,6 +234,9 @@ def plot_tikzpics(self, layers=None, verbose: bool = False) -> TikzFigure: nodes = [[xi, yi] for xi, yi in zip(x, y)] tikz_figure.draw(nodes=nodes, **line["kwargs"]) + if verbose: + print("Generated TikZ figure:") + print(tikz_figure.generate_tikz()) return tikz_figure def plot_plotly(self): diff --git a/tutorials/tutorial_01.ipynb b/tutorials/tutorial_01.ipynb index 2b110e4..5a054f8 100644 --- a/tutorials/tutorial_01.ipynb +++ b/tutorials/tutorial_01.ipynb @@ -29,7 +29,7 @@ "metadata": {}, "outputs": [], "source": [ - "c = Canvas(width=\"17mm\", ratio=0.5, fontsize=12)\n", + "c = Canvas(ratio=0.5, fontsize=12)\n", "c.add_line([0, 1, 2, 3], [0, 1, 4, 9], label=\"Line 1\")\n", "c.add_line([0, 1, 2, 3], [0, 2, 3, 4], linestyle=\"dashed\", color=\"red\", label=\"Line 2\")\n", "c.show()" @@ -44,7 +44,7 @@ "source": [ "# You can also explicitly create a subplot and add lines to it\n", "\n", - "c = Canvas(width=\"17cm\", ratio=0.5, fontsize=12)\n", + "c = Canvas(ratio=0.5, fontsize=12)\n", "sp = c.add_subplot(\n", " grid=True, xlabel=\"(x - 10) * 0.1\", ylabel=\"10y\", yscale=10, xshift=-10, xscale=0.1\n", ")\n", @@ -63,7 +63,7 @@ "source": [ "# Example with multiple subplots\n", "\n", - "c = Canvas(width=\"17cm\", ncols=2, nrows=2, ratio=0.5)\n", + "c = Canvas(width=\"10cm\", ncols=2, nrows=2, ratio=0.5)\n", "sp = c.add_subplot(grid=True)\n", "c.add_subplot(row=1)\n", "sp2 = c.add_subplot(row=1, legend=False)\n", @@ -82,7 +82,7 @@ "outputs": [], "source": [ "# Test with plotly backend\n", - "c = Canvas(width=\"17cm\", ratio=0.5)\n", + "c = Canvas(ratio=0.5)\n", "sp = c.add_subplot(\n", " grid=True, xlabel=\"x (mm)\", ylabel=\"10y\", yscale=10, xshift=-10, xscale=0.1\n", ")\n", diff --git a/tutorials/tutorial_02.ipynb b/tutorials/tutorial_02.ipynb index f2e3528..89ec85b 100644 --- a/tutorials/tutorial_02.ipynb +++ b/tutorials/tutorial_02.ipynb @@ -61,7 +61,7 @@ "metadata": {}, "outputs": [], "source": [ - "c = Canvas(ncols=2, width=\"20cm\", ratio=0.5)\n", + "c = Canvas(width=\"10cm\", ncols=2, ratio=0.5)\n", "tikz = c.add_tikzfigure(grid=False)\n", "\n", "# Add nodes\n",