Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,122 @@ def statistics(self, table, catalog=None, schema=None, unique=False, quick=True)
result_rows.append(row)

return result_rows

def columns(self, table=None, catalog=None, schema=None, column=None):
"""
Creates a result set of column information in the specified tables
using the SQLColumns function.

Args:
table (str, optional): The table name pattern. Default is None (all tables).
catalog (str, optional): The catalog name. Default is None.
schema (str, optional): The schema name pattern. Default is None.
column (str, optional): The column name pattern. Default is None (all columns).

Returns:
list: A list of Row objects containing column information with these columns:
- table_cat: Catalog name
- table_schem: Schema name
- table_name: Table name
- column_name: Column name
- data_type: SQL data type code
- type_name: Data source-dependent type name
- column_size: Column size
- buffer_length: Transfer size in bytes
- decimal_digits: Number of decimal digits
- num_prec_radix: Numeric precision radix
- nullable: Is NULL allowed
- remarks: Comments about column
- column_def: Default value
- sql_data_type: SQL data type
- sql_datetime_sub: Datetime/interval subcode
- char_octet_length: Maximum length in bytes of a character/binary type
- ordinal_position: Column sequence number (starting with 1)
- is_nullable: "NO" means column does not allow NULL, "YES" means it does

Example:
# Get all columns in table 'Customers'
columns = cursor.columns(table='Customers')

# Get all columns in table 'Customers' in schema 'dbo'
columns = cursor.columns(table='Customers', schema='dbo')

# Get column named 'CustomerID' in any table
columns = cursor.columns(column='CustomerID')
"""
self._check_closed()

# Always reset the cursor first to ensure clean state
self._reset_cursor()

# Convert None values to empty strings as required by ODBC API
catalog_p = "" if catalog is None else catalog
schema_p = "" if schema is None else schema
table_p = "" if table is None else table
column_p = "" if column is None else column

# Call the SQLColumns function
retcode = ddbc_bindings.DDBCSQLColumns(
self.hstmt,
catalog_p,
schema_p,
table_p,
column_p
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode)

# Initialize description from column metadata
column_metadata = []
try:
ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata)
self._initialize_description(column_metadata)
except Exception:
# If describe fails, create a manual description for the standard columns
column_types = [str, str, str, str, int, str, int, int, int, int, int, str, str, int, int, int, int, str]
self.description = [
("table_cat", column_types[0], None, 128, 128, 0, True),
("table_schem", column_types[1], None, 128, 128, 0, True),
("table_name", column_types[2], None, 128, 128, 0, False),
("column_name", column_types[3], None, 128, 128, 0, False),
("data_type", column_types[4], None, 10, 10, 0, False),
("type_name", column_types[5], None, 128, 128, 0, False),
("column_size", column_types[6], None, 10, 10, 0, True),
("buffer_length", column_types[7], None, 10, 10, 0, True),
("decimal_digits", column_types[8], None, 10, 10, 0, True),
("num_prec_radix", column_types[9], None, 10, 10, 0, True),
("nullable", column_types[10], None, 10, 10, 0, False),
("remarks", column_types[11], None, 254, 254, 0, True),
("column_def", column_types[12], None, 254, 254, 0, True),
("sql_data_type", column_types[13], None, 10, 10, 0, False),
("sql_datetime_sub", column_types[14], None, 10, 10, 0, True),
("char_octet_length", column_types[15], None, 10, 10, 0, True),
("ordinal_position", column_types[16], None, 10, 10, 0, False),
("is_nullable", column_types[17], None, 254, 254, 0, True)
]

# Define column names in ODBC standard order
column_names = [
"table_cat", "table_schem", "table_name", "column_name", "data_type",
"type_name", "column_size", "buffer_length", "decimal_digits",
"num_prec_radix", "nullable", "remarks", "column_def", "sql_data_type",
"sql_datetime_sub", "char_octet_length", "ordinal_position", "is_nullable"
]

# Fetch all rows
rows_data = []
ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)

# Create a column map for attribute access
column_map = {name: i for i, name in enumerate(column_names)}

# Create Row objects with the column map
result_rows = []
for row_data in rows_data:
row = Row(self, self.description, row_data)
row._column_map = column_map
result_rows.append(row)

return result_rows

@staticmethod
def _select_best_sample_value(column):
Expand Down
54 changes: 53 additions & 1 deletion mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ SQLForeignKeysFunc SQLForeignKeys_ptr = nullptr;
SQLPrimaryKeysFunc SQLPrimaryKeys_ptr = nullptr;
SQLSpecialColumnsFunc SQLSpecialColumns_ptr = nullptr;
SQLStatisticsFunc SQLStatistics_ptr = nullptr;
SQLColumnsFunc SQLColumns_ptr = nullptr;

