Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include SQLite3 schemas #52

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 116 additions & 14 deletions files_to_prompt/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import sqlite3
from fnmatch import fnmatch

import click
Expand All @@ -21,6 +22,10 @@
"yml": "yaml",
"sh": "bash",
"rb": "ruby",
"sql": "sql",
"sqlite": "sql",
"sqlite3": "sql",
"db": "sql",
}


Expand Down Expand Up @@ -98,6 +103,58 @@ def print_as_markdown(writer, path, content, line_numbers):
writer(f"{backticks}")


def is_sqlite3_file(file_path):
"""Check if the given file is a SQLite3 database."""
try:
# Read the first 16 bytes to check for SQLite3 header
with open(file_path, "rb") as f:
header = f.read(16)
return header[:16].startswith(b'SQLite format 3')
except (IOError, OSError):
return False

def get_sqlite_schema(file_path):
"""Extract schema information from a SQLite3 database file."""
try:
conn = sqlite3.connect(file_path)
cursor = conn.cursor()

# Get tables schema
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table' ORDER BY name")
tables = cursor.fetchall()

# Get views schema
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='view' ORDER BY name")
views = cursor.fetchall()

# Get indexes schema
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='index' AND name NOT LIKE 'sqlite_%' ORDER BY name")
indexes = cursor.fetchall()

# Format the results
schema_parts = []

if tables:
schema_parts.append("-- Tables")
for table_name, table_sql in tables:
schema_parts.append(f"{table_sql};")

if views:
schema_parts.append("\n-- Views")
for view_name, view_sql in views:
schema_parts.append(f"{view_sql};")

if indexes:
schema_parts.append("\n-- Indexes")
for idx_name, idx_sql in indexes:
schema_parts.append(f"{idx_sql};")

conn.close()
return "\n".join(schema_parts)

except sqlite3.Error as e:
return f"Error extracting schema: {str(e)}"

def process_path(
path,
extensions,
Expand All @@ -110,14 +167,24 @@ def process_path(
claude_xml,
markdown,
line_numbers=False,
extract_sqlite=False,
):
if os.path.isfile(path):
try:
with open(path, "r") as f:
print_path(writer, path, f.read(), claude_xml, markdown, line_numbers)
except UnicodeDecodeError:
warning_message = f"Warning: Skipping file {path} due to UnicodeDecodeError"
click.echo(click.style(warning_message, fg="red"), err=True)
if extract_sqlite and is_sqlite3_file(path):
try:
schema = get_sqlite_schema(path)
content = f"-- SQLite3 Database Schema\n{schema}"
print_path(writer, path, content, claude_xml, markdown, line_numbers)
except Exception as e:
warning_message = f"Warning: Error processing SQLite file {path}: {str(e)}"
click.echo(click.style(warning_message, fg="red"), err=True)
else:
try:
with open(path, "r") as f:
print_path(writer, path, f.read(), claude_xml, markdown, line_numbers)
except UnicodeDecodeError:
warning_message = f"Warning: Skipping file {path} due to UnicodeDecodeError"
click.echo(click.style(warning_message, fg="red"), err=True)
elif os.path.isdir(path):
for root, dirs, files in os.walk(path):
if not include_hidden:
Expand Down Expand Up @@ -155,21 +222,37 @@ def process_path(

for file in sorted(files):
file_path = os.path.join(root, file)
try:
with open(file_path, "r") as f:
if extract_sqlite and is_sqlite3_file(file_path):
try:
schema = get_sqlite_schema(file_path)
content = f"-- SQLite3 Database Schema\n{schema}"
print_path(
writer,
file_path,
f.read(),
content,
claude_xml,
markdown,
line_numbers,
)
except UnicodeDecodeError:
warning_message = (
f"Warning: Skipping file {file_path} due to UnicodeDecodeError"
)
click.echo(click.style(warning_message, fg="red"), err=True)
except Exception as e:
warning_message = f"Warning: Error processing SQLite file {file_path}: {str(e)}"
click.echo(click.style(warning_message, fg="red"), err=True)
else:
try:
with open(file_path, "r") as f:
print_path(
writer,
file_path,
f.read(),
claude_xml,
markdown,
line_numbers,
)
except UnicodeDecodeError:
warning_message = (
f"Warning: Skipping file {file_path} due to UnicodeDecodeError"
)
click.echo(click.style(warning_message, fg="red"), err=True)


def read_paths_from_stdin(use_null_separator):
Expand Down Expand Up @@ -244,6 +327,12 @@ def read_paths_from_stdin(use_null_separator):
is_flag=True,
help="Use NUL character as separator when reading from stdin",
)
@click.option(
"extract_sqlite",
"--extract-sqlite",
is_flag=True,
help="Extract schema information from SQLite3 database files instead of treating them as binary",
)
@click.version_option()
def cli(
paths,
Expand All @@ -257,6 +346,7 @@ def cli(
markdown,
line_numbers,
null,
extract_sqlite,
):
"""
Takes one or more paths to files or directories and outputs every file,
Expand Down Expand Up @@ -291,6 +381,17 @@ def cli(
```python
Contents of file1.py
```

If the `--extract-sqlite` flag is provided, SQLite3 database files will have their schema
extracted instead of being skipped as binary files:

\b
path/to/database.sqlite
---
-- SQLite3 Database Schema
-- Tables
CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, email TEXT);
---
"""
# Reset global_index for pytest
global global_index
Expand Down Expand Up @@ -327,6 +428,7 @@ def cli(
claude_xml,
markdown,
line_numbers,
extract_sqlite,
)
if claude_xml:
writer("</documents>")
Expand Down