Skip to content

Commit

Permalink
Add fail_if_symlink to fns.io functions (openvinotoolkit#3150)
Browse files Browse the repository at this point in the history
### Changes 

- Add `fail_if_symlink` check to check symbolic links before load
statistics files
- Remove try block for `fns.io.save_file` in dump_statistics

### Reason for changes

Prevent problems with symlinks

---------

Co-authored-by: Alexander Dokuchaev <[email protected]>
  • Loading branch information
kshpv and AlexanderDokuchaev authored Dec 16, 2024
1 parent 5a55a7d commit c479989
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
6 changes: 1 addition & 5 deletions nncf/common/tensor_statistics/statistics_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,7 @@ def dump_statistics(

# Update the mapping
metadata["mapping"][unique_sanitized_name] = original_name

try:
fns.io.save_file(statistics_value, file_path)
except Exception as e:
raise nncf.InternalError(f"Failed to write data to file {file_path}: {e}")
fns.io.save_file(statistics_value, file_path)

if additional_metadata:
metadata |= additional_metadata
Expand Down
3 changes: 3 additions & 0 deletions nncf/tensor/functions/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pathlib import Path
from typing import Dict, Optional

from nncf.common.utils.os import fail_if_symlink
from nncf.tensor import Tensor
from nncf.tensor.definitions import TensorBackend
from nncf.tensor.definitions import TensorDeviceType
Expand All @@ -35,6 +36,7 @@ def load_file(
then the default device is determined by backend.
:return: A dictionary where the keys are tensor names and the values are Tensor objects.
"""
fail_if_symlink(file_path)
loaded_dict = get_io_backend_fn("load_file", backend)(file_path, device=device)
return {key: Tensor(val) for key, val in loaded_dict.items()}

Expand All @@ -50,6 +52,7 @@ def save_file(
:param data: A dictionary where the keys are tensor names and the values are Tensor objects.
:param file_path: The path to the file where the tensor data will be saved.
"""
fail_if_symlink(file_path)
if isinstance(data, dict):
return dispatch_dict(save_file, data, file_path)
raise NotImplementedError(f"Function `save_file` is not implemented for {type(data)}")
15 changes: 15 additions & 0 deletions tests/cross_fw/test_templates/template_test_nncf_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import pytest

import nncf
import nncf.tensor.functions as fns
from nncf.experimental.common.tensor_statistics import statistical_functions as s_fns
from nncf.tensor import Tensor
Expand Down Expand Up @@ -1709,6 +1710,20 @@ def test_save_load_file(self, tmp_path, data):
assert loaded_stat[tensor_key].device == tensor.device
assert loaded_stat[tensor_key].dtype == tensor.dtype

def test_save_load_symlink_error(self, tmp_path):
file_path = tmp_path / "test_tensor"
symlink_path = tmp_path / "symlink_test_tensor"
symlink_path.symlink_to(file_path)

tensor_key = "tensor_key"
tensor = Tensor(self.to_tensor([1, 2]))
stat = {tensor_key: tensor}

with pytest.raises(nncf.ValidationError, match="symbolic link"):
fns.io.save_file(stat, symlink_path)
with pytest.raises(nncf.ValidationError, match="symbolic link"):
fns.io.load_file(symlink_path, backend=tensor.backend)

@pytest.mark.parametrize("data", [[3.0, 2.0, 2.0], [1, 2, 3]])
@pytest.mark.parametrize("dtype", [TensorDataType.float32, TensorDataType.int32, TensorDataType.uint8, None])
def test_fn_tensor(self, data, dtype):
Expand Down

0 comments on commit c479989

Please sign in to comment.