Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ class ConstantsDDBC(Enum):
SQL_NULLABLE = 1
SQL_MAX_NUMERIC_LEN = 16
SQL_ATTR_QUERY_TIMEOUT = 0
SQL_SCOPE_CURROW = 0
SQL_BEST_ROWID = 1
SQL_ROWVER = 2
SQL_NO_NULLS = 0
SQL_NULLABLE_UNKNOWN = 2

class AuthType(Enum):
"""Constants for authentication types"""
Expand Down
188 changes: 188 additions & 0 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,194 @@ def foreignKeys(self, table=None, catalog=None, schema=None, foreignTable=None,
result_rows.append(row)

return result_rows

def rowIdColumns(self, table, catalog=None, schema=None, nullable=True):
"""
Executes SQLSpecialColumns with SQL_BEST_ROWID which creates a result set of
columns that uniquely identify a row.

Args:
table (str): The table name
catalog (str, optional): The catalog name (database). Defaults to None.
schema (str, optional): The schema name. Defaults to None.
nullable (bool, optional): Whether to include nullable columns. Defaults to True.

Returns:
list: A list of rows with the following columns:
- scope: One of SQL_SCOPE_CURROW, SQL_SCOPE_TRANSACTION, or SQL_SCOPE_SESSION
- column_name: Column name
- data_type: The ODBC SQL data type constant (e.g. SQL_CHAR)
- type_name: Type name
- column_size: Column size
- buffer_length: Buffer length
- decimal_digits: Decimal digits
- pseudo_column: One of SQL_PC_UNKNOWN, SQL_PC_NOT_PSEUDO, SQL_PC_PSEUDO
"""
self._check_closed()

# Always reset the cursor first to ensure clean state
self._reset_cursor()

# Convert None values to empty strings as required by ODBC API
catalog_p = "" if catalog is None else catalog
schema_p = "" if schema is None else schema
table_p = table # Table name is required

# Set the identifier type to SQL_BEST_ROWID (1)
identifier_type = ddbc_sql_const.SQL_BEST_ROWID.value

# Set scope to SQL_SCOPE_CURROW (0) - default scope
scope = ddbc_sql_const.SQL_SCOPE_CURROW.value

# Set nullable flag
nullable_flag = ddbc_sql_const.SQL_NULLABLE.value if nullable else ddbc_sql_const.SQL_NO_NULLS.value

# Call the SQLSpecialColumns function
retcode = ddbc_bindings.DDBCSQLSpecialColumns(
self.hstmt,
identifier_type,
catalog_p,
schema_p,
table_p,
scope,
nullable_flag
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode)

# Initialize description from column metadata
column_metadata = []
try:
ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata)
self._initialize_description(column_metadata)
except Exception:
# If describe fails, create a manual description for the standard columns
column_types = [int, str, int, str, int, int, int, int]
self.description = [
("scope", column_types[0], None, 10, 10, 0, False),
("column_name", column_types[1], None, 128, 128, 0, False),
("data_type", column_types[2], None, 10, 10, 0, False),
("type_name", column_types[3], None, 128, 128, 0, False),
("column_size", column_types[4], None, 10, 10, 0, False),
("buffer_length", column_types[5], None, 10, 10, 0, False),
("decimal_digits", column_types[6], None, 10, 10, 0, True),
("pseudo_column", column_types[7], None, 10, 10, 0, False)
]

# Define column names in ODBC standard order
column_names = [
"scope", "column_name", "data_type", "type_name",
"column_size", "buffer_length", "decimal_digits", "pseudo_column"
]

# Fetch all rows
rows_data = []
ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)

# Create a column map for attribute access
column_map = {name: i for i, name in enumerate(column_names)}

# Create Row objects with the column map
result_rows = []
for row_data in rows_data:
row = Row(self, self.description, row_data)
row._column_map = column_map
result_rows.append(row)

return result_rows

def rowVerColumns(self, table, catalog=None, schema=None, nullable=True):
"""
Executes SQLSpecialColumns with SQL_ROWVER which creates a result set of
columns that are automatically updated when any value in the row is updated.

Args:
table (str): The table name
catalog (str, optional): The catalog name (database). Defaults to None.
schema (str, optional): The schema name. Defaults to None.
nullable (bool, optional): Whether to include nullable columns. Defaults to True.

Returns:
list: A list of rows with the following columns:
- scope: One of SQL_SCOPE_CURROW, SQL_SCOPE_TRANSACTION, or SQL_SCOPE_SESSION
- column_name: Column name
- data_type: The ODBC SQL data type constant (e.g. SQL_CHAR)
- type_name: Type name
- column_size: Column size
- buffer_length: Buffer length
- decimal_digits: Decimal digits
- pseudo_column: One of SQL_PC_UNKNOWN, SQL_PC_NOT_PSEUDO, SQL_PC_PSEUDO
"""
self._check_closed()

# Always reset the cursor first to ensure clean state
self._reset_cursor()

# Convert None values to empty strings as required by ODBC API
catalog_p = "" if catalog is None else catalog
schema_p = "" if schema is None else schema
table_p = table # Table name is required

# Set the identifier type to SQL_ROWVER (2)
identifier_type = ddbc_sql_const.SQL_ROWVER.value

# Set scope to SQL_SCOPE_CURROW (0) - default scope
scope = ddbc_sql_const.SQL_SCOPE_CURROW.value

# Set nullable flag
nullable_flag = ddbc_sql_const.SQL_NULLABLE.value if nullable else ddbc_sql_const.SQL_NO_NULLS.value

# Call the SQLSpecialColumns function
retcode = ddbc_bindings.DDBCSQLSpecialColumns(
self.hstmt,
identifier_type,
catalog_p,
schema_p,
table_p,
scope,
nullable_flag
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode)

# Initialize description from column metadata
column_metadata = []
try:
ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata)
self._initialize_description(column_metadata)
except Exception:
# If describe fails, create a manual description for the standard columns
column_types = [int, str, int, str, int, int, int, int]
self.description = [
("scope", column_types[0], None, 10, 10, 0, False),
("column_name", column_types[1], None, 128, 128, 0, False),
("data_type", column_types[2], None, 10, 10, 0, False),
("type_name", column_types[3], None, 128, 128, 0, False),
("column_size", column_types[4], None, 10, 10, 0, False),
("buffer_length", column_types[5], None, 10, 10, 0, False),
("decimal_digits", column_types[6], None, 10, 10, 0, True),
("pseudo_column", column_types[7], None, 10, 10, 0, False)
]

# Define column names in ODBC standard order
column_names = [
"scope", "column_name", "data_type", "type_name",
"column_size", "buffer_length", "decimal_digits", "pseudo_column"
]

# Fetch all rows
rows_data = []
ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)

# Create a column map for attribute access
column_map = {name: i for i, name in enumerate(column_names)}

# Create Row objects with the column map
result_rows = []
for row_data in rows_data:
row = Row(self, self.description, row_data)
row._column_map = column_map
result_rows.append(row)

return result_rows

@staticmethod
def _select_best_sample_value(column):
Expand Down
59 changes: 58 additions & 1 deletion mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ SQLGetTypeInfoFunc SQLGetTypeInfo_ptr = nullptr;
SQLProceduresFunc SQLProcedures_ptr = nullptr;
SQLForeignKeysFunc SQLForeignKeys_ptr = nullptr;
SQLPrimaryKeysFunc SQLPrimaryKeys_ptr = nullptr;
SQLSpecialColumnsFunc SQLSpecialColumns_ptr = nullptr;

