diff --git a/files_to_prompt/cli.py b/files_to_prompt/cli.py index 7eee04f..239c979 100644 --- a/files_to_prompt/cli.py +++ b/files_to_prompt/cli.py @@ -1,5 +1,6 @@ import os import sys +import sqlite3 from fnmatch import fnmatch import click @@ -21,6 +22,10 @@ "yml": "yaml", "sh": "bash", "rb": "ruby", + "sql": "sql", + "sqlite": "sql", + "sqlite3": "sql", + "db": "sql", } @@ -98,6 +103,58 @@ def print_as_markdown(writer, path, content, line_numbers): writer(f"{backticks}") +def is_sqlite3_file(file_path): + """Check if the given file is a SQLite3 database.""" + try: + # Read the first 16 bytes to check for SQLite3 header + with open(file_path, "rb") as f: + header = f.read(16) + return header[:16].startswith(b'SQLite format 3') + except (IOError, OSError): + return False + +def get_sqlite_schema(file_path): + """Extract schema information from a SQLite3 database file.""" + try: + conn = sqlite3.connect(file_path) + cursor = conn.cursor() + + # Get tables schema + cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table' ORDER BY name") + tables = cursor.fetchall() + + # Get views schema + cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='view' ORDER BY name") + views = cursor.fetchall() + + # Get indexes schema + cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='index' AND name NOT LIKE 'sqlite_%' ORDER BY name") + indexes = cursor.fetchall() + + # Format the results + schema_parts = [] + + if tables: + schema_parts.append("-- Tables") + for table_name, table_sql in tables: + schema_parts.append(f"{table_sql};") + + if views: + schema_parts.append("\n-- Views") + for view_name, view_sql in views: + schema_parts.append(f"{view_sql};") + + if indexes: + schema_parts.append("\n-- Indexes") + for idx_name, idx_sql in indexes: + schema_parts.append(f"{idx_sql};") + + conn.close() + return "\n".join(schema_parts) + + except sqlite3.Error as e: + return f"Error extracting schema: {str(e)}" + def process_path( path, extensions, @@ -110,14 +167,24 @@ def process_path( claude_xml, markdown, line_numbers=False, + extract_sqlite=False, ): if os.path.isfile(path): - try: - with open(path, "r") as f: - print_path(writer, path, f.read(), claude_xml, markdown, line_numbers) - except UnicodeDecodeError: - warning_message = f"Warning: Skipping file {path} due to UnicodeDecodeError" - click.echo(click.style(warning_message, fg="red"), err=True) + if extract_sqlite and is_sqlite3_file(path): + try: + schema = get_sqlite_schema(path) + content = f"-- SQLite3 Database Schema\n{schema}" + print_path(writer, path, content, claude_xml, markdown, line_numbers) + except Exception as e: + warning_message = f"Warning: Error processing SQLite file {path}: {str(e)}" + click.echo(click.style(warning_message, fg="red"), err=True) + else: + try: + with open(path, "r") as f: + print_path(writer, path, f.read(), claude_xml, markdown, line_numbers) + except UnicodeDecodeError: + warning_message = f"Warning: Skipping file {path} due to UnicodeDecodeError" + click.echo(click.style(warning_message, fg="red"), err=True) elif os.path.isdir(path): for root, dirs, files in os.walk(path): if not include_hidden: @@ -155,21 +222,37 @@ def process_path( for file in sorted(files): file_path = os.path.join(root, file) - try: - with open(file_path, "r") as f: + if extract_sqlite and is_sqlite3_file(file_path): + try: + schema = get_sqlite_schema(file_path) + content = f"-- SQLite3 Database Schema\n{schema}" print_path( writer, file_path, - f.read(), + content, claude_xml, markdown, line_numbers, ) - except UnicodeDecodeError: - warning_message = ( - f"Warning: Skipping file {file_path} due to UnicodeDecodeError" - ) - click.echo(click.style(warning_message, fg="red"), err=True) + except Exception as e: + warning_message = f"Warning: Error processing SQLite file {file_path}: {str(e)}" + click.echo(click.style(warning_message, fg="red"), err=True) + else: + try: + with open(file_path, "r") as f: + print_path( + writer, + file_path, + f.read(), + claude_xml, + markdown, + line_numbers, + ) + except UnicodeDecodeError: + warning_message = ( + f"Warning: Skipping file {file_path} due to UnicodeDecodeError" + ) + click.echo(click.style(warning_message, fg="red"), err=True) def read_paths_from_stdin(use_null_separator): @@ -244,6 +327,12 @@ def read_paths_from_stdin(use_null_separator): is_flag=True, help="Use NUL character as separator when reading from stdin", ) +@click.option( + "extract_sqlite", + "--extract-sqlite", + is_flag=True, + help="Extract schema information from SQLite3 database files instead of treating them as binary", +) @click.version_option() def cli( paths, @@ -257,6 +346,7 @@ def cli( markdown, line_numbers, null, + extract_sqlite, ): """ Takes one or more paths to files or directories and outputs every file, @@ -291,6 +381,17 @@ def cli( ```python Contents of file1.py ``` + + If the `--extract-sqlite` flag is provided, SQLite3 database files will have their schema + extracted instead of being skipped as binary files: + + \b + path/to/database.sqlite + --- + -- SQLite3 Database Schema + -- Tables + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, email TEXT); + --- """ # Reset global_index for pytest global global_index @@ -327,6 +428,7 @@ def cli( claude_xml, markdown, line_numbers, + extract_sqlite, ) if claude_xml: writer("")