diff --git a/src/sas/qtgui/Plotting/SlicerModel.py b/src/sas/qtgui/Plotting/SlicerModel.py index 6844acf045..815376c52b 100644 --- a/src/sas/qtgui/Plotting/SlicerModel.py +++ b/src/sas/qtgui/Plotting/SlicerModel.py @@ -1,8 +1,21 @@ +from contextlib import contextmanager + from PySide6 import QtCore, QtGui import sas.qtgui.Utilities.GuiUtils as GuiUtils +@contextmanager +def temporary_flag(obj, attribute_name, value): + """Temporarily set an attribute value and always restore the previous value.""" + old_value = getattr(obj, attribute_name) + setattr(obj, attribute_name, value) + try: + yield + finally: + setattr(obj, attribute_name, old_value) + + class SlicerModel: def __init__(self): # Model representation of local parameters @@ -44,9 +57,8 @@ def setParamsFromModel(self): else: params[param_name] = float(self._model.item(row_index, 1).text()) - self.update_model = False - self.setParams(params) - self.update_model = True + with temporary_flag(self, "update_model", False): + self.setParams(params) def setParamsFromModelItem(self, item): """ @@ -61,9 +73,8 @@ def setParamsFromModelItem(self, item): else: params[param_name] = float(self._model.item(row_index, 1).text()) - self.update_model = False - self.setParams(params) - self.update_model = True + with temporary_flag(self, "update_model", False): + self.setParams(params) def model(self): '''getter for the model''' @@ -73,7 +84,7 @@ def getParams(self): ''' pure virtual ''' raise NotImplementedError("Parameter getter must be implemented in derived class.") - def setParams(self): + def setParams(self, params): ''' pure virtual ''' raise NotImplementedError("Parameter setter must be implemented in derived class.") diff --git a/src/sas/qtgui/Plotting/Slicers/BoxSum.py b/src/sas/qtgui/Plotting/Slicers/BoxSum.py index ed47933b54..efc58ed19d 100644 --- a/src/sas/qtgui/Plotting/Slicers/BoxSum.py +++ b/src/sas/qtgui/Plotting/Slicers/BoxSum.py @@ -5,6 +5,7 @@ from sasdata.data_util.manipulations import Boxavg, Boxsum +from sas.qtgui.Plotting.SlicerModel import temporary_flag from sas.qtgui.Plotting.Slicers.BaseInteractor import BaseInteractor from sas.qtgui.Utilities.GuiUtils import formatNumber, toDouble @@ -150,10 +151,10 @@ def setParamsFromModel(self): params["Width"] = toDouble(self.model().item(0, 1).text()) params["center_x"] = toDouble(self.model().item(0, 2).text()) params["center_y"] = toDouble(self.model().item(0, 3).text()) - self.update_model = False - self.setParams(params) - self.setReadOnlyParametersFromModel() - self.update_model = True + + with temporary_flag(self, "update_model", False): + self.setParams(params) + self.setReadOnlyParametersFromModel() def setPanelName(self, name): """ diff --git a/src/sas/qtgui/Plotting/Slicers/MultiSlicerBase.py b/src/sas/qtgui/Plotting/Slicers/MultiSlicerBase.py index 71e34d589d..6e4062ebf7 100644 --- a/src/sas/qtgui/Plotting/Slicers/MultiSlicerBase.py +++ b/src/sas/qtgui/Plotting/Slicers/MultiSlicerBase.py @@ -8,7 +8,7 @@ import numpy as np -from sas.qtgui.Plotting.SlicerModel import SlicerModel +from sas.qtgui.Plotting.SlicerModel import SlicerModel, temporary_flag from sas.qtgui.Plotting.Slicers.BaseInteractor import BaseInteractor from sas.qtgui.Plotting.Slicers.SectorSlicer import SectorInteractor from sas.qtgui.Plotting.Slicers.SlicerUtils import StackableMixin @@ -336,15 +336,11 @@ def _synchronized_moveend(self, ev, interactor_name): # Now post data for all secondary slicers (with update_model temporarily enabled) for i, slicer in enumerate(self.slicers[1:], start=1): try: - # Temporarily enable model updates for this slicer - old_update_model = slicer.update_model - slicer.update_model = True - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - slicer._post_data() - - # Restore original state - slicer.update_model = old_update_model + # Temporarily enable model updates for this slicer. + with temporary_flag(slicer, "update_model", True): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + slicer._post_data() except (ValueError, RuntimeError) as e: logger.warning(f"Failed to post data for slicer {i + 1}: {e}") diff --git a/src/sas/qtgui/Plotting/UnitTesting/SlicerModelTest.py b/src/sas/qtgui/Plotting/UnitTesting/SlicerModelTest.py index 0257c4f2bf..6170df2069 100644 --- a/src/sas/qtgui/Plotting/UnitTesting/SlicerModelTest.py +++ b/src/sas/qtgui/Plotting/UnitTesting/SlicerModelTest.py @@ -28,7 +28,7 @@ def testBaseClass(self, qapp): '''Assure that SlicerModel contains pure virtuals''' model = SlicerModel() with pytest.raises(NotImplementedError): - model.setParams() + model.setParams({}) with pytest.raises(NotImplementedError): model.setModelFromParams() @@ -64,3 +64,48 @@ def testSetParamsFromModel(self, model): # Check the new model. The update should be automatic assert model.model().rowCount() == 3 assert model.model().columnCount() == 2 + + def testSetParamsFromModel_restores_update_model_on_exception(self, qapp): + """update_model should be restored even if setParams raises.""" + class FailingModel(SlicerModel): + params = {"a": 1} + + def __init__(self): + SlicerModel.__init__(self) + + def getParams(self): + return self.params + + def setParams(self, par): + raise RuntimeError("boom") + + model = FailingModel() + model.setModelFromParams() + + with pytest.raises(RuntimeError): + model.setParamsFromModel() + + assert model.update_model is True + + def testSetParamsFromModelItem_restores_update_model_on_exception(self, qapp): + """update_model should be restored for single-item updates too.""" + class FailingModel(SlicerModel): + params = {"a": 1} + + def __init__(self): + SlicerModel.__init__(self) + + def getParams(self): + return self.params + + def setParams(self, par): + raise RuntimeError("boom") + + model = FailingModel() + model.setModelFromParams() + item = model.model().item(0, 1) + + with pytest.raises(RuntimeError): + model.setParamsFromModelItem(item) + + assert model.update_model is True