diff --git a/CHANGELOG.md b/CHANGELOG.md index 76378e76b..125bf72f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ - "How to create STAC catalogs" tutorial ([#775](https://github.com/stac-utils/pystac/pull/775)) - Add a `variables` argument, to accompany `dimensions`, for the `apply` method of stac objects extended with datacube ([#782](https://github.com/stac-utils/pystac/pull/782)) +- Deepcopy collection properties on clone. Implement `clone` method for `Summaries` ([#794](https://github.com/stac-utils/pystac/pull/794)) ## [v1.4.0] diff --git a/pystac/catalog.py b/pystac/catalog.py index 95f268df7..c9ab7114d 100644 --- a/pystac/catalog.py +++ b/pystac/catalog.py @@ -517,7 +517,7 @@ def clone(self) -> "Catalog": id=self.id, description=self.description, title=self.title, - stac_extensions=self.stac_extensions, + stac_extensions=self.stac_extensions.copy(), extra_fields=deepcopy(self.extra_fields), catalog_type=self.catalog_type, ) diff --git a/pystac/collection.py b/pystac/collection.py index af99a4320..069367711 100644 --- a/pystac/collection.py +++ b/pystac/collection.py @@ -562,13 +562,13 @@ def clone(self) -> "Collection": description=self.description, extent=self.extent.clone(), title=self.title, - stac_extensions=self.stac_extensions, - extra_fields=self.extra_fields, + stac_extensions=self.stac_extensions.copy(), + extra_fields=deepcopy(self.extra_fields), catalog_type=self.catalog_type, license=self.license, - keywords=self.keywords, - providers=self.providers, - summaries=self.summaries, + keywords=self.keywords.copy() if self.keywords is not None else None, + providers=deepcopy(self.providers), + summaries=self.summaries.clone(), ) clone._resolved_objects.cache(clone) diff --git a/pystac/summaries.py b/pystac/summaries.py index 9652b4b7b..f7850628c 100644 --- a/pystac/summaries.py +++ b/pystac/summaries.py @@ -1,3 +1,4 @@ +from copy import deepcopy import sys import numbers from enum import Enum @@ -287,6 +288,21 @@ def is_empty(self) -> bool: any(self.lists) or any(self.ranges) or any(self.schemas) or any(self.other) ) + def clone(self) -> "Summaries": + """Clones this object. + + Returns: + Summaries: The clone of this object + """ + summaries = Summaries( + summaries=deepcopy(self._summaries), maxcount=self.maxcount + ) + summaries.lists = deepcopy(self.lists) + summaries.other = deepcopy(self.other) + summaries.ranges = deepcopy(self.ranges) + summaries.schemas = deepcopy(self.schemas) + return summaries + def to_dict(self) -> Dict[str, Any]: return { **{k: v for k, v in self.lists.items() if len(v) < self.maxcount}, diff --git a/tests/extensions/test_scientific.py b/tests/extensions/test_scientific.py index 0d881d2f7..9edf3c3b3 100644 --- a/tests/extensions/test_scientific.py +++ b/tests/extensions/test_scientific.py @@ -452,7 +452,7 @@ def test_set_doi_summaries(self) -> None: sci_summaries = ScientificExtension.summaries(collection) sci_summaries.doi = [PUB2_DOI] - new_dois = ScientificExtension.summaries(self.collection).doi + new_dois = ScientificExtension.summaries(collection).doi assert new_dois is not None self.assertListEqual([PUB2_DOI], new_dois) diff --git a/tests/test_collection.py b/tests/test_collection.py index 2a307c38f..61e4dc415 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -80,6 +80,19 @@ def test_clone_uses_previous_catalog_type(self) -> None: clone = catalog.clone() self.assertEqual(clone.catalog_type, CatalogType.SELF_CONTAINED) + def test_clone_cant_mutate_original(self) -> None: + collection = TestCases.test_case_8() + assert collection.keywords is not None + self.assertListEqual(collection.keywords, ["disaster", "open"]) + clone = collection.clone() + clone.extra_fields["test"] = "extra" + self.assertNotIn("test", collection.extra_fields) + assert clone.keywords is not None + clone.keywords.append("clone") + self.assertListEqual(clone.keywords, ["disaster", "open", "clone"]) + self.assertListEqual(collection.keywords, ["disaster", "open"]) + self.assertNotEqual(id(collection.summaries), id(clone.summaries)) + def test_multiple_extents(self) -> None: cat1 = TestCases.test_case_1() country = cat1.get_child("country-1") diff --git a/tests/test_summaries.py b/tests/test_summaries.py index c542da203..adb06daad 100644 --- a/tests/test_summaries.py +++ b/tests/test_summaries.py @@ -59,6 +59,17 @@ def test_summary_not_empty(self) -> None: summaries = Summarizer().summarize(coll.get_all_items()) self.assertFalse(summaries.is_empty()) + def test_clone_summary(self) -> None: + coll = TestCases.test_case_5() + summaries = Summarizer().summarize(coll.get_all_items()) + summaries_dict = summaries.to_dict() + self.assertEqual(len(summaries_dict["eo:bands"]), 4) + self.assertEqual(len(summaries_dict["proj:epsg"]), 1) + clone = summaries.clone() + self.assertTrue(isinstance(clone, Summaries)) + clone_dict = clone.to_dict() + self.assertDictEqual(clone_dict, summaries_dict) + class RangeSummaryTest(unittest.TestCase): def setUp(self) -> None: