From a1bacdc11eb2c3099ad7a802874871c613d86a55 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 22 Aug 2025 15:45:38 +0530 Subject: [PATCH] FEAT: Adding Columns API in cursor --- mssql_python/cursor.py | 116 +++++++ mssql_python/pybind/ddbc_bindings.cpp | 54 +++- mssql_python/pybind/ddbc_bindings.h | 4 + tests/test_004_cursor.py | 448 ++++++++++++++++++++++++++ 4 files changed, 621 insertions(+), 1 deletion(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 59935411..f4b86624 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -1348,6 +1348,122 @@ def statistics(self, table, catalog=None, schema=None, unique=False, quick=True) result_rows.append(row) return result_rows + + def columns(self, table=None, catalog=None, schema=None, column=None): + """ + Creates a result set of column information in the specified tables + using the SQLColumns function. + + Args: + table (str, optional): The table name pattern. Default is None (all tables). + catalog (str, optional): The catalog name. Default is None. + schema (str, optional): The schema name pattern. Default is None. + column (str, optional): The column name pattern. Default is None (all columns). + + Returns: + list: A list of Row objects containing column information with these columns: + - table_cat: Catalog name + - table_schem: Schema name + - table_name: Table name + - column_name: Column name + - data_type: SQL data type code + - type_name: Data source-dependent type name + - column_size: Column size + - buffer_length: Transfer size in bytes + - decimal_digits: Number of decimal digits + - num_prec_radix: Numeric precision radix + - nullable: Is NULL allowed + - remarks: Comments about column + - column_def: Default value + - sql_data_type: SQL data type + - sql_datetime_sub: Datetime/interval subcode + - char_octet_length: Maximum length in bytes of a character/binary type + - ordinal_position: Column sequence number (starting with 1) + - is_nullable: "NO" means column does not allow NULL, "YES" means it does + + Example: + # Get all columns in table 'Customers' + columns = cursor.columns(table='Customers') + + # Get all columns in table 'Customers' in schema 'dbo' + columns = cursor.columns(table='Customers', schema='dbo') + + # Get column named 'CustomerID' in any table + columns = cursor.columns(column='CustomerID') + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Convert None values to empty strings as required by ODBC API + catalog_p = "" if catalog is None else catalog + schema_p = "" if schema is None else schema + table_p = "" if table is None else table + column_p = "" if column is None else column + + # Call the SQLColumns function + retcode = ddbc_bindings.DDBCSQLColumns( + self.hstmt, + catalog_p, + schema_p, + table_p, + column_p + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except Exception: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, str, int, str, int, int, int, int, int, str, str, int, int, int, int, str] + self.description = [ + ("table_cat", column_types[0], None, 128, 128, 0, True), + ("table_schem", column_types[1], None, 128, 128, 0, True), + ("table_name", column_types[2], None, 128, 128, 0, False), + ("column_name", column_types[3], None, 128, 128, 0, False), + ("data_type", column_types[4], None, 10, 10, 0, False), + ("type_name", column_types[5], None, 128, 128, 0, False), + ("column_size", column_types[6], None, 10, 10, 0, True), + ("buffer_length", column_types[7], None, 10, 10, 0, True), + ("decimal_digits", column_types[8], None, 10, 10, 0, True), + ("num_prec_radix", column_types[9], None, 10, 10, 0, True), + ("nullable", column_types[10], None, 10, 10, 0, False), + ("remarks", column_types[11], None, 254, 254, 0, True), + ("column_def", column_types[12], None, 254, 254, 0, True), + ("sql_data_type", column_types[13], None, 10, 10, 0, False), + ("sql_datetime_sub", column_types[14], None, 10, 10, 0, True), + ("char_octet_length", column_types[15], None, 10, 10, 0, True), + ("ordinal_position", column_types[16], None, 10, 10, 0, False), + ("is_nullable", column_types[17], None, 254, 254, 0, True) + ] + + # Define column names in ODBC standard order + column_names = [ + "table_cat", "table_schem", "table_name", "column_name", "data_type", + "type_name", "column_size", "buffer_length", "decimal_digits", + "num_prec_radix", "nullable", "remarks", "column_def", "sql_data_type", + "sql_datetime_sub", "char_octet_length", "ordinal_position", "is_nullable" + ] + + # Fetch all rows + rows_data = [] + ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) + + # Create a column map for attribute access + column_map = {name: i for i, name in enumerate(column_names)} + + # Create Row objects with the column map + result_rows = [] + for row_data in rows_data: + row = Row(self, self.description, row_data) + row._column_map = column_map + result_rows.append(row) + + return result_rows @staticmethod def _select_best_sample_value(column): diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 2beb3663..e20e0348 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -129,6 +129,7 @@ SQLForeignKeysFunc SQLForeignKeys_ptr = nullptr; SQLPrimaryKeysFunc SQLPrimaryKeys_ptr = nullptr; SQLSpecialColumnsFunc SQLSpecialColumns_ptr = nullptr; SQLStatisticsFunc SQLStatistics_ptr = nullptr; +SQLColumnsFunc SQLColumns_ptr = nullptr; // Transaction APIs SQLEndTranFunc SQLEndTran_ptr = nullptr; @@ -791,6 +792,7 @@ DriverHandle LoadDriverOrThrowException() { SQLPrimaryKeys_ptr = GetFunctionPointer(handle, "SQLPrimaryKeysW"); SQLSpecialColumns_ptr = GetFunctionPointer(handle, "SQLSpecialColumnsW"); SQLStatistics_ptr = GetFunctionPointer(handle, "SQLStatisticsW"); + SQLColumns_ptr = GetFunctionPointer(handle, "SQLColumnsW"); SQLEndTran_ptr = GetFunctionPointer(handle, "SQLEndTran"); SQLDisconnect_ptr = GetFunctionPointer(handle, "SQLDisconnect"); @@ -810,7 +812,8 @@ DriverHandle LoadDriverOrThrowException() { SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLGetTypeInfo_ptr && SQLProcedures_ptr && SQLForeignKeys_ptr && - SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && SQLStatistics_ptr; + SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && SQLStatistics_ptr && + SQLColumns_ptr; if (!success) { ThrowStdException("Failed to load required function pointers from driver."); @@ -1047,6 +1050,47 @@ SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle, #endif } +SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, + const std::wstring& catalog, + const std::wstring& schema, + const std::wstring& table, + const std::wstring& column) { + if (!SQLColumns_ptr) { + ThrowStdException("SQLColumns function not loaded"); + } + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalog); + std::vector schemaBuf = WStringToSQLWCHAR(schema); + std::vector tableBuf = WStringToSQLWCHAR(table); + std::vector columnBuf = WStringToSQLWCHAR(column); + + return SQLColumns_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : tableBuf.data(), + table.empty() ? 0 : SQL_NTS, + column.empty() ? nullptr : columnBuf.data(), + column.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLColumns_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), + table.empty() ? 0 : SQL_NTS, + column.empty() ? nullptr : (SQLWCHAR*)column.c_str(), + column.empty() ? 0 : SQL_NTS); +#endif +} + // Helper function to check for driver errors ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { LOG("Checking errors for retcode - {}" , retcode); @@ -2853,6 +2897,14 @@ PYBIND11_MODULE(ddbc_bindings, m) { SQLUSMALLINT reserved) { return SQLStatistics_wrap(StatementHandle, catalog, schema, table, unique, reserved); }); + m.def("DDBCSQLColumns", [](SqlHandlePtr StatementHandle, + const std::wstring& catalog, + const std::wstring& schema, + const std::wstring& table, + const std::wstring& column) { + return SQLColumns_wrap(StatementHandle, catalog, schema, table, column); + }); + // Add a version attribute m.attr("__version__") = "1.0.0"; diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index edaeb6b0..d757ad95 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -119,6 +119,9 @@ typedef SQLRETURN (SQL_API* SQLSpecialColumnsFunc)(SQLHSTMT, SQLUSMALLINT, SQLWC typedef SQLRETURN (SQL_API* SQLStatisticsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLUSMALLINT, SQLUSMALLINT); +typedef SQLRETURN (SQL_API* SQLColumnsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLWCHAR*, SQLSMALLINT); // Transaction APIs typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); @@ -168,6 +171,7 @@ extern SQLForeignKeysFunc SQLForeignKeys_ptr; extern SQLPrimaryKeysFunc SQLPrimaryKeys_ptr; extern SQLSpecialColumnsFunc SQLSpecialColumns_ptr; extern SQLStatisticsFunc SQLStatistics_ptr; +extern SQLColumnsFunc SQLColumns_ptr; // Transaction APIs extern SQLEndTranFunc SQLEndTran_ptr; diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index b71a5e6d..01761950 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -3267,6 +3267,454 @@ def test_statistics_cleanup(cursor, db_connection): except Exception as e: pytest.fail(f"Test cleanup failed: {e}") +def test_columns_setup(cursor, db_connection): + """Create test tables for columns method testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_cols_schema') EXEC('CREATE SCHEMA pytest_cols_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_test") + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_special_test") + + # Create test table with various column types + cursor.execute(""" + CREATE TABLE pytest_cols_schema.columns_test ( + id INT PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + description NVARCHAR(MAX) NULL, + price DECIMAL(10, 2) NULL, + created_date DATETIME DEFAULT GETDATE(), + is_active BIT NOT NULL DEFAULT 1, + binary_data VARBINARY(MAX) NULL, + notes TEXT NULL, + [computed_col] AS (name + ' - ' + CAST(id AS VARCHAR(10))) + ) + """) + + # Create table with special column names and edge cases - fix the problematic column name + cursor.execute(""" + CREATE TABLE pytest_cols_schema.columns_special_test ( + [ID] INT PRIMARY KEY, + [User Name] NVARCHAR(100) NULL, + [Spaces Multiple] VARCHAR(50) NULL, + [123_numeric_start] INT NULL, + [MAX] VARCHAR(20) NULL, -- SQL keyword as column name + [SELECT] INT NULL, -- SQL keyword as column name + [Column.With.Dots] VARCHAR(20) NULL, + [Column/With/Slashes] VARCHAR(20) NULL, + [Column_With_Underscores] VARCHAR(20) NULL -- Changed from problematic nested brackets + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_columns_all(cursor, db_connection): + """Test columns returns information about all columns in all tables""" + try: + # First set up our test tables + test_columns_setup(cursor, db_connection) + + # Get all columns (no filters) + cols = cursor.columns() + + # Verify we got results + assert cols is not None, "columns() should return results" + assert len(cols) > 0, "columns() should return at least one column" + + # Verify our test tables' columns are in the results + # Use case-insensitive comparison to avoid driver case sensitivity issues + found_test_table = False + for col in cols: + if (hasattr(col, 'table_name') and + col.table_name and + col.table_name.lower() == 'columns_test' and + hasattr(col, 'table_schem') and + col.table_schem and + col.table_schem.lower() == 'pytest_cols_schema'): + found_test_table = True + break + + assert found_test_table, "Test table columns should be included in results" + + # Verify structure of results + first_row = cols[0] + assert hasattr(first_row, 'table_cat'), "Result should have table_cat column" + assert hasattr(first_row, 'table_schem'), "Result should have table_schem column" + assert hasattr(first_row, 'table_name'), "Result should have table_name column" + assert hasattr(first_row, 'column_name'), "Result should have column_name column" + assert hasattr(first_row, 'data_type'), "Result should have data_type column" + assert hasattr(first_row, 'type_name'), "Result should have type_name column" + assert hasattr(first_row, 'column_size'), "Result should have column_size column" + assert hasattr(first_row, 'buffer_length'), "Result should have buffer_length column" + assert hasattr(first_row, 'decimal_digits'), "Result should have decimal_digits column" + assert hasattr(first_row, 'num_prec_radix'), "Result should have num_prec_radix column" + assert hasattr(first_row, 'nullable'), "Result should have nullable column" + assert hasattr(first_row, 'remarks'), "Result should have remarks column" + assert hasattr(first_row, 'column_def'), "Result should have column_def column" + assert hasattr(first_row, 'sql_data_type'), "Result should have sql_data_type column" + assert hasattr(first_row, 'sql_datetime_sub'), "Result should have sql_datetime_sub column" + assert hasattr(first_row, 'char_octet_length'), "Result should have char_octet_length column" + assert hasattr(first_row, 'ordinal_position'), "Result should have ordinal_position column" + assert hasattr(first_row, 'is_nullable'), "Result should have is_nullable column" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_specific_table(cursor, db_connection): + """Test columns returns information about a specific table""" + try: + # Get columns for the test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ) + + # Verify we got results + assert len(cols) == 9, "Should find exactly 9 columns in columns_test" + + # Verify all column names are present (case insensitive) + col_names = [col.column_name.lower() for col in cols] + expected_names = ['id', 'name', 'description', 'price', 'created_date', + 'is_active', 'binary_data', 'notes', 'computed_col'] + + for name in expected_names: + assert name in col_names, f"Column {name} should be in results" + + # Verify details of a specific column (id) + id_col = next(col for col in cols if col.column_name.lower() == 'id') + assert id_col.nullable == 0, "id column should be non-nullable" + assert id_col.ordinal_position == 1, "id should be the first column" + assert id_col.is_nullable == "NO", "is_nullable should be NO for id column" + + # Check data types (but don't assume specific ODBC type codes since they vary by driver) + # Instead check that the type_name is correct + id_type = id_col.type_name.lower() + assert 'int' in id_type, f"id column should be INTEGER type, got {id_type}" + + # Check a nullable column + desc_col = next(col for col in cols if col.column_name.lower() == 'description') + assert desc_col.nullable == 1, "description column should be nullable" + assert desc_col.is_nullable == "YES", "is_nullable should be YES for description column" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_special_chars(cursor, db_connection): + """Test columns with special characters and edge cases""" + try: + # Get columns for the special table + cols = cursor.columns( + table='columns_special_test', + schema='pytest_cols_schema' + ) + + # Verify we got results + assert len(cols) == 9, "Should find exactly 9 columns in columns_special_test" + + # Check that special column names are handled correctly + col_names = [col.column_name for col in cols] + + # Create case-insensitive lookup + col_names_lower = [name.lower() if name else None for name in col_names] + + # Check for columns with special characters - note that column names might be + # returned with or without brackets/quotes depending on the driver + assert any('user name' in name.lower() for name in col_names), "Column with spaces should be in results" + assert any('id' == name.lower() for name in col_names), "ID column should be in results" + assert any('123_numeric_start' in name.lower() for name in col_names), "Column starting with numbers should be in results" + assert any('max' == name.lower() for name in col_names), "MAX column should be in results" + assert any('select' == name.lower() for name in col_names), "SELECT column should be in results" + assert any('column.with.dots' in name.lower() for name in col_names), "Column with dots should be in results" + assert any('column/with/slashes' in name.lower() for name in col_names), "Column with slashes should be in results" + assert any('column_with_underscores' in name.lower() for name in col_names), "Column with underscores should be in results" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_specific_column(cursor, db_connection): + """Test columns with specific column filter""" + try: + # Get specific column + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='name' + ) + + # Verify we got just one result + assert len(cols) == 1, "Should find exactly 1 column named 'name'" + + # Verify column details + col = cols[0] + assert col.column_name.lower() == 'name', "Column name should be 'name'" + assert col.table_name.lower() == 'columns_test', "Table name should be 'columns_test'" + assert col.table_schem.lower() == 'pytest_cols_schema', "Schema should be 'pytest_cols_schema'" + assert col.nullable == 0, "name column should be non-nullable" + + # Get column using pattern (% wildcard) + pattern_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%date%' + ) + + # Should find created_date column + assert len(pattern_cols) == 1, "Should find 1 column matching '%date%'" + assert pattern_cols[0].column_name.lower() == 'created_date', "Should find created_date column" + + # Get multiple columns with pattern + multi_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%d%' # Should match id, description, created_date + ) + + # At least 3 columns should match this pattern + assert len(multi_cols) >= 3, "Should find at least 3 columns matching '%d%'" + match_names = [col.column_name.lower() for col in multi_cols] + assert 'id' in match_names, "id should match '%d%'" + assert 'description' in match_names, "description should match '%d%'" + assert 'created_date' in match_names, "created_date should match '%d%'" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_with_underscore_pattern(cursor, db_connection): + """Test columns with underscore wildcard pattern""" + try: + # Get columns with underscore pattern (one character wildcard) + # Looking for 'id' (exactly 2 chars) + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='__' + ) + + # Should find 'id' column + id_found = False + for col in cols: + if col.column_name.lower() == 'id' and col.table_name.lower() == 'columns_test': + id_found = True + break + + assert id_found, "Should find 'id' column with pattern '__'" + + # Try a more complex pattern with both % and _ + # For example: '%_d%' matches any column with 'd' as the second or later character + pattern_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%_d%' + ) + + # Should match 'id' (if considering case-insensitive) and 'created_date' + match_names = [col.column_name.lower() for col in pattern_cols + if col.table_name.lower() == 'columns_test'] + + # At least 'created_date' should match this pattern + assert 'created_date' in match_names, "created_date should match '%_d%'" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_nonexistent(cursor): + """Test columns with non-existent table or column""" + # Test with non-existent table + table_cols = cursor.columns(table='nonexistent_table_xyz123') + assert len(table_cols) == 0, "Should return empty list for non-existent table" + + # Test with non-existent column in existing table + col_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='nonexistent_column_xyz123' + ) + assert len(col_cols) == 0, "Should return empty list for non-existent column" + + # Test with non-existent schema + schema_cols = cursor.columns( + table='columns_test', + schema='nonexistent_schema_xyz123' + ) + assert len(schema_cols) == 0, "Should return empty list for non-existent schema" + +def test_columns_data_types(cursor, db_connection): + """Test columns returns correct data type information""" + try: + # Get all columns from test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ) + + # Create a dictionary mapping column names to their details + col_dict = {col.column_name.lower(): col for col in cols} + + # Check data types by name (case insensitive checks) + # Note: We're checking type_name as a string to avoid SQL type code inconsistencies + # between drivers + + # INT column + assert 'int' in col_dict['id'].type_name.lower(), "id should be INT type" + + # NVARCHAR column + assert any(name in col_dict['name'].type_name.lower() + for name in ['nvarchar', 'varchar', 'char', 'wchar']), "name should be NVARCHAR type" + + # DECIMAL column + assert any(name in col_dict['price'].type_name.lower() + for name in ['decimal', 'numeric', 'money']), "price should be DECIMAL type" + + # BIT column + assert any(name in col_dict['is_active'].type_name.lower() + for name in ['bit', 'boolean']), "is_active should be BIT type" + + # TEXT column + assert any(name in col_dict['notes'].type_name.lower() + for name in ['text', 'char', 'varchar']), "notes should be TEXT type" + + # Check nullable flag + assert col_dict['id'].nullable == 0, "id should be non-nullable" + assert col_dict['description'].nullable == 1, "description should be nullable" + + # Check column size + assert col_dict['name'].column_size == 100, "name should have size 100" + + # Check decimal digits for numeric type + assert col_dict['price'].decimal_digits == 2, "price should have 2 decimal digits" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_catalog_filter(cursor, db_connection): + """Test columns with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get columns with current catalog + cols = cursor.columns( + table='columns_test', + catalog=current_db, + schema='pytest_cols_schema' + ) + + # Verify catalog filter worked + assert len(cols) > 0, "Should find columns with correct catalog" + + # Check catalog in results + for col in cols: + # Some drivers might return None for catalog + if col.table_cat is not None: + assert col.table_cat.lower() == current_db.lower(), "Wrong table catalog" + + # Test with non-existent catalog + fake_cols = cursor.columns( + table='columns_test', + catalog='nonexistent_db_xyz123', + schema='pytest_cols_schema' + ) + assert len(fake_cols) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_schema_pattern(cursor, db_connection): + """Test columns with schema name pattern""" + try: + # Get columns with schema pattern + cols = cursor.columns( + table='columns_test', + schema='pytest_%' + ) + + # Should find our test table columns + test_cols = [col for col in cols if col.table_name.lower() == 'columns_test'] + assert len(test_cols) > 0, "Should find columns using schema pattern" + + # Try a more specific pattern + specific_cols = cursor.columns( + table='columns_test', + schema='pytest_cols%' + ) + + # Should still find our test table columns + test_cols = [col for col in specific_cols if col.table_name.lower() == 'columns_test'] + assert len(test_cols) > 0, "Should find columns using specific schema pattern" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_table_pattern(cursor, db_connection): + """Test columns with table name pattern""" + try: + # Get columns with table pattern + cols = cursor.columns( + table='columns_%', + schema='pytest_cols_schema' + ) + + # Should find columns from both test tables + tables_found = set() + for col in cols: + if col.table_name: + tables_found.add(col.table_name.lower()) + + assert 'columns_test' in tables_found, "Should find columns_test with pattern columns_%" + assert 'columns_special_test' in tables_found, "Should find columns_special_test with pattern columns_%" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_ordinal_position(cursor, db_connection): + """Test ordinal_position is correct in columns results""" + try: + # Get columns for the test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ) + + # Sort by ordinal position + sorted_cols = sorted(cols, key=lambda col: col.ordinal_position) + + # Verify positions are consecutive starting from 1 + for i, col in enumerate(sorted_cols, 1): + assert col.ordinal_position == i, f"Column {col.column_name} should have ordinal_position {i}" + + # First column should be id (primary key) + assert sorted_cols[0].column_name.lower() == 'id', "First column should be id" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_test") + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_special_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_cols_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + def test_close(db_connection): """Test closing the cursor""" try: