diff --git a/news/plot-recipe.rst b/news/plot-recipe.rst new file mode 100644 index 00000000..347b958b --- /dev/null +++ b/news/plot-recipe.rst @@ -0,0 +1,23 @@ +**Added:** + +* Add ``plot_recipe`` method to ``FitRecipe``. + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/requirements/conda.txt b/requirements/conda.txt index 5306b5d1..f728df41 100644 --- a/requirements/conda.txt +++ b/requirements/conda.txt @@ -1,3 +1,4 @@ matplotlib-base numpy scipy +bg-mpl-stylesheets diff --git a/requirements/pip.txt b/requirements/pip.txt index 74fa65e6..a9d7948e 100644 --- a/requirements/pip.txt +++ b/requirements/pip.txt @@ -1,3 +1,4 @@ matplotlib numpy scipy +bg-mpl-stylesheets diff --git a/src/diffpy/srfit/fitbase/fitrecipe.py b/src/diffpy/srfit/fitbase/fitrecipe.py index ffc1d595..4d7be2c9 100644 --- a/src/diffpy/srfit/fitbase/fitrecipe.py +++ b/src/diffpy/srfit/fitbase/fitrecipe.py @@ -36,7 +36,8 @@ from collections import OrderedDict -import six +import matplotlib.pyplot as plt +from bg_mpl_stylesheets.styles import all_styles from numpy import array, concatenate, dot, sqrt from diffpy.srfit.fitbase.fithook import PrintFitHook @@ -45,6 +46,8 @@ from diffpy.srfit.interface import _fitrecipe_interface from diffpy.srfit.util.tagmanager import TagManager +plt.style.use(all_styles["bg-style"]) + class FitRecipe(_fitrecipe_interface, RecipeOrganizer): """FitRecipe class. @@ -151,6 +154,34 @@ def __init__(self, name="fit"): self._contributions = OrderedDict() self._manage(self._contributions) + self.plot_options = { + "show_observed": True, + "show_fit": True, + "show_diff": True, + "offset_scale": 1.0, + "xmin": None, + "xmax": None, + "figsize": (8, 6), + "data_style": "o", + "fit_style": "-", + "diff_style": "-", + "data_color": None, + "fit_color": None, + "diff_color": None, + "data_label": "Observed", + "fit_label": "Calculated", + "diff_label": "Difference", + "xlabel": None, + "ylabel": None, + "title": None, + "legend": True, + "legend_loc": "best", + "grid": False, + "markersize": None, + "linewidth": None, + "alpha": 1.0, + "show": True, + } return def pushFitHook(self, fithook, index=None): @@ -641,7 +672,7 @@ def __get_var_and_check(self, var): Returns the variable or None if the variable cannot be found in the _parameters list. """ - if isinstance(var, six.string_types): + if isinstance(var, str): var = self._parameters.get(var) if var not in self._parameters.values(): @@ -661,9 +692,7 @@ def __get_vars_from_args(self, *args, **kw): or if a tag is passed in a keyword. """ # Process args. Each variable is tagged with its name, so this is easy. - strargs = set( - [arg for arg in args if isinstance(arg, six.string_types)] - ) + strargs = set([arg for arg in args if isinstance(arg, str)]) varargs = set(args) - strargs # Check that the tags are valid alltags = set(self._tagmanager.alltags()) @@ -764,7 +793,7 @@ def unconstrain(self, *pars): """ update = False for par in pars: - if isinstance(par, six.string_types): + if isinstance(par, str): name = par par = self.get(name) @@ -815,7 +844,7 @@ def constrain(self, par, con, ns={}): the FitRecipe and that is not defined in ns. Raises ValueError if par is marked as constant. """ - if isinstance(par, six.string_types): + if isinstance(par, str): name = par par = self.get(name) if par is None: @@ -871,6 +900,290 @@ def getBounds2(self): ub = array([b[1] for b in bounds]) return lb, ub + def set_plot_defaults(self, **kwargs): + """Set default plotting options for all future plots. + + Any keyword argument accepted by plot_recipe() can be set here. + + Parameters + ---------- + show_observed : bool, optional + The observed data is plotted if True. Default is True. + show_fit : bool, optional + The fit to the data is plotted if True. Default is True. + show_diff : bool, optional + The difference curve (observed - calculated) is plotted if True. + Default is True. + offset_scale : float, optional + The scaling factor for the difference curve offset. The difference + curve is offset below the data by + (min_y - 0.1*range) * offset_scale. Default is 1.0. + xmin : float or None, optional + The minimum x value to plot. If None, uses the minimum x value + of the data. Default is None. + xmax : float or None, optional + The maximum x value to plot. If None, uses the maximum x value + of the data. Default is None. + figsize : tuple, optional + The figure size as (width, height). Default is (8, 6). + data_style : str, optional + The matplotlib line/marker style for data points. Default is "o". + fit_style : str, optional + The matplotlib line/marker style for the calculated fit. + Default is "-". + diff_style : str, optional + The matplotlib line/marker style for the difference curve. + Default is "-". + data_color : str or None, optional + The color for data plot. If None, uses default matplotlib colors. + fit_color : str or None, optional + The color for the fit plot. If None, uses default matplotlib + colors. + diff_color : str or None, optional + The color for the difference plot. If None, uses default + matplotlib colors. + data_label : str, optional + The legend label for observed data. Default is "Observed". + fit_label : str, optional + The legend label for the calculated fit. Default is "Calculated". + diff_label : str, optional + The legend label for the difference curve. Default is "Difference". + xlabel : str, optional + The label for the x-axis. + ylabel : str, optional + The label for the y-axis. + title : str or None, optional + The plot title. Default is no title. + legend : bool, optional + The legend is shown if True. Default is True. + legend_loc : str, optional + The legend location. Default is "best". + grid : bool, optional + The grid is shown if True. Default is False. + markersize : float, optional + The size of data point markers. + linewidth : float, optional + The width of fit and difference lines. + alpha : float, optional + The transparency of all plot elements (0=transparent, 1=opaque). + Default is 1.0. + show : bool, optional + The plot is displayed using `plt.show()` if True. Default is True. + ax : matplotlib.axes.Axes or None, optional + The axes object to plot on. If None, creates a new figure. + Default is None. + return_fig : bool, optional + The figure and axes objects are returned if True. Default is False. + + Examples + -------- + >>> recipe.set_plot_defaults( + xlabel='r (Å)', + ylabel='G(r) (Å⁻²)', + data_color='black', + fit_color='red' + ) + """ + + for key in kwargs: + if key not in self.plot_options: + print( + f"Warning: '{key}' is not a valid " + "plot_recipe option and will be ignored." + ) + self.plot_options.update(kwargs) + + def _set_axes_labels_from_metadata(self, meta, plot_params): + """Set axes labels based on filename suffix in profile metadata if not + already set.""" + if isinstance(meta, dict): + filename = meta.get("filename") + if filename: + suffix = filename.rsplit(".", 1)[-1].lower() + if "gr" in suffix: + if plot_params.get("xlabel") is None: + plot_params["xlabel"] = r"r ($\mathrm{\AA}$)" + if plot_params.get("ylabel") is None: + plot_params["ylabel"] = r"G ($\mathrm{\AA}^{-2}$)" + return + + def plot_recipe(self, ax=None, return_fig=False, **kwargs): + """Plot the observed, fit, and difference curves for each contribution + of the fit recipe. + + If the recipe has multiple contributions, a separate + plot is created for each contribution. + + Parameters + ---------- + ax : matplotlib.axes.Axes or None, optional + The axes object to plot on. If None, creates a new figure. + Default is None. + return_fig : bool, optional + The figure and axes objects are returned if True. Default is False. + **kwargs : dict + Any plotting option can be passed to override the defaults in + `FitRecipe().plot_options`. See the + `FitRecipe().set_plot_defaults()` method for available + keyword arguments. + + Returns + ------- + fig, axes : tuple of (mpl.figure.Figure, list of mpl.axes.Axes) + The figure object and a list of axes objects (one per contribution) + are returned if return_fig=True. + + Examples + -------- + Plot with default settings: + + >>> recipe.plot_recipe() + + Override defaults for one plot: + + >>> recipe.plot_recipe(show_diff=False, title='My Custom Title') + + Set defaults once, use everywhere: + + >>> recipe.set_plot_defaults(xlabel='r (Å)', ylabel='G(r)') + >>> recipe.plot_recipe() # Uses xlabel and ylabel + >>> recipe.plot_recipe() # Still uses them + + Override a default for one plot: + + >>> recipe.set_plot_defaults(figsize=(10, 7)) + >>> recipe.plot_recipe() # Uses (10, 7) + >>> recipe.plot_recipe(figsize=(12, 8)) # Temporarily uses (12, 8) + >>> recipe.plot_recipe() # Back to (10, 7) + + Notes + ----- + The default values are taken from recipe.plot_options. You can modify + these defaults in three ways: + + 1. Using set_plot_defaults(): + recipe.set_plot_defaults(xlabel='r (Å)') + + 2. Direct attribute access: + recipe.plot_options['xlabel'] = 'r (Å)' + + 3. Using update(): + recipe.plot_options.update({'xlabel': 'r (Å)', 'ylabel': 'G(r)'}) + """ + plot_params = self.plot_options.copy() + plot_params.update(kwargs) + + if not any( + [ + plot_params["show_observed"], + plot_params["show_fit"], + plot_params["show_diff"], + ] + ): + raise ValueError( + "At least one of show_observed, show_fit, " + "or show_diff must be True" + ) + + if not self._contributions: + raise ValueError( + "No contributions found in recipe. " + "Add contributions before plotting." + ) + figures = [] + axes_list = [] + for name, contrib in self._contributions.items(): + profile = contrib.profile + x = profile.x + yobs = profile.y + ycalc = profile.ycalc + if ycalc is None: + if plot_params["show_fit"] or plot_params["show_diff"]: + print( + f"Contribution '{name}' has no calculated values " + "(ycalc is None). " + "Only observed data will be plotted." + ) + plot_params["show_fit"] = False + plot_params["show_diff"] = False + else: + diff = yobs - ycalc + y_min = min(yobs.min(), ycalc.min()) + y_max = max(yobs.max(), ycalc.max()) + y_range = y_max - y_min + base_offset = y_min - 0.1 * y_range + offset = base_offset * plot_params["offset_scale"] + if ax is None: + fig = plt.figure(figsize=plot_params["figsize"]) + current_ax = fig.add_subplot(111) + else: + current_ax = ax + fig = current_ax.figure + if plot_params["show_observed"]: + current_ax.plot( + x, + yobs, + plot_params["data_style"], + label=plot_params["data_label"], + color=plot_params["data_color"], + markersize=plot_params["markersize"], + alpha=plot_params["alpha"], + ) + if plot_params["show_fit"]: + current_ax.plot( + x, + ycalc, + plot_params["fit_style"], + label=plot_params["fit_label"], + color=plot_params["fit_color"], + linewidth=plot_params["linewidth"], + alpha=plot_params["alpha"], + ) + if plot_params["show_diff"]: + current_ax.plot( + x, + diff + offset, + plot_params["diff_style"], + label=plot_params["diff_label"], + color=plot_params["diff_color"], + linewidth=plot_params["linewidth"], + alpha=plot_params["alpha"], + ) + current_ax.axhline( + offset, + color="black", + ) + meta = getattr(profile, "meta", None) + if meta: + self._set_axes_labels_from_metadata(meta, plot_params) + if plot_params["xlabel"] is not None: + current_ax.set_xlabel(plot_params["xlabel"]) + if plot_params["ylabel"] is not None: + current_ax.set_ylabel(plot_params["ylabel"]) + if plot_params["title"] is not None: + current_ax.set_title(plot_params["title"]) + if plot_params["legend"]: + current_ax.legend(loc=plot_params["legend_loc"], frameon=True) + if plot_params["grid"]: + current_ax.grid(True) + if ( + plot_params["xmin"] is not None + or plot_params["xmax"] is not None + ): + current_ax.set_xlim( + left=plot_params["xmin"], right=plot_params["xmax"] + ) + fig.tight_layout() + figures.append(fig) + axes_list.append(current_ax) + if plot_params["show"] and ax is None: + plt.show() + if return_fig: + if len(figures) == 1: + return figures[0], axes_list[0] + else: + return figures, axes_list + def boundsToRestraints(self, sig=1, scaled=False): """Turn all bounded parameters into restraints. diff --git a/tests/conftest.py b/tests/conftest.py index 0bb50a2c..250bccf3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,8 +5,10 @@ import pytest import six +from numpy import linspace, pi, sin import diffpy.srfit.equation.literals as literals +from diffpy.srfit.fitbase import FitContribution, FitRecipe, Profile logger = logging.getLogger(__name__) @@ -142,3 +144,70 @@ def _capturestdout(f, *args, **kwargs): return fp.getvalue() return _capturestdout + + +@pytest.fixture(scope="session") +def build_recipe_one_contribution(): + "helper to build a simple recipe" + profile = Profile() + x = linspace(0, pi, 10) + y = sin(x) + profile.setObservedProfile(x, y) + contribution = FitContribution("c1") + contribution.setProfile(profile) + contribution.setEquation("A*sin(k*x + c)") + recipe = FitRecipe() + recipe.addContribution(contribution) + recipe.addVar(contribution.A, 1) + recipe.addVar(contribution.k, 1) + recipe.addVar(contribution.c, 1) + return recipe + + +@pytest.fixture(scope="session") +def build_recipe_two_contributions(): + "helper to build a recipe with two contributions" + profile1 = Profile() + x = linspace(0, pi, 10) + y1 = sin(x) + profile1.setObservedProfile(x, y1) + contribution1 = FitContribution("c1") + contribution1.setProfile(profile1) + contribution1.setEquation("A*sin(k*x + c)") + + profile2 = Profile() + y2 = 0.5 * sin(2 * x) + profile2.setObservedProfile(x, y2) + contribution2 = FitContribution("c2") + contribution2.setProfile(profile2) + contribution2.setEquation("B*sin(m*x + d)") + recipe = FitRecipe() + recipe.addContribution(contribution1) + recipe.addContribution(contribution2) + recipe.addVar(contribution1.A, 1) + recipe.addVar(contribution1.k, 1) + recipe.addVar(contribution1.c, 1) + recipe.addVar(contribution2.B, 0.5) + recipe.addVar(contribution2.m, 2) + recipe.addVar(contribution2.d, 0) + + return recipe + + +@pytest.fixture +def temp_data_files(tmp_path): + """ + Temporary directory containing: + - data_with_meta.gr + - data_without_meta.dat + Each file contains a single line of data. + """ + file_with_meta = tmp_path / "gr_file.gr" + file_with_meta.write_text("1.0 2.0\n" "1.1 2.1\n" "1.2 2.2\n") + + dat_file = tmp_path / "dat_file.dat" + dat_file.write_text("1.0 2.0\n" "1.1 2.1\n" "1.2 2.2\n") + + cgr_file = tmp_path / "cgr_file.cgr" + cgr_file.write_text("1.0 2.0\n" "1.1 2.1\n" "1.2 2.2\n") + yield tmp_path diff --git a/tests/test_fitrecipe.py b/tests/test_fitrecipe.py index 1a2b2368..7b40f229 100644 --- a/tests/test_fitrecipe.py +++ b/tests/test_fitrecipe.py @@ -16,12 +16,19 @@ import unittest +import matplotlib +import matplotlib.pyplot as plt +import pytest from numpy import array_equal, dot, linspace, pi, sin +from scipy.optimize import leastsq from diffpy.srfit.fitbase.fitcontribution import FitContribution from diffpy.srfit.fitbase.fitrecipe import FitRecipe from diffpy.srfit.fitbase.parameter import Parameter from diffpy.srfit.fitbase.profile import Profile +from diffpy.srfit.pdf import PDFParser + +matplotlib.use("Agg") class TestFitRecipe(unittest.TestCase): @@ -284,5 +291,401 @@ def testPrintFitHook(capturestdout): return +def optimize_recipe(recipe): + recipe.fithooks[0].verbose = 0 + residuals = recipe.residual + values = recipe.values + leastsq(residuals, values) + + +def get_labels_and_linecount(ax): + """Helper to get line labels and count from a matplotlib Axes.""" + labels = [ + line.get_label() + for line in ax.get_lines() + if not line.get_label().startswith("_") + ] + line_count = len( + [ + line + for line in ax.get_lines() + if not line.get_label().startswith("_") + ] + ) + return labels, line_count + + +def build_recipe_from_datafile(datafile): + """Helper to build a FitRecipe from a datafile using PDFParser and + PDFGenerator.""" + profile = Profile() + parser = PDFParser() + parser.parseFile(str(datafile)) + profile.loadParsedData(parser) + + contribution = FitContribution("c") + contribution.setProfile(profile) + contribution.setEquation("m*x + b") + recipe = FitRecipe() + recipe.addContribution(contribution) + recipe.addVar(contribution.m, 1) + recipe.addVar(contribution.b, 0) + return recipe + + +def test_plot_recipe_bad_display(build_recipe_one_contribution): + recipe = build_recipe_one_contribution + # Case: All plots are disabled + # expected: raised ValueError with message + plt.close("all") + msg = "At least one of show_observed, show_fit, or show_diff must be True" + with pytest.raises(ValueError, match=msg): + recipe.plot_recipe( + show_observed=False, + show_diff=False, + show_fit=False, + ) + + +def test_plot_recipe_no_contribution(): + recipe = FitRecipe() + # Case: No contributions in recipe + # expected: raised ValueError with message + plt.close("all") + msg = ( + "No contributions found in recipe. " + "Add contributions before plotting." + ) + with pytest.raises(ValueError, match=msg): + recipe.plot_recipe() + + +def test_plot_recipe_before_refinement(capsys, build_recipe_one_contribution): + # Case: User tries to plot recipe before refinement + # expected: Data plotted without fit line or difference curve + # and warning message printed + recipe = build_recipe_one_contribution + plt.close("all") + before = set(plt.get_fignums()) + # include fit_label="nothing" to make sure fit line is not plotted + fig, ax = recipe.plot_recipe( + show=False, data_label="my data", fit_label="nothing", return_fig=True + ) + after = set(plt.get_fignums()) + new_figs = after - before + captured = capsys.readouterr() + actual = captured.out.strip() + expected = ( + "Contribution 'c1' has no calculated values (ycalc is None). " + "Only observed data will be plotted." + ) + # get labels from the plotted line + actual_label, actual_line_count = get_labels_and_linecount(ax) + expected_line_count = 1 + expected_label = ["my data"] + assert actual_line_count == expected_line_count + assert actual_label == expected_label + assert len(new_figs) == 1 + assert actual == expected + + +def test_plot_recipe_after_refinement(build_recipe_one_contribution): + # Case: User refines recipe and then plots + # expected: Plot generates with no problem + recipe = build_recipe_one_contribution + optimize_recipe(recipe) + plt.close("all") + before = set(plt.get_fignums()) + fig, ax = recipe.plot_recipe(show=False, return_fig=True) + after = set(plt.get_fignums()) + new_figs = after - before + actual_label, actual_line_count = get_labels_and_linecount(ax) + expected_label = ["Observed", "Calculated", "Difference"] + expected_line_count = 3 + assert actual_line_count == expected_line_count + assert actual_label == expected_label + assert len(new_figs) == 1 + + +def test_plot_recipe_two_contributions(build_recipe_two_contributions): + # Case: Two contributions in recipe + # expected: two figures created + recipe = build_recipe_two_contributions + optimize_recipe(recipe) + plt.close("all") + before = set(plt.get_fignums()) + figs, axes = recipe.plot_recipe(show=False, return_fig=True) + for ax in axes: + actual_label, actual_line_count = get_labels_and_linecount(ax) + expected_label = ["Observed", "Calculated", "Difference"] + expected_line_count = 3 + assert actual_line_count == expected_line_count + assert actual_label == expected_label + after = set(plt.get_fignums()) + new_figs = after - before + assert len(new_figs) == 2 + + +def test_plot_recipe_on_existing_plot(build_recipe_one_contribution): + # Case: User passes axes to plot_recipe to plot on existing figure + # expected: User modifications are present in the final figure + recipe = build_recipe_one_contribution + optimize_recipe(recipe) + plt.close("all") + fig, ax = plt.subplots() + ax.set_title("User Title") + ax.plot([0, 1], [0, 1], label="New Data") + recipe.plot_recipe(ax=ax, show=False) + actual_title = ax.get_title() + expected_title = "User Title" + actual_labels, actual_line_count = get_labels_and_linecount(ax) + expected_line_count = 4 + expected_labels = ["Calculated", "Difference", "New Data", "Observed"] + assert actual_line_count == expected_line_count + assert sorted(actual_labels) == sorted(expected_labels) + assert actual_title == expected_title + + +def test_plot_recipe_add_new_data(build_recipe_one_contribution): + # Case: User wants to add data to figure generated by plot_recipe + # Expected: New data is added to existing figure (check with labels) + recipe = build_recipe_one_contribution + optimize_recipe(recipe) + plt.close("all") + before = set(plt.get_fignums()) + fig, ax = recipe.plot_recipe(return_fig=True, show=False) + after = set(plt.get_fignums()) + new_figs = after - before + # add new data to existing plot + ax.plot([0, pi], [0, 0], label="New Data") + ax.legend() + actual_labels, actual_line_count = get_labels_and_linecount(ax) + expected_labels = ["Observed", "Calculated", "Difference", "New Data"] + expected_line_count = 4 + assert len(new_figs) == 1 + assert actual_line_count == expected_line_count + assert sorted(actual_labels) == sorted(expected_labels) + + +def test_plot_recipe_add_new_data_two_figs(build_recipe_two_contributions): + # Case: User wants to add data to figure generated by plot_recipe + # with two contributions + # Expected: New data is added to existing figure (check with labels) + recipe = build_recipe_two_contributions + optimize_recipe(recipe) + plt.close("all") + before = set(plt.get_fignums()) + figure, axes = recipe.plot_recipe(return_fig=True, show=False) + after = set(plt.get_fignums()) + new_figs = after - before + # add new data to existing plots + for ax in axes: + ax.plot([0, pi], [0, 0], label="New Data") + ax.legend() + actual_labels, actual_line_count = get_labels_and_linecount(ax) + expected_labels = ["Observed", "Calculated", "Difference", "New Data"] + expected_line_count = 4 + assert actual_line_count == expected_line_count + assert sorted(actual_labels) == sorted(expected_labels) + assert len(new_figs) == 2 + + +def test_plot_recipe_set_title(build_recipe_one_contribution): + # Case: User sets title via plot_recipe + # Expected: Title is set correctly + recipe = build_recipe_one_contribution + optimize_recipe(recipe) + plt.close("all") + expected_title = "Custom Recipe Title" + figure, ax = recipe.plot_recipe( + title=expected_title, return_fig=True, show=False + ) + actual_title = ax.get_title() + assert actual_title == expected_title + + +def test_plot_recipe_set_defaults(build_recipe_one_contribution): + # Case: user sets default plot options with set_plot_defaults + # Expected: plot_recipe uses the default options for all calls + recipe = build_recipe_one_contribution + optimize_recipe(recipe) + plt.close("all") + # set new defaults + recipe.set_plot_defaults( + show_observed=False, + show_fit=True, + show_diff=False, + data_label="Data Default", + fit_label="Fit Default", + diff_label="Diff Default", + title="Default Title", + ) + # call plot_recipe without any arguments + figure, ax = recipe.plot_recipe(return_fig=True, show=False) + actual_title = ax.get_title() + actual_labels, actual_line_count = get_labels_and_linecount(ax) + expected_title = "Default Title" + expected_labels = ["Fit Default"] + expected_line_count = 1 + assert actual_title == expected_title + assert actual_line_count == expected_line_count + assert actual_labels == expected_labels + + +def test_plot_recipe_set_defaults_bad(capsys, build_recipe_one_contribution): + # Case: user tries to set kwargs that are not valid plot_recipe options + # Expected: Plot is shown and warning is printed + recipe = build_recipe_one_contribution + optimize_recipe(recipe) + plt.close("all") + recipe.set_plot_defaults( + invalid_option="blah", + ) + captured = capsys.readouterr() + actual_msg = captured.out.strip() + expected_msg = ( + "Warning: 'invalid_option' is not a valid " + "plot_recipe option and will be ignored." + ) + assert actual_msg == expected_msg + before = set(plt.get_fignums()) + figure, ax = recipe.plot_recipe(return_fig=True, show=False) + after = set(plt.get_fignums()) + new_figs = after - before + assert len(new_figs) == 1 + + +@pytest.mark.parametrize( + "input,expected", + [ + # case1: .gr file + # expected: labels are inferred from file + ("gr_file.gr", [r"r ($\mathrm{\AA}$)", r"G ($\mathrm{\AA}^{-2}$)"]), + # case2: .dat file + # expected: default empty labels + ("dat_file.dat", ["", ""]), + # case3: .cgr file + # expected: labels are inferred from file + ("cgr_file.cgr", [r"r ($\mathrm{\AA}$)", r"G ($\mathrm{\AA}^{-2}$)"]), + ], +) +def test_plot_recipe_labels_from_gr_file(temp_data_files, input, expected): + gr_file = temp_data_files / input + recipe = build_recipe_from_datafile(gr_file) + optimize_recipe(recipe) + plt.close("all") + fig, ax = recipe.plot_recipe(return_fig=True, show=False) + actual_xlabel = ax.get_xlabel() + actual_ylabel = ax.get_ylabel() + expected_xlabel = expected[0] + expected_ylabel = expected[1] + assert actual_xlabel == expected_xlabel + assert actual_ylabel == expected_ylabel + + +def test_plot_recipe_labels_from_gr_file_overwrite(temp_data_files): + gr_file = temp_data_files / "gr_file.gr" + recipe = build_recipe_from_datafile(gr_file) + optimize_recipe(recipe) + plt.close("all") + fig, ax = recipe.plot_recipe( + return_fig=True, show=False, xlabel="My X", ylabel="My Y" + ) + actual_xlabel = ax.get_xlabel() + actual_ylabel = ax.get_ylabel() + expected_xlabel = "My X" + expected_ylabel = "My Y" + assert actual_xlabel == expected_xlabel + assert actual_ylabel == expected_ylabel + + +def test_plot_recipe_reset_all_defaults(build_recipe_one_contribution): + expected_defaults = { + "show_observed": True, + "show_fit": True, + "show_diff": True, + "offset_scale": 0.5, + "xmin": 1, + "xmax": 10, + "figsize": (9, 10), + "data_style": "-", + "fit_style": "o", + "diff_style": "o", + "data_color": "blue", + "fit_color": "purple", + "diff_color": "orange", + "data_label": "my data", + "fit_label": "my fit", + "diff_label": "my diff", + "xlabel": "my x label", + "ylabel": "my y label", + "title": "my title", + "legend": False, + "legend_loc": "upper right", + "markersize": 5, + "linewidth": 3, + "alpha": 0.5, + "show": True, + } + + recipe = build_recipe_one_contribution + optimize_recipe(recipe) + plt.close("all") + + recipe.set_plot_defaults(**expected_defaults) + fig, ax = recipe.plot_recipe(return_fig=True, show=False) + + actual_title = ax.get_title() + actual_xlabel = ax.get_xlabel() + actual_ylabel = ax.get_ylabel() + + expected_title = expected_defaults["title"] + expected_xlabel = expected_defaults["xlabel"] + expected_ylabel = expected_defaults["ylabel"] + + assert actual_title == expected_title + assert actual_xlabel == expected_xlabel + assert actual_ylabel == expected_ylabel + + actual_labels, actual_line_count = get_labels_and_linecount(ax) + + expected_labels = [ + expected_defaults["data_label"], + expected_defaults["fit_label"], + expected_defaults["diff_label"], + ] + expected_line_count = 3 + + assert actual_line_count == expected_line_count + assert actual_labels == expected_labels + + lines_by_label = {line.get_label(): line for line in ax.get_lines()} + + data_line = lines_by_label[expected_defaults["data_label"]] + fit_line = lines_by_label[expected_defaults["fit_label"]] + diff_line = lines_by_label[expected_defaults["diff_label"]] + + assert data_line.get_color() == expected_defaults["data_color"] + assert fit_line.get_color() == expected_defaults["fit_color"] + assert diff_line.get_color() == expected_defaults["diff_color"] + + assert data_line.get_linestyle() == expected_defaults["data_style"] + assert fit_line.get_marker() == expected_defaults["fit_style"] + assert diff_line.get_marker() == expected_defaults["diff_style"] + + assert data_line.get_markersize() == expected_defaults["markersize"] + assert data_line.get_alpha() == expected_defaults["alpha"] + + actual_xlim = ax.get_xlim() + expected_xlim = (expected_defaults["xmin"], expected_defaults["xmax"]) + assert actual_xlim == expected_xlim + + # no legend + actual_legend = ax.get_legend() is not None + expected_legend = expected_defaults["legend"] + + assert actual_legend == expected_legend + + if __name__ == "__main__": unittest.main()