Skip to content

Commit

Permalink
fix: detect incompatible case_sensitive+preserve_case instances
Browse files Browse the repository at this point in the history
  • Loading branch information
joanise committed Nov 15, 2023
1 parent 3811a9e commit d768d74
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
13 changes: 11 additions & 2 deletions g2p/mappings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ class _MappingModelDefinition(BaseModel):
"""List of case equivalencies for preserve_case that are not already in the Unicode standard"""

preserve_case: bool = False
"""Preserve source case in output"""
"""Preserve source case in output (requires case_sensitive=False)"""

escape_special: bool = False
"""Escape special characters in rules"""
Expand Down Expand Up @@ -731,7 +731,7 @@ def check_mapping_types(self) -> "_MappingModelDefinition":
and not self.rules
and self.rules_path is None
):
LOGGER.warn(
LOGGER.warning(
exceptions.MalformedMapping(
"You have to either specify some rules or a path to a file containing rules."
)
Expand Down Expand Up @@ -773,6 +773,15 @@ def validate_case_equivalencies(cls, v):
)
return v

@model_validator(mode="after")
def validate_preserve_case(self):
"""preserve_case=True requires case_sensitive=False"""
if self.preserve_case and self.case_sensitive:
raise exceptions.MalformedMapping(
"Sorry, preserve_case=True requires case_sensitive=False."
)
return self

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("rules_path", "abbreviations_path", "alignments_path", pre=True)
Expand Down
3 changes: 3 additions & 0 deletions g2p/tests/test_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ def test_case_preservation(self):
case_equivalencies={"λ": "\u2144\u2144\u2144"},
)

with self.assertRaises(MalformedMapping):
_ = Mapping(rules=[], case_sensitive=True, preserve_case=True)

def test_normalize_edges(self):
# Remove non-deletion edges with the same index as deletions
bad_edges = [
Expand Down

0 comments on commit d768d74

Please sign in to comment.