diff --git a/docs/src/api/api.rst b/docs/src/api/api.rst index 6c1d5b3b..ea8e2661 100644 --- a/docs/src/api/api.rst +++ b/docs/src/api/api.rst @@ -19,6 +19,15 @@ Sample is build from assemblies. sample +Project +======= +Project provides a higher-level interface for managing models, experiments, and ORSO import. + +.. toctree:: + :maxdepth: 1 + + project + Assemblies ========== Assemblies are collections of layers that are used to represent a specific physical setup. diff --git a/pyproject.toml b/pyproject.toml index 984758ed..630422ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,11 +29,11 @@ classifiers = [ requires-python = ">=3.11,<3.13" dependencies = [ - #"easyscience @ git+https://github.com/easyscience/corelib.git@dict_size_changed_bug", - "easyscience", + "easyscience @ git+https://github.com/easyscience/corelib.git@develop", + #"easyscience", "scipp", "refnx", - "refl1d>=1.0.0rc0", + "refl1d>=1.0.0", "orsopy", "svglib<1.6 ; platform_system=='Linux'", "xhtml2pdf", diff --git a/src/easyreflectometry/calculators/calculator_base.py b/src/easyreflectometry/calculators/calculator_base.py index 1be7b568..7d1314cd 100644 --- a/src/easyreflectometry/calculators/calculator_base.py +++ b/src/easyreflectometry/calculators/calculator_base.py @@ -7,7 +7,7 @@ from easyscience.fitting.calculators.interface_factory import ItemContainer from easyscience.io import SerializerComponent -#if TYPE_CHECKING: +# if TYPE_CHECKING: from easyreflectometry.model import Model from easyreflectometry.sample import BaseAssembly from easyreflectometry.sample import Layer diff --git a/src/easyreflectometry/data/data_store.py b/src/easyreflectometry/data/data_store.py index 8dde0bf7..42fd28b0 100644 --- a/src/easyreflectometry/data/data_store.py +++ b/src/easyreflectometry/data/data_store.py @@ -76,7 +76,7 @@ def __init__( y: Optional[Union[np.ndarray, list]] = None, ye: Optional[Union[np.ndarray, list]] = None, xe: Optional[Union[np.ndarray, list]] = None, - model: Optional['Model'] = None, # delay type checking until runtime (quotes) + model: Optional['Model'] = None, # delay type checking until runtime (quotes) x_label: str = 'x', y_label: str = 'y', ): @@ -117,7 +117,7 @@ def __init__( self._color = None @property - def model(self) -> 'Model': # delay type checking until runtime (quotes) + def model(self) -> 'Model': # delay type checking until runtime (quotes) return self._model @model.setter diff --git a/src/easyreflectometry/fitting.py b/src/easyreflectometry/fitting.py index 99b5e614..a33e6570 100644 --- a/src/easyreflectometry/fitting.py +++ b/src/easyreflectometry/fitting.py @@ -55,13 +55,13 @@ def fit(self, data: sc.DataGroup, id: int = 0) -> sc.DataGroup: variances = data['data'][f'R_{i}'].variances # Find points with non-zero variance - zero_variance_mask = (variances == 0.0) + zero_variance_mask = variances == 0.0 num_zero_variance = np.sum(zero_variance_mask) if num_zero_variance > 0: warnings.warn( - f"Masked {num_zero_variance} data point(s) in reflectivity {i} due to zero variance during fitting.", - UserWarning + f'Masked {num_zero_variance} data point(s) in reflectivity {i} due to zero variance during fitting.', + UserWarning, ) # Keep only points with non-zero variances diff --git a/src/easyreflectometry/orso_utils.py b/src/easyreflectometry/orso_utils.py index 23756eba..3b33a14d 100644 --- a/src/easyreflectometry/orso_utils.py +++ b/src/easyreflectometry/orso_utils.py @@ -27,14 +27,16 @@ def LoadOrso(orso_str: str): return sample, data + def load_data_from_orso_file(fname: str) -> sc.DataGroup: """Load data from an ORSO file.""" try: orso_data = orso.load_orso(fname) except Exception as e: - raise ValueError(f"Error loading ORSO file: {e}") + raise ValueError(f'Error loading ORSO file: {e}') return load_orso_data(orso_data) + def load_orso_model(orso_str: str) -> Sample: """ Load a model from an ORSO file and return a Sample object. @@ -64,9 +66,9 @@ def load_orso_model(orso_str: str) -> Sample: # Handle case where layers are not resolved correctly if not orso_layers: - raise ValueError("Could not resolve ORSO layers.") + raise ValueError('Could not resolve ORSO layers.') - logger.debug(f"Resolved layers: {orso_layers}") + logger.debug(f'Resolved layers: {orso_layers}') # Convert ORSO layers to EasyReflectometry layers erl_layers = [] @@ -98,7 +100,7 @@ def _convert_orso_layer_to_erl(layer): material=Material(sld=m_sld, isld=m_isld, name=m_name), thickness=layer.thickness.magnitude if layer.thickness is not None else 0.0, roughness=layer.roughness.magnitude if layer.roughness is not None else 0.0, - name=layer.original_name if layer.original_name is not None else m_name + name=layer.original_name if layer.original_name is not None else m_name, ) @@ -107,10 +109,7 @@ def _get_sld_values(material, material_name): if material.sld is None and material.mass_density is not None: # Calculate SLD from mass density m_density = material.mass_density.magnitude - density = MaterialDensity( - chemical_structure=material_name, - density=m_density - ) + density = MaterialDensity(chemical_structure=material_name, density=m_density) m_sld = density.sld.value m_isld = density.isld.value else: @@ -123,6 +122,7 @@ def _get_sld_values(material, material_name): return m_sld, m_isld + def load_orso_data(orso_str: str) -> DataSet1D: data = {} coords = {} diff --git a/src/easyreflectometry/project.py b/src/easyreflectometry/project.py index 196b9af1..1d3b4815 100644 --- a/src/easyreflectometry/project.py +++ b/src/easyreflectometry/project.py @@ -18,6 +18,8 @@ from easyreflectometry.data import DataSet1D from easyreflectometry.data import load_as_dataset from easyreflectometry.fitting import MultiFitter + +# from easyreflectometry.model import LinearSpline from easyreflectometry.model import Model from easyreflectometry.model import ModelCollection from easyreflectometry.model import PercentageFwhm @@ -268,10 +270,53 @@ def load_orso_file(self, path: Union[Path, str]) -> None: self._with_experiments = True pass - def set_sample_from_orso(self, sample) -> None: + def set_sample_from_orso(self, sample: Sample) -> None: + """Replace the current project model collection with a single model built from an ORSO-parsed sample. + + This is a convenience helper for the ORSO import pipeline where a complete + :class:`~easyreflectometry.sample.Sample` is constructed elsewhere. + + :param sample: Sample to set as the project's (single) model. + :type sample: easyreflectometry.sample.Sample + :return: ``None``. + :rtype: None + """ model = Model(sample=sample) self.models = ModelCollection([model]) + def add_sample_from_orso(self, sample: Sample) -> None: + """Add a new model with the given sample to the existing model collection. + + The created model is appended to :attr:`models`, its calculator interface is + set to the project's current calculator, and any materials referenced in the + sample are added to the project's material collection. + + After adding the model, :attr:`current_model_index` is updated to point to + the newly added model. + + :param sample: Sample to add as a new model. + :type sample: easyreflectometry.sample.Sample + :return: ``None``. + :rtype: None + """ + model = Model(sample=sample) + self.models.add_model(model) + # Set interface after adding to collection + model.interface = self._calculator + # Extract materials from the new model and add to project materials + self._materials.extend(self._get_materials_from_model(model)) + # Switch to the newly added model so its data is visible in the UI + self.current_model_index = len(self._models) - 1 + + def _get_materials_from_model(self, model: Model) -> 'MaterialCollection': + """Get all materials from a single model's sample.""" + materials_in_model = MaterialCollection(populate_if_none=False) + for assembly in model.sample: + for layer in assembly.layers: + if layer.material not in materials_in_model: + materials_in_model.append(layer.material) + return materials_in_model + def load_new_experiment(self, path: Union[Path, str]) -> None: new_experiment = load_as_dataset(str(path)) new_index = len(self._experiments) @@ -291,6 +336,10 @@ def load_new_experiment(self, path: Union[Path, str]) -> None: q_error = new_experiment.xe # TODO: set resolution function based on value of control in GUI resolution_function = Pointwise(q_data_points=[q, reflectivity, q_error]) + # resolution_function = LinearSpline( + # q_data_points=self._experiments[new_index].y, + # fwhm_values=np.sqrt(self._experiments[new_index].ye), + # ) self.models[model_index].resolution_function = resolution_function def load_experiment_for_model_at_index(self, path: Union[Path, str], index: Optional[int] = 0) -> None: diff --git a/tests/calculators/refl1d/test_refl1d_calculator.py b/tests/calculators/refl1d/test_refl1d_calculator.py index 4dff8a9b..ba8c8d35 100644 --- a/tests/calculators/refl1d/test_refl1d_calculator.py +++ b/tests/calculators/refl1d/test_refl1d_calculator.py @@ -61,7 +61,7 @@ def test_reflectity_profile(self): 5.7605e-07, 2.3775e-07, 1.3093e-07, - 1.0520e-07 + 1.0520e-07, ] assert_almost_equal(p.reflectity_profile(q, 'MyModel'), expected, decimal=4) @@ -106,7 +106,7 @@ def test_calculate2(self): 1.0968e-06, 4.5635e-07, 3.4120e-07, - 2.7505e-07 + 2.7505e-07, ] assert_almost_equal(actual, expected, decimal=4) diff --git a/tests/calculators/refl1d/test_refl1d_wrapper.py b/tests/calculators/refl1d/test_refl1d_wrapper.py index 725aca6a..e19dfe42 100644 --- a/tests/calculators/refl1d/test_refl1d_wrapper.py +++ b/tests/calculators/refl1d/test_refl1d_wrapper.py @@ -232,7 +232,7 @@ def test_calculate(self): 5.7605e-07, 2.3775e-07, 1.3093e-07, - 1.0520e-07 + 1.0520e-07, ] assert_almost_equal(p.calculate(q, 'MyModel'), expected, decimal=4) @@ -276,7 +276,7 @@ def test_calculate_three_items(self): 1.0968e-06, 4.5635e-07, 3.4120e-07, - 2.7505e-07 + 2.7505e-07, ] assert_almost_equal(p.calculate(q, 'MyModel'), expected, decimal=4) @@ -396,7 +396,7 @@ def test_get_polarized_probe_oversampling(): probe = _get_polarized_probe(q_array=q, dq_array=dq, model_name=model_name, storage=storage, oversampling_factor=2) # Then - assert len(probe.xs[0].calc_Qo) == 2*len(q) + assert len(probe.xs[0].calc_Qo) == 2 * len(q) def test_get_polarized_probe_polarization(): diff --git a/tests/data/test_data_store.py b/tests/data/test_data_store.py index 17f837e2..84acf4f8 100644 --- a/tests/data/test_data_store.py +++ b/tests/data/test_data_store.py @@ -29,13 +29,7 @@ def test_constructor_default_values(self): def test_constructor_with_values(self): # When data = DataSet1D( - x=[1, 2, 3], - y=[4, 5, 6], - ye=[7, 8, 9], - xe=[10, 11, 12], - x_label='label_x', - y_label='label_y', - name='MyDataSet1D' + x=[1, 2, 3], y=[4, 5, 6], ye=[7, 8, 9], xe=[10, 11, 12], x_label='label_x', y_label='label_y', name='MyDataSet1D' ) # Then @@ -116,9 +110,7 @@ def test_is_simulation_property(self): def test_data_points(self): # When - data = DataSet1D( - x=[1, 2, 3], y=[4, 5, 6], ye=[7, 8, 9], xe=[10, 11, 12] - ) + data = DataSet1D(x=[1, 2, 3], y=[4, 5, 6], ye=[7, 8, 9], xe=[10, 11, 12]) # Then points = list(data.data_points()) @@ -126,9 +118,7 @@ def test_data_points(self): def test_repr(self): # When - data = DataSet1D( - x=[1, 2, 3], y=[4, 5, 6], x_label='Q', y_label='R' - ) + data = DataSet1D(x=[1, 2, 3], y=[4, 5, 6], x_label='Q', y_label='R') # Then expected = "1D DataStore of 'Q' Vs 'R' with 3 data points" @@ -194,7 +184,7 @@ def test_setitem(self): item1 = DataSet1D(name='item1') item2 = DataSet1D(name='item2') store = DataStore(item1) - + # When store[0] = item2 @@ -314,4 +304,3 @@ def test_constructor_with_custom_datastores(self): assert project.sim_data == sim_store assert project.exp_data.name == 'CustomExp' assert project.sim_data.name == 'CustomSim' - diff --git a/tests/summary/test_summary.py b/tests/summary/test_summary.py index 319a1c9b..7a3c727f 100644 --- a/tests/summary/test_summary.py +++ b/tests/summary/test_summary.py @@ -177,7 +177,7 @@ def test_save_sld_plot(self, project: Project, tmp_path) -> None: # Expect assert os.path.exists(file_path) - @pytest.mark.skip(reason="Matplotlib issue with headless CI environments") + @pytest.mark.skip(reason='Matplotlib issue with headless CI environments') def test_save_fit_experiment_plot(self, project: Project, tmp_path) -> None: # When summary = Summary(project) diff --git a/tests/test_data.py b/tests/test_data.py index c16aa259..0ee95d94 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -224,8 +224,7 @@ def test_load_txt_three_columns(self): assert coords_name in er_data['coords'] # xe should be zeros for 3-column file - assert_almost_equal(er_data['coords'][coords_name].variances, - np.zeros_like(er_data['coords'][coords_name].values)) + assert_almost_equal(er_data['coords'][coords_name].variances, np.zeros_like(er_data['coords'][coords_name].values)) def test_load_txt_with_zero_errors(self): fpath = os.path.join(PATH_STATIC, 'ref_zero_var.txt') @@ -246,6 +245,7 @@ def test_load_txt_file_not_found(self): def test_load_txt_insufficient_columns(self): # Create a temporary file with insufficient columns import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: f.write('1.0 2.0\n') # Only 2 columns temp_path = f.name @@ -272,7 +272,7 @@ def test_load_orso_multiple_datasets(self): if data_key.replace('R_', '') in coord_key: coord_key_found = True break - assert coord_key_found, f"No corresponding coord found for {data_key}" + assert coord_key_found, f'No corresponding coord found for {data_key}' def test_load_orso_with_attrs(self): fpath = os.path.join(PATH_STATIC, 'test_example1.ort') diff --git a/tests/test_fitting.py b/tests/test_fitting.py index 0b02ed82..fdeba385 100644 --- a/tests/test_fitting.py +++ b/tests/test_fitting.py @@ -86,7 +86,7 @@ def test_fitting_with_zero_variance(): # First, load the raw data to count zero variance points raw_data = np.loadtxt(fpath, delimiter=',', comments='#') zero_variance_count = np.sum(raw_data[:, 2] == 0.0) # Error column - assert zero_variance_count == 6, f"Expected 6 zero variance points, got {zero_variance_count}" + assert zero_variance_count == 6, f'Expected 6 zero variance points, got {zero_variance_count}' # Load data through the measurement module (which already filters zero variance) data = load(fpath) @@ -129,12 +129,11 @@ def test_fitting_with_zero_variance(): # Capture warnings during fitting - check if zero variance points still exist in the data # and are properly handled by the fitting method with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + warnings.simplefilter('always') analysed = fitter.fit(data) # Check if any zero variance warnings were issued during fitting - fitting_warnings = [str(warning.message) for warning in w - if "zero variance during fitting" in str(warning.message)] + fitting_warnings = [str(warning.message) for warning in w if 'zero variance during fitting' in str(warning.message)] # The fitting method should handle zero variance points gracefully # If there are any zero variance points remaining in the data, they should be masked @@ -142,15 +141,15 @@ def test_fitting_with_zero_variance(): if len(fitting_warnings) > 0: # Verify the warning message format and that it mentions masking points for warning_msg in fitting_warnings: - assert "Masked" in warning_msg and "zero variance during fitting" in warning_msg - print(f"Info: {warning_msg}") # Log for debugging + assert 'Masked' in warning_msg and 'zero variance during fitting' in warning_msg + print(f'Info: {warning_msg}') # Log for debugging # Basic checks that fitting completed # The keys will be based on the filename, not just '0' model_keys = [k for k in analysed.keys() if k.endswith('_model')] sld_keys = [k for k in analysed.keys() if k.startswith('SLD_')] - assert len(model_keys) > 0, f"No model keys found in {list(analysed.keys())}" - assert len(sld_keys) > 0, f"No SLD keys found in {list(analysed.keys())}" + assert len(model_keys) > 0, f'No model keys found in {list(analysed.keys())}' + assert len(sld_keys) > 0, f'No SLD keys found in {list(analysed.keys())}' assert 'success' in analysed.keys() @@ -172,14 +171,12 @@ def test_fitting_with_manual_zero_variance(): variances[30:32] = 0.0 # 2 more zero variance points # Create scipp DataGroup manually - data = sc.DataGroup({ - 'coords': { - 'Qz_0': sc.array(dims=['Qz_0'], values=qz_values) - }, - 'data': { - 'R_0': sc.array(dims=['Qz_0'], values=r_values, variances=variances) + data = sc.DataGroup( + { + 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=qz_values)}, + 'data': {'R_0': sc.array(dims=['Qz_0'], values=r_values, variances=variances)}, } - }) + ) # Create a simple model for fitting si = Material(2.07, 0, 'Si') @@ -214,16 +211,15 @@ def test_fitting_with_manual_zero_variance(): # Capture warnings during fitting with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + warnings.simplefilter('always') analysed = fitter.fit(data) # Check that warnings were issued about zero variance points - fitting_warnings = [str(warning.message) for warning in w - if "zero variance during fitting" in str(warning.message)] + fitting_warnings = [str(warning.message) for warning in w if 'zero variance during fitting' in str(warning.message)] # Should have one warning about the 7 zero variance points (5 + 2) - assert len(fitting_warnings) == 1, f"Expected 1 warning, got {len(fitting_warnings)}: {fitting_warnings}" - assert "Masked 7 data point(s)" in fitting_warnings[0], f"Unexpected warning content: {fitting_warnings[0]}" + assert len(fitting_warnings) == 1, f'Expected 1 warning, got {len(fitting_warnings)}: {fitting_warnings}' + assert 'Masked 7 data point(s)' in fitting_warnings[0], f'Unexpected warning content: {fitting_warnings[0]}' # Basic checks that fitting completed despite zero variance points assert 'R_0_model' in analysed.keys() assert 'SLD_0' in analysed.keys() diff --git a/tests/test_measurement_comprehensive.py b/tests/test_measurement_comprehensive.py index 3842fbb5..e9bf6ffe 100644 --- a/tests/test_measurement_comprehensive.py +++ b/tests/test_measurement_comprehensive.py @@ -33,7 +33,7 @@ def test_load_function_with_orso_file(self): """Test that load() correctly identifies and loads ORSO files.""" fpath = os.path.join(PATH_STATIC, 'test_example1.ort') result = load(fpath) - + assert 'data' in result assert 'coords' in result assert len(result['data']) > 0 @@ -43,7 +43,7 @@ def test_load_function_with_txt_file(self): """Test that load() falls back to txt loading for non-ORSO files.""" fpath = os.path.join(PATH_STATIC, 'test_example1.txt') result = load(fpath) - + assert 'data' in result assert 'coords' in result assert 'R_test_example1' in result['data'] @@ -53,7 +53,7 @@ def test_load_as_dataset_returns_dataset1d(self): """Test that load_as_dataset returns a proper DataSet1D object.""" fpath = os.path.join(PATH_STATIC, 'test_example1.txt') dataset = load_as_dataset(fpath) - + assert isinstance(dataset, DataSet1D) assert hasattr(dataset, 'x') assert hasattr(dataset, 'y') @@ -65,7 +65,7 @@ def test_load_as_dataset_extracts_correct_basename(self): """Test that load_as_dataset correctly extracts file basename.""" fpath = os.path.join(PATH_STATIC, 'ref_concat_1.txt') dataset = load_as_dataset(fpath) - + # Should work without error and have data assert len(dataset.x) > 0 assert len(dataset.y) > 0 @@ -74,12 +74,12 @@ def test_merge_datagroups_preserves_all_data(self): """Test that merge_datagroups combines multiple data groups correctly.""" fpath1 = os.path.join(PATH_STATIC, 'test_example1.txt') fpath2 = os.path.join(PATH_STATIC, 'ref_concat_1.txt') - + group1 = load(fpath1) group2 = load(fpath2) - + merged = merge_datagroups(group1, group2) - + # Should have data from both groups assert len(merged['data']) >= len(group1['data']) assert len(merged['coords']) >= len(group1['coords']) @@ -88,9 +88,9 @@ def test_merge_datagroups_single_group(self): """Test that merge_datagroups works with a single group.""" fpath = os.path.join(PATH_STATIC, 'test_example1.ort') group = load(fpath) - + merged = merge_datagroups(group) - + # Should be equivalent to original assert len(merged['data']) == len(group['data']) assert len(merged['coords']) == len(group['coords']) @@ -99,7 +99,7 @@ def test_load_txt_handles_comma_delimiter(self): """Test that _load_txt correctly handles comma-delimited files.""" fpath = os.path.join(PATH_STATIC, 'ref_concat_1.txt') result = _load_txt(fpath) - + assert 'data' in result assert 'coords' in result # Should successfully parse comma-delimited data @@ -110,11 +110,10 @@ def test_load_txt_handles_three_columns(self): """Test that _load_txt handles files with only 3 columns (no xe).""" fpath = os.path.join(PATH_STATIC, 'ref_concat_1.txt') result = _load_txt(fpath) - + coords_key = list(result['coords'].keys())[0] # xe should be zeros - assert_array_equal(result['coords'][coords_key].variances, - np.zeros_like(result['coords'][coords_key].values)) + assert_array_equal(result['coords'][coords_key].variances, np.zeros_like(result['coords'][coords_key].values)) def test_load_txt_with_insufficient_columns(self): """Test that _load_txt raises error for files with too few columns.""" @@ -122,7 +121,7 @@ def test_load_txt_with_insufficient_columns(self): with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: f.write('1.0 2.0\n3.0 4.0\n') temp_path = f.name - + try: with pytest.raises(ValueError, match='File must contain at least 3 columns'): _load_txt(temp_path) @@ -133,7 +132,7 @@ def test_load_orso_with_multiple_datasets(self): """Test that _load_orso handles files with multiple datasets.""" fpath = os.path.join(PATH_STATIC, 'test_example2.ort') result = load_data_from_orso_file(fpath) - + # Should have multiple data entries assert len(result['data']) > 1 assert 'attrs' in result @@ -142,7 +141,7 @@ def test_load_orso_preserves_metadata(self): """Test that _load_orso preserves ORSO metadata in attrs.""" fpath = os.path.join(PATH_STATIC, 'test_example1.ort') result = load_data_from_orso_file(fpath) - + assert 'attrs' in result # Should have orso_header in attrs for data_key in result['data']: @@ -159,15 +158,9 @@ def test_constructor_all_parameters(self): y = [10, 20, 30, 40] xe = [0.1, 0.1, 0.1, 0.1] ye = [1, 2, 3, 4] - - dataset = DataSet1D( - name='TestData', - x=x, y=y, xe=xe, ye=ye, - x_label='Q (Å⁻¹)', - y_label='Reflectivity', - model=None - ) - + + dataset = DataSet1D(name='TestData', x=x, y=y, xe=xe, ye=ye, x_label='Q (Å⁻¹)', y_label='Reflectivity', model=None) + assert dataset.name == 'TestData' assert_array_equal(dataset.x, np.array(x)) assert_array_equal(dataset.y, np.array(y)) @@ -182,7 +175,7 @@ def test_is_experiment_vs_simulation_properties(self): sim_data = DataSet1D(x=[1, 2], y=[3, 4]) assert sim_data.is_simulation is True assert sim_data.is_experiment is False - + # Dataset with model is experiment exp_data = DataSet1D(x=[1, 2], y=[3, 4], model=Mock()) assert exp_data.is_experiment is True @@ -190,13 +183,8 @@ def test_is_experiment_vs_simulation_properties(self): def test_data_points_iterator(self): """Test the data_points method returns correct tuples.""" - dataset = DataSet1D( - x=[1, 2, 3], - y=[10, 20, 30], - xe=[0.1, 0.2, 0.3], - ye=[1, 2, 3] - ) - + dataset = DataSet1D(x=[1, 2, 3], y=[10, 20, 30], xe=[0.1, 0.2, 0.3], ye=[1, 2, 3]) + points = list(dataset.data_points()) expected = [(1, 10, 1, 0.1), (2, 20, 2, 0.2), (3, 30, 3, 0.3)] assert points == expected @@ -205,20 +193,15 @@ def test_model_property_with_background_setting(self): """Test that setting model updates background to minimum y value.""" dataset = DataSet1D(x=[1, 2, 3, 4], y=[5, 1, 8, 3]) mock_model = Mock() - + dataset.model = mock_model - + assert mock_model.background == 1 # minimum of [5, 1, 8, 3] def test_repr_string_representation(self): """Test the string representation of DataSet1D.""" - dataset = DataSet1D( - x=[1, 2, 3], - y=[4, 5, 6], - x_label='Momentum Transfer', - y_label='Intensity' - ) - + dataset = DataSet1D(x=[1, 2, 3], y=[4, 5, 6], x_label='Momentum Transfer', y_label='Intensity') + expected = "1D DataStore of 'Momentum Transfer' Vs 'Intensity' with 3 data points" assert str(dataset) == expected @@ -230,19 +213,19 @@ def test_datastore_as_sequence(self): """Test DataStore behaves like a sequence.""" item1 = DataSet1D(name='item1', x=[1], y=[2]) item2 = DataSet1D(name='item2', x=[3], y=[4]) - + store = DataStore(item1, item2, name='TestStore') - + # Test sequence operations assert len(store) == 2 assert store[0].name == 'item1' assert store[1].name == 'item2' - + # Test item replacement item3 = DataSet1D(name='item3', x=[5], y=[6]) store[0] = item3 assert store[0].name == 'item3' - + # Test deletion del store[0] assert len(store) == 1 @@ -254,12 +237,12 @@ def test_datastore_experiments_and_simulations_filtering(self): exp2 = DataSet1D(name='exp2', x=[3], y=[4], model=Mock()) sim1 = DataSet1D(name='sim1', x=[5], y=[6]) sim2 = DataSet1D(name='sim2', x=[7], y=[8]) - + store = DataStore(exp1, sim1, exp2, sim2) - + experiments = store.experiments simulations = store.simulations - + assert len(experiments) == 2 assert len(simulations) == 2 assert all(item.is_experiment for item in experiments) @@ -269,9 +252,9 @@ def test_datastore_append_method(self): """Test append method adds items correctly.""" store = DataStore() item = DataSet1D(name='new_item', x=[1], y=[2]) - + store.append(item) - + assert len(store) == 1 assert store[0] == item @@ -282,7 +265,7 @@ class TestProjectDataComprehensive: def test_project_data_initialization(self): """Test ProjectData initializes with correct default values.""" project = ProjectData() - + assert project.name == 'DataStore' assert isinstance(project.exp_data, DataStore) assert isinstance(project.sim_data, DataStore) @@ -293,13 +276,9 @@ def test_project_data_with_custom_stores(self): """Test ProjectData with custom experiment and simulation stores.""" custom_exp = DataStore(name='CustomExp') custom_sim = DataStore(name='CustomSim') - - project = ProjectData( - name='MyProject', - exp_data=custom_exp, - sim_data=custom_sim - ) - + + project = ProjectData(name='MyProject', exp_data=custom_exp, sim_data=custom_sim) + assert project.name == 'MyProject' assert project.exp_data == custom_exp assert project.sim_data == custom_sim @@ -307,13 +286,13 @@ def test_project_data_with_custom_stores(self): def test_project_data_stores_independence(self): """Test that exp_data and sim_data are independent stores.""" project = ProjectData() - + exp_item = DataSet1D(name='exp', x=[1], y=[2], model=Mock()) sim_item = DataSet1D(name='sim', x=[3], y=[4]) - + project.exp_data.append(exp_item) project.sim_data.append(sim_item) - + assert len(project.exp_data) == 1 assert len(project.sim_data) == 1 assert project.exp_data[0] != project.sim_data[0] @@ -327,11 +306,11 @@ def test_complete_workflow_orso_file(self): # Load file fpath = os.path.join(PATH_STATIC, 'test_example1.ort') dataset = load_as_dataset(fpath) - + # Create project and add to experimental data project = ProjectData(name='MyAnalysis') project.exp_data.append(dataset) - + # Verify workflow assert len(project.exp_data) == 1 assert project.exp_data[0] == dataset @@ -342,11 +321,11 @@ def test_complete_workflow_txt_file(self): # Load file fpath = os.path.join(PATH_STATIC, 'ref_concat_1.txt') dataset = load_as_dataset(fpath) - + # Create project and add to simulation data (no model) project = ProjectData(name='MySimulation') project.sim_data.append(dataset) - + # Verify workflow assert len(project.sim_data) == 1 assert project.sim_data[0] == dataset @@ -357,13 +336,13 @@ def test_merge_multiple_files_workflow(self): # Load multiple files fpath1 = os.path.join(PATH_STATIC, 'test_example1.txt') fpath2 = os.path.join(PATH_STATIC, 'ref_concat_1.txt') - + group1 = load(fpath1) group2 = load(fpath2) - + # Merge data groups merged = merge_datagroups(group1, group2) - + # Create datasets from merged data # This tests that merged data can be used to create datasets assert len(merged['data']) >= 2 @@ -374,13 +353,13 @@ def test_error_handling_robustness(self): # Test mismatched array lengths with pytest.raises(ValueError, match='x and y must be the same length'): DataSet1D(x=[1, 2, 3], y=[4, 5]) - + # Test empty DataStore operations empty_store = DataStore() assert len(empty_store) == 0 assert len(empty_store.experiments) == 0 assert len(empty_store.simulations) == 0 - + # Test file not found with pytest.raises(FileNotFoundError): _load_txt('nonexistent_file.txt') @@ -391,14 +370,14 @@ def test_data_consistency_checks(self): original_x = [1, 2, 3, 4] original_y = [10, 20, 30, 40] dataset = DataSet1D(x=original_x, y=original_y) - + # Store in datastore store = DataStore(dataset) - + # Add to project project = ProjectData() project.sim_data = store - + # Verify data consistency retrieved_dataset = project.sim_data[0] assert_array_equal(retrieved_dataset.x, np.array(original_x)) @@ -407,4 +386,4 @@ def test_data_consistency_checks(self): if __name__ == '__main__': # Run all tests if script is executed directly - pytest.main([__file__, '-v']) \ No newline at end of file + pytest.main([__file__, '-v']) diff --git a/tests/test_orso_utils.py b/tests/test_orso_utils.py index 89dd07db..3dd2ce50 100644 --- a/tests/test_orso_utils.py +++ b/tests/test_orso_utils.py @@ -18,14 +18,14 @@ @pytest.fixture def orso_data(): """Load the test ORSO data from Ni_example.ort.""" - return orso.load_orso(os.path.join(PATH_STATIC, "Ni_example.ort")) + return orso.load_orso(os.path.join(PATH_STATIC, 'Ni_example.ort')) def test_load_orso_model(orso_data): """Test loading a model from ORSO data.""" sample = load_orso_model(orso_data) assert sample is not None - assert sample.name == "Ni on Si" # Based on the file + assert sample.name == 'Ni on Si' # Based on the file def test_load_orso_data(orso_data): @@ -33,7 +33,7 @@ def test_load_orso_data(orso_data): data = load_orso_data(orso_data) assert data is not None # Check structure, e.g., has R_0 in data - assert "R_0" in data["data"] + assert 'R_0' in data['data'] def test_LoadOrso(orso_data): @@ -46,8 +46,9 @@ def test_LoadOrso(orso_data): def test_load_data_from_orso_file(): """Test loading data from ORSO file.""" - data = load_data_from_orso_file(os.path.join(PATH_STATIC, "Ni_example.ort")) + data = load_data_from_orso_file(os.path.join(PATH_STATIC, 'Ni_example.ort')) assert data is not None # Check it's a sc.DataGroup import scipp as sc - assert isinstance(data, sc.DataGroup) \ No newline at end of file + + assert isinstance(data, sc.DataGroup) diff --git a/tests/test_ort_file.py b/tests/test_ort_file.py index ff66c743..fe40e748 100644 --- a/tests/test_ort_file.py +++ b/tests/test_ort_file.py @@ -24,59 +24,57 @@ def make_pooch(base_url: str, registry: dict[str, str | None]) -> pooch.Pooch: """Make a Pooch object to download test data.""" return pooch.create( - path=pooch.os_cache("data"), - env="POOCH_DIR", + path=pooch.os_cache('data'), + env='POOCH_DIR', base_url=base_url, registry=registry, ) -@pytest.fixture(scope="module") +@pytest.fixture(scope='module') def data_registry(): return make_pooch( - base_url="https://pub-6c25ef91903d4301a3338bd53b370098.r2.dev", + base_url='https://pub-6c25ef91903d4301a3338bd53b370098.r2.dev', registry={ - "amor_reduced_iofq.ort": None, + 'amor_reduced_iofq.ort': None, }, ) -@pytest.fixture(scope="module") +@pytest.fixture(scope='module') def load_data(data_registry): - path = data_registry.fetch("amor_reduced_iofq.ort") - logging.info("Loading data from %s", path) + path = data_registry.fetch('amor_reduced_iofq.ort') + logging.info('Loading data from %s', path) data = load(path) return data -@pytest.fixture(scope="module") +@pytest.fixture(scope='module') def fit_model(load_data): data = load_data # Rescale data - reflectivity = data["data"]["R_0"].values + reflectivity = data['data']['R_0'].values scale_factor = 1 / np.max(reflectivity) - data["data"]["R_0"].values *= scale_factor + data['data']['R_0'].values *= scale_factor # Create a model for the sample - si = Material(sld=2.07, isld=0.0, name="Si") - sio2 = Material(sld=3.47, isld=0.0, name="SiO2") - d2o = Material(sld=6.33, isld=0.0, name="D2O") - dlipids = Material(sld=5.0, isld=0.0, name="DLipids") + si = Material(sld=2.07, isld=0.0, name='Si') + sio2 = Material(sld=3.47, isld=0.0, name='SiO2') + d2o = Material(sld=6.33, isld=0.0, name='D2O') + dlipids = Material(sld=5.0, isld=0.0, name='DLipids') - superphase = Layer(material=si, thickness=0, roughness=0, name="Si superphase") - sio2_layer = Layer(material=sio2, thickness=20, roughness=4, name="SiO2 layer") - dlipids_layer = Layer( - material=dlipids, thickness=40, roughness=4, name="DLipids layer" - ) - subphase = Layer(material=d2o, thickness=0, roughness=5, name="D2O subphase") + superphase = Layer(material=si, thickness=0, roughness=0, name='Si superphase') + sio2_layer = Layer(material=sio2, thickness=20, roughness=4, name='SiO2 layer') + dlipids_layer = Layer(material=dlipids, thickness=40, roughness=4, name='DLipids layer') + subphase = Layer(material=d2o, thickness=0, roughness=5, name='D2O subphase') multi_sample = Sample( Multilayer(superphase), Multilayer(sio2_layer), Multilayer(dlipids_layer), Multilayer(subphase), - name="Multilayer Structure", + name='Multilayer Structure', ) multi_layer_model = Model( @@ -84,7 +82,7 @@ def fit_model(load_data): scale=1, background=0.000001, resolution_function=PercentageFwhm(0), - name="Multilayer Model", + name='Multilayer Model', ) # Set the fitting parameters @@ -122,66 +120,60 @@ def fit_model(load_data): def test_read_reduced_data__check_structure(load_data): - data_keys = load_data["data"].keys() - coord_keys = load_data["coords"].keys() + data_keys = load_data['data'].keys() + coord_keys = load_data['coords'].keys() for key in data_keys: if key in coord_keys: - assert len(load_data["data"][key].values) == len( - load_data["coords"][key].values - ) + assert len(load_data['data'][key].values) == len(load_data['coords'][key].values) def test_validate_physical_data__r_values_non_negative(load_data): - for key in load_data["data"].keys(): - assert all(load_data["data"][key].values >= 0) + for key in load_data['data'].keys(): + assert all(load_data['data'][key].values >= 0) def test_validate_physical_data__r_values_finite(load_data): - for key in load_data["data"].keys(): - assert all(np.isfinite(load_data["data"][key].values)) + for key in load_data['data'].keys(): + assert all(np.isfinite(load_data['data'][key].values)) -@pytest.mark.skip("Currently no warning implemented") +@pytest.mark.skip('Currently no warning implemented') def test_validate_physical_data__r_values_ureal_positive(load_data): - a = load_data["data"]["R_0"].values - b = 1 + 2 * np.sqrt(load_data["data"]["R_0"].variances) + a = load_data['data']['R_0'].values + b = 1 + 2 * np.sqrt(load_data['data']['R_0'].variances) for val_a, val_b in zip(a, b): if val_a > val_b: pytest.warns( - UserWarning, - reason=f"Reflectivity value {val_a} is unphysically large compared to its uncertainty {val_b}" + UserWarning, reason=f'Reflectivity value {val_a} is unphysically large compared to its uncertainty {val_b}' ) - assert all( - load_data["data"]["R_0"].values - <= 1 + 2 * np.sqrt(load_data["data"]["R_0"].variances) - ) + assert all(load_data['data']['R_0'].values <= 1 + 2 * np.sqrt(load_data['data']['R_0'].variances)) def test_validate_physical_data__q_values_non_negative(load_data): - for key in load_data["coords"].keys(): - assert all(load_data["coords"][key].values >= 0) + for key in load_data['coords'].keys(): + assert all(load_data['coords'][key].values >= 0) def test_validate_physical_data__q_values_ureal_positive(load_data): - for key in load_data["coords"].keys(): + for key in load_data['coords'].keys(): # Reflectometry data is usually with the range of 0-5, # so 10 is a safe upper limit - assert all(load_data["coords"][key].values < 10) + assert all(load_data['coords'][key].values < 10) def test_validate_physical_data__q_values_finite(load_data): - for key in load_data["coords"].keys(): - assert all(np.isfinite(load_data["coords"][key].values < 10)) + for key in load_data['coords'].keys(): + assert all(np.isfinite(load_data['coords'][key].values < 10)) -@pytest.mark.skip("Currently no meta data to check") +@pytest.mark.skip('Currently no meta data to check') def test_validate_meta_data__required_meta_data() -> None: - pytest.fail(reason="Currently no meta data to check") + pytest.fail(reason='Currently no meta data to check') def test_analyze_reduced_data__fit_model_success(fit_model): - assert fit_model["success"] is True + assert fit_model['success'] is True def test_analyze_reduced_data__fit_model_reasonable(fit_model): - assert fit_model["reduced_chi"] < 0.01 + assert fit_model['reduced_chi'] < 0.01 diff --git a/tests/test_project.py b/tests/test_project.py index b86210ec..587ad601 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -17,8 +17,11 @@ from easyreflectometry.model import PercentageFwhm from easyreflectometry.model import Pointwise from easyreflectometry.project import Project +from easyreflectometry.sample import Layer from easyreflectometry.sample import Material from easyreflectometry.sample import MaterialCollection +from easyreflectometry.sample import Multilayer +from easyreflectometry.sample import Sample PATH_STATIC = os.path.join(os.path.dirname(easyreflectometry.__file__), '..', '..', 'tests', '_static') @@ -576,6 +579,7 @@ def test_create(self, tmp_path): def test_load_experiment(self): # When + global_object.map._clear() project = Project() model_5 = Model() project.models = ModelCollection(Model(), Model(), Model(), Model(), Model(), model_5) @@ -594,6 +598,7 @@ def test_load_experiment(self): def test_experimental_data_at_index(self): # When + global_object.map._clear() project = Project() project.models = ModelCollection(Model()) fpath = os.path.join(PATH_STATIC, 'example.ort') @@ -613,6 +618,7 @@ def test_experimental_data_at_index(self): def test_q(self): # When + global_object.map._clear() project = Project() # Then @@ -700,3 +706,127 @@ def test_current_experiment_index_setter_out_of_range(self): assert False, 'Expected ValueError for out-of-range index' except ValueError: pass + + def test_get_materials_from_model(self): + # When + project = Project() + material_1 = Material(sld=2.07, isld=0.0, name='Material 1') + material_2 = Material(sld=3.47, isld=0.0, name='Material 2') + material_3 = Material(sld=6.36, isld=0.0, name='Material 3') + + layer_1 = Layer(material=material_1, thickness=10, roughness=0, name='Layer 1') + layer_2 = Layer(material=material_2, thickness=20, roughness=1, name='Layer 2') + layer_3 = Layer(material=material_3, thickness=0, roughness=2, name='Layer 3') + + sample = Sample(Multilayer([layer_1, layer_2]), Multilayer([layer_3])) + model = Model(sample=sample) + + # Then + materials = project._get_materials_from_model(model) + + # Expect + assert len(materials) == 3 + assert materials[0] == material_1 + assert materials[1] == material_2 + assert materials[2] == material_3 + + def test_get_materials_from_model_duplicate_materials(self): + # When + project = Project() + # Use the same material in multiple layers + shared_material = Material(sld=2.07, isld=0.0, name='Shared Material') + material_2 = Material(sld=3.47, isld=0.0, name='Material 2') + + layer_1 = Layer(material=shared_material, thickness=10, roughness=0, name='Layer 1') + layer_2 = Layer(material=material_2, thickness=20, roughness=1, name='Layer 2') + layer_3 = Layer(material=shared_material, thickness=30, roughness=2, name='Layer 3') + + sample = Sample(Multilayer([layer_1, layer_2, layer_3])) + model = Model(sample=sample) + + # Then + materials = project._get_materials_from_model(model) + + # Expect - should only include unique materials + assert len(materials) == 2 + assert materials[0] == shared_material + assert materials[1] == material_2 + + def test_add_sample_from_orso(self): + # When + global_object.map._clear() + project = Project() + project.default_model() + + initial_model_count = len(project._models) + initial_material_count = len(project._materials) + + material_1 = Material(sld=4.0, isld=0.0, name='New Material 1') + material_2 = Material(sld=5.0, isld=0.0, name='New Material 2') + layer_1 = Layer(material=material_1, thickness=50, roughness=1, name='New Layer 1') + layer_2 = Layer(material=material_2, thickness=100, roughness=2, name='New Layer 2') + new_sample = Sample(Multilayer([layer_1, layer_2])) + + # Then + project.add_sample_from_orso(new_sample) + + # Expect + assert len(project._models) == initial_model_count + 1 + assert project._models[-1].sample == new_sample + # The interface should be set by add_sample_from_orso + assert project._models[-1].interface == project._calculator + assert len(project._materials) == initial_material_count + 2 + assert material_1 in project._materials + assert material_2 in project._materials + assert project.current_model_index == len(project._models) - 1 + + def test_add_sample_from_orso_multiple_additions(self): + # When + global_object.map._clear() + project = Project() + + material_1 = Material(sld=2.0, isld=0.0, name='Material A') + layer_1 = Layer(material=material_1, thickness=10, roughness=0, name='Layer A') + sample_1 = Sample(Multilayer([layer_1])) + + material_2 = Material(sld=3.0, isld=0.0, name='Material B') + layer_2 = Layer(material=material_2, thickness=20, roughness=1, name='Layer B') + sample_2 = Sample(Multilayer([layer_2])) + + # Then + project.add_sample_from_orso(sample_1) + project.add_sample_from_orso(sample_2) + + # Expect + assert len(project._models) == 2 + assert project._models[0].sample == sample_1 + assert project._models[1].sample == sample_2 + assert len(project._materials) == 2 + assert material_1 in project._materials + assert material_2 in project._materials + assert project.current_model_index == 1 + + def test_add_sample_from_orso_with_shared_materials(self): + # When + global_object.map._clear() + project = Project() + + # Create first sample with a material + shared_material = Material(sld=2.0, isld=0.0, name='Shared Material') + layer_1 = Layer(material=shared_material, thickness=10, roughness=0, name='Layer 1') + sample_1 = Sample(Multilayer([layer_1])) + project.add_sample_from_orso(sample_1) + + initial_material_count = len(project._materials) + + # Create second sample using the same material + layer_2 = Layer(material=shared_material, thickness=20, roughness=1, name='Layer 2') + sample_2 = Sample(Multilayer([layer_2])) + + # Then + project.add_sample_from_orso(sample_2) + + # Expect - shared material should not be duplicated + assert len(project._models) == 2 + # The shared material instance is already in the collection, so count should stay the same + assert len(project._materials) == initial_material_count