Skip to content

Commit

Permalink
Merge pull request stanford-crfm#1282 from stanford-crfm/yifanmai/944…
Browse files Browse the repository at this point in the history
…-copy-cache-script

Add copy_cache script
  • Loading branch information
teetone authored Jan 2, 2023
2 parents 84ed754 + 231daa4 commit b356bdb
Showing 1 changed file with 141 additions and 0 deletions.
141 changes: 141 additions & 0 deletions scripts/cache/copy_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Utility for copying caches from SQLite to MongoDB.
Example usage:
python3 scripts/cache/copy_cache.py --organization openai \
prod_env/cache/ mongodb://username:password@mongodbhost/crfm-models
python3 scripts/cache/copy_cache.py --all \
prod_env/cache/ mongodb://username:password@mongodbhost/crfm-models
"""

import argparse
import json
import os

from sqlitedict import SqliteDict
from helm.common.cache import _MongoKeyValueStore
from helm.common.hierarchical_logger import hlog, htrack
from typing import Optional


_SQLITE_FILE_SUFFIX = ".sqlite"


@htrack("Copying all caches")
def copy_all_caches(cache_dir: str, mongo_host: str, dry_run: bool):
hlog(f"Opening Sqlite dir {cache_dir}")
with os.scandir(cache_dir) as it:
for entry in it:
if entry.name.endswith(_SQLITE_FILE_SUFFIX) and entry.is_file():
organization = entry.name[: -len(_SQLITE_FILE_SUFFIX)]
copy_cache(
cache_dir=cache_dir,
mongo_host=mongo_host,
organization=organization,
dry_run=dry_run,
)


@htrack("Copying single cache")
def copy_cache(
cache_dir: str,
mongo_host: str,
organization: str,
dry_run: bool,
range_start: Optional[int] = None,
range_end: Optional[int] = None,
):
if dry_run:
hlog("Dry run mode, skipping writing to mongo")
if range_start:
hlog(f"Start of range: {range_start}")
if range_end:
hlog(f"End of range: {range_end}")
num_items = 0
num_written = 0
num_skipped = 0
num_failed = 0
cache_path = os.path.join(cache_dir, f"{organization}.sqlite")
hlog(f"Opening Sqlite cache {cache_path}")
with SqliteDict(cache_path) as source_cache:
hlog(f"Copying to MongoDB {mongo_host}")
with _MongoKeyValueStore(mongo_host, collection_name=organization) as target_cache:
for key, value in source_cache.items():
if not dry_run and (not range_start or num_items >= range_start):
try:
target_cache.put(json.loads(key), value)
num_written += 1
except Exception:
num_failed += 1
else:
num_skipped += 1
num_items += 1
if num_items % 1000 == 0:
hlog(f"Processed {num_items} items so far")
hlog(
f"Copied {num_written} and skipped {num_skipped} and "
+ f"failed {num_failed} items from {cache_path} so far"
)
if range_end and num_items >= range_end:
break

hlog(f"Processed {num_items} total items from {cache_path}")
hlog(
f"Copied {num_written} and skipped {num_skipped} and failed "
+ f"{num_failed} total items from {cache_path}"
)
hlog(f"Finished copying Sqlite cache {cache_path} to MongoDB {mongo_host}")


def main():
parser = argparse.ArgumentParser(description="Copy items from Sqlite to mongo")
parser.add_argument("cache_dir", type=str, help="Directory for the .sqlite files")
parser.add_argument("mongo_host", type=str, help="Mongo host to copy items to")
parser.add_argument("--organization", type=str, help="Organization to copy cache for")
parser.add_argument("--range-start", type=int, help="The start of the range to copy")
parser.add_argument("--range-end", type=int, help="The end of the range to copy (exclusive)")
parser.add_argument(
"--all",
action="store_true",
default=None,
help="Copy caches for all organizations",
)
parser.add_argument(
"--bulk-write",
action="store_true",
default=None,
help="Uses bulk writes",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=None,
help="Skips actually writing to mongo",
)
args = parser.parse_args()

if (args.range_start or args.range_end) and not args.organization:
raise ValueError("--range_start and --range_end require --organization to be specified")

if args.all:
copy_all_caches(
cache_dir=args.cache_dir,
mongo_host=args.mongo_host,
dry_run=bool(args.dry_run),
)
elif args.organization:
copy_cache(
cache_dir=args.cache_dir,
mongo_host=args.mongo_host,
organization=args.organization,
dry_run=bool(args.dry_run),
range_start=args.range_start,
range_end=args.range_end,
)
else:
raise ValueError("Either --all or --organization must be specified")


if __name__ == "__main__":
main()

0 comments on commit b356bdb

Please sign in to comment.