Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python: add OverwriteReason field to MagikaPrediction #895

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ semver guidelines for more details about this.
- Upgrade model from `standard_v2_1` to `standard_v3_0`. This should result in a 3x faster inference speed, with the same overall accuracy. This new model should also be ~20% faster than `standard_v1`.
- New API: `get_output_content_types()`. This API returns the list of all possible outputs by the module. I.e., all possible values for `MagikaResult.prediction.output.label`. This is the list that is relevant for most clients.
- New API: `get_model_content_types()`. This API returns the list of all possible outputs of the deep learning model. I.e., all possible values for `MagikaResult.prediction.dl.label`. Note that, in general, the list of "model outputs" is different than the "tool outputs" as in some cases the model is not even used, or the model's output is overwritten due to a low-confidence score, or other reasons. This API is useful mostly for debugging purposes; the vast majority of client should use `get_output_content_types()`.
- `MagikaPrediction` now has an `overwrite_reason` field, specifying why and if the model's prediction was overwritten.

## [0.6.0-rc3] - 2024-11-20

Expand Down
5 changes: 4 additions & 1 deletion python/scripts/magika_python_module_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from magika import Magika, MagikaError, PredictionMode, colors
from magika.logger import get_logger
from magika.types import ContentTypeLabel, MagikaResult
from magika.types.overwrite_reason import OverwriteReason

VERSION = importlib.metadata.version("magika")

Expand Down Expand Up @@ -291,8 +292,10 @@ def main(
result.prediction.dl.label != ContentTypeLabel.UNDEFINED
and result.prediction.dl.label
!= result.prediction.output.label
and result.prediction.overwrite_reason
== OverwriteReason.NONE
):
# it seems that we had a too-low confidence prediction
# It seems that we had a low-confidence prediction
# from the model. Let's warn the user about our best
# bet.
output += (
Expand Down
40 changes: 26 additions & 14 deletions python/src/magika/magika.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ModelConfig,
ModelFeatures,
ModelOutput,
OverwriteReason,
PredictionMode,
Status,
)
Expand Down Expand Up @@ -575,15 +576,18 @@ def _get_results_from_features(
# both the raw DL model output and the final output we return to
# the user.

output_ct_label = self._get_output_ct_label_from_dl_result(
model_output.ct_label, model_output.score
output_ct_label, overwrite_reason = (
self._get_output_ct_label_from_dl_result(
model_output.ct_label, model_output.score
)
)

results[str(path)] = self._get_result_from_labels_and_score(
path=path,
dl_ct_label=model_output.ct_label,
output_ct_label=output_ct_label,
score=model_output.score,
overwrite_reason=overwrite_reason,
)

return results
Expand All @@ -600,13 +604,18 @@ def _get_result_from_features(

def _get_output_ct_label_from_dl_result(
self, dl_ct_label: ContentTypeLabel, score: float
) -> ContentTypeLabel:
# overwrite ct_label if specified in the config
dl_ct_label = self._model_config.overwrite_map.get(dl_ct_label, dl_ct_label)
) -> Tuple[ContentTypeLabel, OverwriteReason]:
overwrite_reason = OverwriteReason.NONE

# Overwrite dl_ct_label if specified in the overwrite_map model config
output_ct_label = self._model_config.overwrite_map.get(dl_ct_label, dl_ct_label)
if output_ct_label != dl_ct_label:
overwrite_reason = OverwriteReason.OVERWRITE_MAP

if self._prediction_mode == PredictionMode.BEST_GUESS:
# We take the model predictions, no matter what the score is.
output_ct_label = dl_ct_label
# We take the (potentially overwritten) model prediction, no matter
# what the score is.
pass
elif (
self._prediction_mode == PredictionMode.HIGH_CONFIDENCE
and score
Expand All @@ -615,41 +624,44 @@ def _get_output_ct_label_from_dl_result(
)
):
# The model score is higher than the per-content-type
# high-confidence threshold.
output_ct_label = dl_ct_label
# high-confidence threshold, so we keep it.
pass
elif (
self._prediction_mode == PredictionMode.MEDIUM_CONFIDENCE
and score >= self._model_config.medium_confidence_threshold
):
# We take the model prediction only if the score is above a given
# relatively loose threshold.
output_ct_label = dl_ct_label
# The model score is higher than the generic medium-confidence
# threshold, so we keep it.
pass
else:
# We are not in a condition to trust the model, we opt to return
# generic labels. Note that here we use an implicit assumption that
# the model has, at the very least, got the binary vs. text category
# right. This allows us to pick between unknown and txt without the
# need to read or scan the file bytes once again.
if self._get_ct_info(dl_ct_label).is_text:
overwrite_reason = OverwriteReason.LOW_CONFIDENCE
if self._get_ct_info(output_ct_label).is_text:
output_ct_label = ContentTypeLabel.TXT
else:
output_ct_label = ContentTypeLabel.UNKNOWN

return output_ct_label
return output_ct_label, overwrite_reason

def _get_result_from_labels_and_score(
self,
path: Path,
dl_ct_label: ContentTypeLabel,
output_ct_label: ContentTypeLabel,
score: float,
overwrite_reason: OverwriteReason = OverwriteReason.NONE,
) -> MagikaResult:
return MagikaResult(
path=path,
prediction=MagikaPrediction(
dl=self._get_ct_info(dl_ct_label),
output=self._get_ct_info(output_ct_label),
score=score,
overwrite_reason=overwrite_reason,
),
)

Expand Down
2 changes: 2 additions & 0 deletions python/src/magika/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ModelFeatures,
ModelOutput,
)
from magika.types.overwrite_reason import OverwriteReason # noqa: F401
from magika.types.prediction_mode import PredictionMode # noqa: F401
from magika.types.status import Status # noqa: F401

Expand All @@ -35,6 +36,7 @@
"ModelConfig",
"ModelFeatures",
"ModelOutput",
"OverwriteReason",
"PredictionMode",
"Status",
]
2 changes: 2 additions & 0 deletions python/src/magika/types/magika_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from dataclasses import dataclass

from magika.types.content_type_info import ContentTypeInfo
from magika.types.overwrite_reason import OverwriteReason


@dataclass(frozen=True)
class MagikaPrediction:
dl: ContentTypeInfo
output: ContentTypeInfo
score: float
overwrite_reason: OverwriteReason
24 changes: 24 additions & 0 deletions python/src/magika/types/overwrite_reason.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import enum

from magika.types.strenum import LowerCaseStrEnum


class OverwriteReason(LowerCaseStrEnum):
NONE = enum.auto()
LOW_CONFIDENCE = enum.auto()
OVERWRITE_MAP = enum.auto()
112 changes: 49 additions & 63 deletions python/tests/test_magika_python_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
MagikaResult,
Status,
)
from magika.types.overwrite_reason import OverwriteReason
from tests import utils


Expand Down Expand Up @@ -185,92 +186,77 @@ def test_magika_module_with_python_and_non_python_content() -> None:
def test_magika_module_with_different_prediction_modes() -> None:
model_dir = utils.get_default_model_dir()
m = Magika(model_dir=model_dir, prediction_mode=PredictionMode.BEST_GUESS)
assert (
m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.01)
== ContentTypeLabel.PYTHON
assert m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.01) == (
ContentTypeLabel.PYTHON,
OverwriteReason.NONE,
)
assert (
m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.40)
== ContentTypeLabel.PYTHON
assert m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.40) == (
ContentTypeLabel.PYTHON,
OverwriteReason.NONE,
)
assert (
m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.60)
== ContentTypeLabel.PYTHON
assert m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.60) == (
ContentTypeLabel.PYTHON,
OverwriteReason.NONE,
)
assert (
m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.99)
== ContentTypeLabel.PYTHON
assert m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.99) == (
ContentTypeLabel.PYTHON,
OverwriteReason.NONE,
)

