Skip to content

Commit

Permalink
[cleanup] minor, moving a util from core xformers to LRA
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux authored Jun 3, 2022
1 parent 52d1dd0 commit 34addd8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 25 deletions.
2 changes: 1 addition & 1 deletion xformers/benchmarks/LRA/run_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@

from xformers.benchmarks.LRA.code.dataset import LRADataset
from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual
from xformers.benchmarks.utils import temp_files_ctx
from xformers.components.attention import ATTENTION_REGISTRY
from xformers.utils import temp_files_ctx


class Task(str, Enum):
Expand Down
25 changes: 24 additions & 1 deletion xformers/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import logging
import os
import tempfile
from collections import namedtuple
from typing import Any, Dict, List
from typing import Any, Dict, Generator, List

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -179,3 +182,23 @@ def pretty_barplot(results, title, units: str, filename=None, dash_key=""):

plt.savefig(filename, bbox_inches="tight")
plt.close(f)


def rmf(filename: str) -> None:
"""Remove a file like rm -f."""
try:
os.remove(filename)
except FileNotFoundError:
pass


@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
"""A context to get tempfiles and ensure they are cleaned up."""
files = [tempfile.mkstemp()[1] for _ in range(num)]

yield tuple(files)

# temp files could have been removed, so we use rmf.
for name in files:
rmf(name)
24 changes: 1 addition & 23 deletions xformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
# LICENSE file in the root directory of this source tree.


import contextlib
import importlib
import os
import sys
import tempfile
from collections import namedtuple
from dataclasses import fields
from typing import Any, Callable, Dict, Generator, List
from typing import Any, Callable, Dict, List

Item = namedtuple("Item", ["constructor", "config"])

Expand Down Expand Up @@ -79,23 +77,3 @@ def generate_matching_config(superset: Dict[str, Any], config_class: Any) -> Any
subset[k] = None

return config_class(**subset)


def rmf(filename: str) -> None:
"""Remove a file like rm -f."""
try:
os.remove(filename)
except FileNotFoundError:
pass


@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
"""A context to get tempfiles and ensure they are cleaned up."""
files = [tempfile.mkstemp()[1] for _ in range(num)]

yield tuple(files)

# temp files could have been removed, so we use rmf.
for name in files:
rmf(name)

0 comments on commit 34addd8

Please sign in to comment.