// Transaction APIs
SQLEndTranFunc SQLEndTran_ptr = nullptr;
Expand Down Expand Up @@ -787,6 +788,7 @@ DriverHandle LoadDriverOrThrowException() {
SQLProcedures_ptr = GetFunctionPointer<SQLProceduresFunc>(handle, "SQLProceduresW");
SQLForeignKeys_ptr = GetFunctionPointer<SQLForeignKeysFunc>(handle, "SQLForeignKeysW");
SQLPrimaryKeys_ptr = GetFunctionPointer<SQLPrimaryKeysFunc>(handle, "SQLPrimaryKeysW");
SQLSpecialColumns_ptr = GetFunctionPointer<SQLSpecialColumnsFunc>(handle, "SQLSpecialColumnsW");

SQLEndTran_ptr = GetFunctionPointer<SQLEndTranFunc>(handle, "SQLEndTran");
SQLDisconnect_ptr = GetFunctionPointer<SQLDisconnectFunc>(handle, "SQLDisconnect");
Expand All @@ -806,7 +808,7 @@ DriverHandle LoadDriverOrThrowException() {
SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr &&
SQLFreeStmt_ptr && SQLGetDiagRec_ptr &&
SQLGetTypeInfo_ptr && SQLProcedures_ptr && SQLForeignKeys_ptr &&
SQLPrimaryKeys_ptr;
SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr;

if (!success) {
ThrowStdException("Failed to load required function pointers from driver.");
Expand Down Expand Up @@ -1560,6 +1562,50 @@ SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, py::list& ColumnMeta
return SQL_SUCCESS;
}

SQLRETURN SQLSpecialColumns_wrap(SqlHandlePtr StatementHandle,
SQLSMALLINT identifierType,
const std::wstring& catalog,
const std::wstring& schema,
const std::wstring& table,
SQLSMALLINT scope,
SQLSMALLINT nullable) {
if (!SQLSpecialColumns_ptr) {
ThrowStdException("SQLSpecialColumns function not loaded");
}

#if defined(__APPLE__) || defined(__linux__)
// Unix implementation
std::vector<SQLWCHAR> catalogBuf = WStringToSQLWCHAR(catalog);
std::vector<SQLWCHAR> schemaBuf = WStringToSQLWCHAR(schema);
std::vector<SQLWCHAR> tableBuf = WStringToSQLWCHAR(table);

return SQLSpecialColumns_ptr(
StatementHandle->get(),
identifierType,
catalog.empty() ? nullptr : catalogBuf.data(),
catalog.empty() ? 0 : SQL_NTS,
schema.empty() ? nullptr : schemaBuf.data(),
schema.empty() ? 0 : SQL_NTS,
table.empty() ? nullptr : tableBuf.data(),
table.empty() ? 0 : SQL_NTS,
scope,
nullable);
#else
// Windows implementation
return SQLSpecialColumns_ptr(
StatementHandle->get(),
identifierType,
catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(),
catalog.empty() ? 0 : SQL_NTS,
schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(),
schema.empty() ? 0 : SQL_NTS,
table.empty() ? nullptr : (SQLWCHAR*)table.c_str(),
table.empty() ? 0 : SQL_NTS,
scope,
nullable);
#endif
}

// Wrap SQLFetch to retrieve rows
SQLRETURN SQLFetch_wrap(SqlHandlePtr StatementHandle) {
LOG("Fetch next row");
Expand Down Expand Up @@ -2746,6 +2792,17 @@ PYBIND11_MODULE(ddbc_bindings, m) {
const std::wstring& table) {
return SQLPrimaryKeys_wrap(StatementHandle, catalog, schema, table);
});
m.def("DDBCSQLSpecialColumns", [](SqlHandlePtr StatementHandle,
SQLSMALLINT identifierType,
const std::wstring& catalog,
const std::wstring& schema,
const std::wstring& table,
SQLSMALLINT scope,
SQLSMALLINT nullable) {
return SQLSpecialColumns_wrap(StatementHandle,
identifierType, catalog, schema, table,
scope, nullable);
});

// Add a version attribute
m.attr("__version__") = "1.0.0";
Expand Down
4 changes: 4 additions & 0 deletions mssql_python/pybind/ddbc_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ typedef SQLRETURN (SQL_API* SQLForeignKeysFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT
SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT);
typedef SQLRETURN (SQL_API* SQLPrimaryKeysFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*,
SQLSMALLINT, SQLWCHAR*, SQLSMALLINT);
typedef SQLRETURN (SQL_API* SQLSpecialColumnsFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR*, SQLSMALLINT,
SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT,
SQLUSMALLINT, SQLUSMALLINT);

// Transaction APIs
typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT);
Expand Down Expand Up @@ -160,6 +163,7 @@ extern SQLGetTypeInfoFunc SQLGetTypeInfo_ptr;
extern SQLProceduresFunc SQLProcedures_ptr;
extern SQLForeignKeysFunc SQLForeignKeys_ptr;
extern SQLPrimaryKeysFunc SQLPrimaryKeys_ptr;
extern SQLSpecialColumnsFunc SQLSpecialColumns_ptr;

// Transaction APIs
extern SQLEndTranFunc SQLEndTran_ptr;
Expand Down
Loading