m = Magika(model_dir=model_dir, prediction_mode=PredictionMode.MEDIUM_CONFIDENCE)
assert (
m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.01)
== ContentTypeLabel.TXT
)
assert (
m._get_output_ct_label_from_dl_result(
ContentTypeLabel.PYTHON, m._model_config.medium_confidence_threshold - 0.01
)
== ContentTypeLabel.TXT
assert m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.01) == (
ContentTypeLabel.TXT,
OverwriteReason.LOW_CONFIDENCE,
)
assert (
m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.60)
== ContentTypeLabel.PYTHON
assert m._get_output_ct_label_from_dl_result(
ContentTypeLabel.PYTHON, m._model_config.medium_confidence_threshold - 0.01
) == (ContentTypeLabel.TXT, OverwriteReason.LOW_CONFIDENCE)
assert m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.60) == (
ContentTypeLabel.PYTHON,
OverwriteReason.NONE,
)
assert (
m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.99)
== ContentTypeLabel.PYTHON
assert m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.99) == (
ContentTypeLabel.PYTHON,
OverwriteReason.NONE,
)

m = Magika(model_dir=model_dir, prediction_mode=PredictionMode.HIGH_CONFIDENCE)
high_confidence_threshold = m._model_config.thresholds.get(
ContentTypeLabel.PYTHON, m._model_config.medium_confidence_threshold
)
assert (
m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.01)
== ContentTypeLabel.TXT
)
assert (
m._get_output_ct_label_from_dl_result(
ContentTypeLabel.PYTHON, high_confidence_threshold - 0.01
)
== ContentTypeLabel.TXT
)
assert (
m._get_output_ct_label_from_dl_result(
ContentTypeLabel.PYTHON, high_confidence_threshold + 0.01
)
== ContentTypeLabel.PYTHON
assert m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.01) == (
ContentTypeLabel.TXT,
OverwriteReason.LOW_CONFIDENCE,
)
assert (
m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.99)
== ContentTypeLabel.PYTHON
assert m._get_output_ct_label_from_dl_result(
ContentTypeLabel.PYTHON, high_confidence_threshold - 0.01
) == (ContentTypeLabel.TXT, OverwriteReason.LOW_CONFIDENCE)
assert m._get_output_ct_label_from_dl_result(
ContentTypeLabel.PYTHON, high_confidence_threshold + 0.01
) == (ContentTypeLabel.PYTHON, OverwriteReason.NONE)
assert m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.99) == (
ContentTypeLabel.PYTHON,
OverwriteReason.NONE,
)

# test that the default is HIGH_CONFIDENCE
m = Magika(model_dir=model_dir)
high_confidence_threshold = m._model_config.thresholds.get(
ContentTypeLabel.PYTHON, m._model_config.medium_confidence_threshold
)
assert (
m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.01)
== ContentTypeLabel.TXT
)
assert (
m._get_output_ct_label_from_dl_result(
ContentTypeLabel.PYTHON, high_confidence_threshold - 0.01
)
== ContentTypeLabel.TXT
)
assert (
m._get_output_ct_label_from_dl_result(
ContentTypeLabel.PYTHON, high_confidence_threshold + 0.01
)
== ContentTypeLabel.PYTHON
assert m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.01) == (
ContentTypeLabel.TXT,
OverwriteReason.LOW_CONFIDENCE,
)
assert (
m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.99)
== ContentTypeLabel.PYTHON
assert m._get_output_ct_label_from_dl_result(
ContentTypeLabel.PYTHON, high_confidence_threshold - 0.01
) == (ContentTypeLabel.TXT, OverwriteReason.LOW_CONFIDENCE)
assert m._get_output_ct_label_from_dl_result(
ContentTypeLabel.PYTHON, high_confidence_threshold + 0.01
) == (ContentTypeLabel.PYTHON, OverwriteReason.NONE)
assert m._get_output_ct_label_from_dl_result(ContentTypeLabel.PYTHON, 0.99) == (
ContentTypeLabel.PYTHON,
OverwriteReason.NONE,
)


Expand Down
Loading