// Transaction APIs
SQLEndTranFunc SQLEndTran_ptr = nullptr;
Expand Down Expand Up @@ -791,6 +792,7 @@ DriverHandle LoadDriverOrThrowException() {
SQLPrimaryKeys_ptr = GetFunctionPointer<SQLPrimaryKeysFunc>(handle, "SQLPrimaryKeysW");
SQLSpecialColumns_ptr = GetFunctionPointer<SQLSpecialColumnsFunc>(handle, "SQLSpecialColumnsW");
SQLStatistics_ptr = GetFunctionPointer<SQLStatisticsFunc>(handle, "SQLStatisticsW");
SQLColumns_ptr = GetFunctionPointer<SQLColumnsFunc>(handle, "SQLColumnsW");

SQLEndTran_ptr = GetFunctionPointer<SQLEndTranFunc>(handle, "SQLEndTran");
SQLDisconnect_ptr = GetFunctionPointer<SQLDisconnectFunc>(handle, "SQLDisconnect");
Expand All @@ -810,7 +812,8 @@ DriverHandle LoadDriverOrThrowException() {
SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr &&
SQLFreeStmt_ptr && SQLGetDiagRec_ptr &&
SQLGetTypeInfo_ptr && SQLProcedures_ptr && SQLForeignKeys_ptr &&
SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && SQLStatistics_ptr;
SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && SQLStatistics_ptr &&
SQLColumns_ptr;

if (!success) {
ThrowStdException("Failed to load required function pointers from driver.");
Expand Down Expand Up @@ -1047,6 +1050,47 @@ SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle,
#endif
}

SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle,
const std::wstring& catalog,
const std::wstring& schema,
const std::wstring& table,
const std::wstring& column) {
if (!SQLColumns_ptr) {
ThrowStdException("SQLColumns function not loaded");
}

#if defined(__APPLE__) || defined(__linux__)
// Unix implementation
std::vector<SQLWCHAR> catalogBuf = WStringToSQLWCHAR(catalog);
std::vector<SQLWCHAR> schemaBuf = WStringToSQLWCHAR(schema);
std::vector<SQLWCHAR> tableBuf = WStringToSQLWCHAR(table);
std::vector<SQLWCHAR> columnBuf = WStringToSQLWCHAR(column);

return SQLColumns_ptr(
StatementHandle->get(),
catalog.empty() ? nullptr : catalogBuf.data(),
catalog.empty() ? 0 : SQL_NTS,
schema.empty() ? nullptr : schemaBuf.data(),
schema.empty() ? 0 : SQL_NTS,
table.empty() ? nullptr : tableBuf.data(),
table.empty() ? 0 : SQL_NTS,
column.empty() ? nullptr : columnBuf.data(),
column.empty() ? 0 : SQL_NTS);
#else
// Windows implementation
return SQLColumns_ptr(
StatementHandle->get(),
catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(),
catalog.empty() ? 0 : SQL_NTS,
schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(),
schema.empty() ? 0 : SQL_NTS,
table.empty() ? nullptr : (SQLWCHAR*)table.c_str(),
table.empty() ? 0 : SQL_NTS,
column.empty() ? nullptr : (SQLWCHAR*)column.c_str(),
column.empty() ? 0 : SQL_NTS);
#endif
}

// Helper function to check for driver errors
ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) {
LOG("Checking errors for retcode - {}" , retcode);
Expand Down Expand Up @@ -2853,6 +2897,14 @@ PYBIND11_MODULE(ddbc_bindings, m) {
SQLUSMALLINT reserved) {
return SQLStatistics_wrap(StatementHandle, catalog, schema, table, unique, reserved);
});
m.def("DDBCSQLColumns", [](SqlHandlePtr StatementHandle,
const std::wstring& catalog,
const std::wstring& schema,
const std::wstring& table,
const std::wstring& column) {
return SQLColumns_wrap(StatementHandle, catalog, schema, table, column);
});


// Add a version attribute
m.attr("__version__") = "1.0.0";
Expand Down
4 changes: 4 additions & 0 deletions mssql_python/pybind/ddbc_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ typedef SQLRETURN (SQL_API* SQLSpecialColumnsFunc)(SQLHSTMT, SQLUSMALLINT, SQLWC
typedef SQLRETURN (SQL_API* SQLStatisticsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*,
SQLSMALLINT, SQLWCHAR*, SQLSMALLINT,
SQLUSMALLINT, SQLUSMALLINT);
typedef SQLRETURN (SQL_API* SQLColumnsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*,
SQLSMALLINT, SQLWCHAR*, SQLSMALLINT,
SQLWCHAR*, SQLSMALLINT);

// Transaction APIs
typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT);
Expand Down Expand Up @@ -168,6 +171,7 @@ extern SQLForeignKeysFunc SQLForeignKeys_ptr;
extern SQLPrimaryKeysFunc SQLPrimaryKeys_ptr;
extern SQLSpecialColumnsFunc SQLSpecialColumns_ptr;
extern SQLStatisticsFunc SQLStatistics_ptr;
extern SQLColumnsFunc SQLColumns_ptr;

// Transaction APIs
extern SQLEndTranFunc SQLEndTran_ptr;
Expand Down
Loading