From 6d5ac80aace7b3e703f2b2c16ab753f8bb9f94b2 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Thu, 21 Aug 2025 14:55:44 +0530 Subject: [PATCH] FEAT: Adding output converter --- mssql_python/connection.py | 70 ++++++++++ mssql_python/row.py | 57 +++++++- tests/test_003_connection.py | 244 ++++++++++++++++++++++++++++++++++- 3 files changed, 369 insertions(+), 2 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index fe400ec3..e8452d4d 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -208,6 +208,76 @@ def execute(self, sql, *args): cursor = self.cursor() cursor.execute(sql, *args) return cursor + + def add_output_converter(self, sqltype, func) -> None: + """ + Register an output converter function that will be called whenever a value + with the given SQL type is read from the database. + + Args: + sqltype (int): The integer SQL type value to convert, which can be one of the + defined standard constants (e.g. SQL_VARCHAR) or a database-specific + value (e.g. -151 for the SQL Server 2008 geometry data type). + func (callable): The converter function which will be called with a single parameter, + the value, and should return the converted value. If the value is NULL + then the parameter passed to the function will be None, otherwise it + will be a bytes object. + + Returns: + None + """ + if not hasattr(self, '_output_converters'): + self._output_converters = {} + self._output_converters[sqltype] = func + # Pass to the underlying connection if native implementation supports it + if hasattr(self._conn, 'add_output_converter'): + self._conn.add_output_converter(sqltype, func) + log('info', f"Added output converter for SQL type {sqltype}") + + def get_output_converter(self, sqltype): + """ + Get the output converter function for the specified SQL type. + + Args: + sqltype (int or type): The SQL type value or Python type to get the converter for + + Returns: + callable or None: The converter function or None if no converter is registered + """ + if not hasattr(self, '_output_converters'): + return None + return self._output_converters.get(sqltype) + + def remove_output_converter(self, sqltype): + """ + Remove the output converter function for the specified SQL type. + + Args: + sqltype (int or type): The SQL type value to remove the converter for + + Returns: + None + """ + if hasattr(self, '_output_converters') and sqltype in self._output_converters: + del self._output_converters[sqltype] + # Pass to the underlying connection if native implementation supports it + if hasattr(self._conn, 'remove_output_converter'): + self._conn.remove_output_converter(sqltype) + log('info', f"Removed output converter for SQL type {sqltype}") + + def clear_output_converters(self) -> None: + """ + Remove all output converter functions. + + Returns: + None + """ + if hasattr(self, '_output_converters'): + self._output_converters.clear() + # Pass to the underlying connection if native implementation supports it + if hasattr(self._conn, 'clear_output_converters'): + self._conn.clear_output_converters() + log('info', "Cleared all output converters") def commit(self) -> None: """ diff --git a/mssql_python/row.py b/mssql_python/row.py index 1f54e8c8..01c96fa7 100644 --- a/mssql_python/row.py +++ b/mssql_python/row.py @@ -20,7 +20,13 @@ def __init__(self, cursor, description, values, column_map=None): column_map: Optional pre-built column map (for optimization) """ self._cursor = cursor - self._values = values + self._description = description + + # Apply output converters if available + if hasattr(cursor.connection, '_output_converters') and cursor.connection._output_converters: + self._values = self._apply_output_converters(values) + else: + self._values = values # TODO: ADO task - Optimize memory usage by sharing column map across rows # Instead of storing the full cursor_description in each Row object: @@ -38,6 +44,55 @@ def __init__(self, cursor, description, values, column_map=None): self._column_map = column_map + def _apply_output_converters(self, values): + """ + Apply output converters to raw values. + + Args: + values: Raw values from the database + + Returns: + List of converted values + """ + if not self._description: + return values + + converted_values = list(values) + + for i, (value, desc) in enumerate(zip(values, self._description)): + if desc is None or value is None: + continue + + # Get SQL type from description + sql_type = desc[1] # type_code is at index 1 in description tuple + + # Try to get a converter for this type + converter = self._cursor.connection.get_output_converter(sql_type) + + # If no converter found for the SQL type but the value is a string or bytes, + # try the WVARCHAR converter as a fallback + if converter is None and isinstance(value, (str, bytes)): + from mssql_python.constants import ConstantsDDBC + converter = self._cursor.connection.get_output_converter(ConstantsDDBC.SQL_WVARCHAR.value) + + # If we found a converter, apply it + if converter: + try: + # If value is already a Python type (str, int, etc.), + # we need to convert it to bytes for our converters + if isinstance(value, str): + # Encode as UTF-16LE for string values (SQL_WVARCHAR format) + value_bytes = value.encode('utf-16-le') + converted_values[i] = converter(value_bytes) + else: + converted_values[i] = converter(value) + except Exception as e: + # If conversion fails, keep the original value + # You might want to log this error + pass + + return converted_values + def __getitem__(self, index): """Allow accessing by numeric index: row[0]""" return self._values[index] diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 8b3af574..a2ecf0ac 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -23,6 +23,9 @@ import time from mssql_python import Connection, connect, pooling import threading +import struct +from datetime import datetime, timedelta, timezone +from mssql_python.constants import ConstantsDDBC def drop_table_if_exists(cursor, table_name): """Drop the table if it exists""" @@ -31,6 +34,26 @@ def drop_table_if_exists(cursor, table_name): except Exception as e: pytest.fail(f"Failed to drop table {table_name}: {e}") +# Add these helper functions after other helper functions +def handle_datetimeoffset(dto_value): + """Converter function for SQL Server's DATETIMEOFFSET type""" + if dto_value is None: + return None + + # The format depends on the ODBC driver and how it returns binary data + # This matches SQL Server's format for DATETIMEOFFSET + tup = struct.unpack("<6hI2h", dto_value) # e.g., (2017, 3, 16, 10, 35, 18, 500000000, -6, 0) + return datetime( + tup[0], tup[1], tup[2], tup[3], tup[4], tup[5], tup[6] // 1000, + timezone(timedelta(hours=tup[7], minutes=tup[8])) + ) + +def custom_string_converter(value): + """A simple converter that adds a prefix to string values""" + if value is None: + return None + return "CONVERTED: " + value.decode('utf-16-le') # SQL_WVARCHAR is UTF-16LE encoded + def test_connection_string(conn_str): # Check if the connection string is not None assert conn_str is not None, "Connection string should not be None" @@ -645,4 +668,223 @@ def test_connection_execute_many_parameters(db_connection): # Verify all parameters were correctly passed for i, value in enumerate(params): - assert result[0][i] == value, f"Parameter at position {i} not correctly passed" \ No newline at end of file + assert result[0][i] == value, f"Parameter at position {i} not correctly passed" + +def test_add_output_converter(db_connection): + """Test adding an output converter""" + # Add a converter + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Verify it was added correctly + assert hasattr(db_connection, '_output_converters') + assert sql_wvarchar in db_connection._output_converters + assert db_connection._output_converters[sql_wvarchar] == custom_string_converter + + # Clean up + db_connection.clear_output_converters() + +def test_get_output_converter(db_connection): + """Test getting an output converter""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Initial state - no converter + assert db_connection.get_output_converter(sql_wvarchar) is None + + # Add a converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Get the converter + converter = db_connection.get_output_converter(sql_wvarchar) + assert converter == custom_string_converter + + # Get a non-existent converter + assert db_connection.get_output_converter(999) is None + + # Clean up + db_connection.clear_output_converters() + +def test_remove_output_converter(db_connection): + """Test removing an output converter""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Add a converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + assert db_connection.get_output_converter(sql_wvarchar) is not None + + # Remove the converter + db_connection.remove_output_converter(sql_wvarchar) + assert db_connection.get_output_converter(sql_wvarchar) is None + + # Remove a non-existent converter (should not raise) + db_connection.remove_output_converter(999) + +def test_clear_output_converters(db_connection): + """Test clearing all output converters""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + sql_timestamp_offset = ConstantsDDBC.SQL_TIMESTAMPOFFSET.value + + # Add multiple converters + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + db_connection.add_output_converter(sql_timestamp_offset, handle_datetimeoffset) + + # Verify converters were added + assert db_connection.get_output_converter(sql_wvarchar) is not None + assert db_connection.get_output_converter(sql_timestamp_offset) is not None + + # Clear all converters + db_connection.clear_output_converters() + + # Verify all converters were removed + assert db_connection.get_output_converter(sql_wvarchar) is None + assert db_connection.get_output_converter(sql_timestamp_offset) is None + +def test_converter_integration(db_connection): + """ + Test that converters work during fetching. + + This test verifies that output converters work at the Python level + without requiring native driver support. + """ + cursor = db_connection.cursor() + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Test with string converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Test a simple string query + cursor.execute("SELECT N'test string' AS test_col") + row = cursor.fetchone() + + # Check if the type matches what we expect for SQL_WVARCHAR + # For Cursor.description, the second element is the type code + column_type = cursor.description[0][1] + + # If the cursor description has SQL_WVARCHAR as the type code, + # then our converter should be applied + if column_type == sql_wvarchar: + assert row[0].startswith("CONVERTED:"), "Output converter not applied" + else: + # If the type code is different, adjust the test or the converter + print(f"Column type is {column_type}, not {sql_wvarchar}") + # Add converter for the actual type used + db_connection.clear_output_converters() + db_connection.add_output_converter(column_type, custom_string_converter) + + # Re-execute the query + cursor.execute("SELECT N'test string' AS test_col") + row = cursor.fetchone() + assert row[0].startswith("CONVERTED:"), "Output converter not applied" + + # Clean up + db_connection.clear_output_converters() + +def test_output_converter_with_null_values(db_connection): + """Test that output converters handle NULL values correctly""" + cursor = db_connection.cursor() + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Add converter for string type + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Execute a query with NULL values + cursor.execute("SELECT CAST(NULL AS NVARCHAR(50)) AS null_col") + value = cursor.fetchone()[0] + + # NULL values should remain None regardless of converter + assert value is None + + # Clean up + db_connection.clear_output_converters() + +def test_chaining_output_converters(db_connection): + """Test that output converters can be chained (replaced)""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Define a second converter + def another_string_converter(value): + if value is None: + return None + return "ANOTHER: " + value.decode('utf-16-le') + + # Add first converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Verify first converter is registered + assert db_connection.get_output_converter(sql_wvarchar) == custom_string_converter + + # Replace with second converter + db_connection.add_output_converter(sql_wvarchar, another_string_converter) + + # Verify second converter replaced the first + assert db_connection.get_output_converter(sql_wvarchar) == another_string_converter + + # Clean up + db_connection.clear_output_converters() + +def test_temporary_converter_replacement(db_connection): + """Test temporarily replacing a converter and then restoring it""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Add a converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Save original converter + original_converter = db_connection.get_output_converter(sql_wvarchar) + + # Define a temporary converter + def temp_converter(value): + if value is None: + return None + return "TEMP: " + value.decode('utf-16-le') + + # Replace with temporary converter + db_connection.add_output_converter(sql_wvarchar, temp_converter) + + # Verify temporary converter is in use + assert db_connection.get_output_converter(sql_wvarchar) == temp_converter + + # Restore original converter + db_connection.add_output_converter(sql_wvarchar, original_converter) + + # Verify original converter is restored + assert db_connection.get_output_converter(sql_wvarchar) == original_converter + + # Clean up + db_connection.clear_output_converters() + +def test_multiple_output_converters(db_connection): + """Test that multiple output converters can work together""" + cursor = db_connection.cursor() + + # Execute a query to get the actual type codes used + cursor.execute("SELECT CAST(42 AS INT) as int_col, N'test' as str_col") + int_type = cursor.description[0][1] # Type code for integer column + str_type = cursor.description[1][1] # Type code for string column + + # Add converter for string type + db_connection.add_output_converter(str_type, custom_string_converter) + + # Add converter for integer type + def int_converter(value): + if value is None: + return None + # Convert from bytes to int and multiply by 2 + if isinstance(value, bytes): + return int.from_bytes(value, byteorder='little') * 2 + elif isinstance(value, int): + return value * 2 + return value + + db_connection.add_output_converter(int_type, int_converter) + + # Test query with both types + cursor.execute("SELECT CAST(42 AS INT) as int_col, N'test' as str_col") + row = cursor.fetchone() + + # Verify converters worked + assert row[0] == 84, f"Integer converter failed, got {row[0]} instead of 84" + assert isinstance(row[1], str) and "CONVERTED:" in row[1], f"String converter failed, got {row[1]}" + + # Clean up + db_connection.clear_output_converters() \ No newline at end of file