Skip to content

Commit

Permalink
Bind PostgreSQL values using strong type (#1217)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvorisek authored May 19, 2024
1 parent da01b0b commit 3bc5dce
Show file tree
Hide file tree
Showing 16 changed files with 253 additions and 61 deletions.
4 changes: 2 additions & 2 deletions src/Model.php
Original file line number Diff line number Diff line change
Expand Up @@ -1662,8 +1662,8 @@ public function insert(array $row)
$entity = $this->createEntity();

$hasRefs = false;
foreach ($row as $v) {
if (is_array($v)) {
foreach ($row as $k => $v) {
if (is_array($v) && $this->hasReference($k)) {
$hasRefs = true;

break;
Expand Down
21 changes: 15 additions & 6 deletions src/Persistence/Sql.php
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
class Sql extends Persistence
{
use Sql\BinaryTypeCompatibilityTypecastTrait;
use Sql\JsonTypeCompatibilityTypecastTrait;

public const HOOK_INIT_SELECT_QUERY = self::class . '@initSelectQuery';
public const HOOK_BEFORE_INSERT_QUERY = self::class . '@beforeInsertQuery';
Expand Down Expand Up @@ -648,8 +649,12 @@ public function typecastSaveField(Field $field, $value)
{
$value = parent::typecastSaveField($field, $value);

if ($value !== null && !$value instanceof Expression && $this->binaryTypeIsEncodeNeeded($field->type)) {
$value = $this->binaryTypeValueEncode($value);
if ($value !== null && !$value instanceof Expression) {
if ($this->binaryTypeIsEncodeNeeded($field->type)) {
$value = $this->binaryTypeValueEncode($value);
} elseif ($this->jsonTypeIsEncodeNeeded($field->type)) {
$value = $this->jsonTypeValueEncode($value);
}
}

return $value;
Expand All @@ -658,12 +663,16 @@ public function typecastSaveField(Field $field, $value)
#[\Override]
public function typecastLoadField(Field $field, $value)
{
$value = parent::typecastLoadField($field, $value);

if ($value !== null && $this->binaryTypeIsDecodeNeeded($field->type, $value)) {
$value = $this->binaryTypeValueDecode($value);
if ($value !== null) {
if ($this->binaryTypeIsDecodeNeeded($field->type, $value)) {
$value = $this->binaryTypeValueDecode($value);
} elseif ($this->jsonTypeIsDecodeNeeded($field->type, $value)) {
$value = $this->jsonTypeValueDecode($value);
}
}

$value = parent::typecastLoadField($field, $value);

return $value;
}

Expand Down
2 changes: 2 additions & 0 deletions src/Persistence/Sql/Expression.php
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,8 @@ protected function _execute(?object $connection, bool $fromExecuteStatement)
if (\Closure::bind(static fn () => $dummyPersistence->binaryTypeValueIsEncoded($val), null, Persistence\Sql::class)()) {
$val = \Closure::bind(static fn () => $dummyPersistence->binaryTypeValueDecode($val), null, Persistence\Sql::class)();
$type = ParameterType::BINARY;
} elseif (\Closure::bind(static fn () => $dummyPersistence->jsonTypeValueIsEncoded($val), null, Persistence\Sql::class)()) {
$val = \Closure::bind(static fn () => $dummyPersistence->jsonTypeValueDecode($val), null, Persistence\Sql::class)();
}
}
} elseif (is_resource($val)) {
Expand Down
74 changes: 74 additions & 0 deletions src/Persistence/Sql/JsonTypeCompatibilityTypecastTrait.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
<?php

declare(strict_types=1);

namespace Atk4\Data\Persistence\Sql;

use Atk4\Data\Exception;
use Doctrine\DBAL\Platforms\PostgreSQLPlatform;

trait JsonTypeCompatibilityTypecastTrait
{
private function jsonTypeValueGetPrefixConst(): string
{
return "atk4_json\ru5f8mzx4vsm8g2c9\r";
}

private function jsonTypeValueEncode(string $value): string
{
return $this->jsonTypeValueGetPrefixConst() . hash('crc32b', $value) . $value;
}

private function jsonTypeValueIsEncoded(string $value): bool
{
return str_starts_with($value, $this->jsonTypeValueGetPrefixConst());
}

private function jsonTypeValueDecode(string $value): string
{
if (!$this->jsonTypeValueIsEncoded($value)) {
throw new Exception('Unexpected unencoded json value');
}

$resCrc = substr($value, strlen($this->jsonTypeValueGetPrefixConst()), 8);
$res = substr($value, strlen($this->jsonTypeValueGetPrefixConst()) + 8);
if ($resCrc !== hash('crc32b', $res)) {
throw new Exception('Unexpected json value crc');
}

if ($this->jsonTypeValueIsEncoded($res)) {
throw new Exception('Unexpected double encoded json value');
}

return $res;
}

private function jsonTypeIsEncodeNeeded(string $type): bool
{
// json values for PostgreSQL database are stored natively, but we need
// to encode first to hold the json type info for PDO parameter type binding

$platform = $this->getDatabasePlatform();
if ($platform instanceof PostgreSQLPlatform) {
if ($type === 'json') {
return true;
}
}

return false;
}

/**
* @param scalar $value
*/
private function jsonTypeIsDecodeNeeded(string $type, $value): bool
{
if ($this->jsonTypeIsEncodeNeeded($type)) {
if ($this->jsonTypeValueIsEncoded($value)) {
return true;
}
}

return false;
}
}
8 changes: 4 additions & 4 deletions src/Persistence/Sql/Mssql/Query.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ static function ($sqlLeft, $sqlRight) use ($reuse, $makeSqlFx, $nullFromArgsOnly
#[\Override]
protected function _renderConditionLikeOperator(bool $negated, string $sqlLeft, string $sqlRight): string
{
return $this->_renderConditionBinaryReuseBool(
return ($negated ? 'not ' : '') . $this->_renderConditionBinaryReuseBool(
$sqlLeft,
$sqlRight,
function ($sqlLeft, $sqlRight) use ($negated) {
function ($sqlLeft, $sqlRight) {
$iifNtextFx = static function ($valueSql, $trueSql, $falseSql) {
$isNtextFx = static function ($sql, $negate) {
// "select top 0 ..." is always optimized into constant expression
Expand All @@ -90,7 +90,7 @@ function ($sqlLeft, $sqlRight) use ($negated) {
. ' or (' . $isBinaryFx($valueSql, true) . ' and ' . $falseSql . '))';
};

$makeSqlFx = function ($isNtext, $isBinary) use ($sqlLeft, $sqlRight, $negated) {
$makeSqlFx = function ($isNtext, $isBinary) use ($sqlLeft, $sqlRight) {
$quoteStringFx = fn (string $v) => $isNtext
? $this->escapeStringLiteral($v)
: '0x' . bin2hex($v);
Expand All @@ -114,7 +114,7 @@ function ($sqlLeft, $sqlRight) use ($negated) {

$sqlRightEscaped = $replaceFx($sqlRightEscaped, '[', '\[');

return $sqlLeft . ($negated ? ' not' : '') . ' like ' . $sqlRightEscaped
return $sqlLeft . ' like ' . $sqlRightEscaped
. ($isBinary ? ' collate Latin1_General_BIN' : '')
. ' escape ' . $quoteStringFx('\\');
};
Expand Down
14 changes: 14 additions & 0 deletions src/Persistence/Sql/Postgresql/Connection.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,23 @@
namespace Atk4\Data\Persistence\Sql\Postgresql;

use Atk4\Data\Persistence\Sql\Connection as BaseConnection;
use Doctrine\DBAL\Configuration;

class Connection extends BaseConnection
{
protected string $expressionClass = Expression::class;
protected string $queryClass = Query::class;

#[\Override]
protected static function createDbalConfiguration(): Configuration
{
$configuration = parent::createDbalConfiguration();

$configuration->setMiddlewares([
...$configuration->getMiddlewares(),
new InitializeSessionMiddleware(),
]);

return $configuration;
}
}
11 changes: 11 additions & 0 deletions src/Persistence/Sql/Postgresql/ExpressionTrait.php
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ static function ($matches) use ($params) {
$sql = 'cast(' . $sql . ' as BIGINT)';
} elseif (is_float($value)) {
$sql = 'cast(' . $sql . ' as DOUBLE PRECISION)';
} elseif (is_string($value)) {
$dummyPersistence = (new \ReflectionClass(Persistence\Sql::class))->newInstanceWithoutConstructor();
if (\Closure::bind(static fn () => $dummyPersistence->binaryTypeValueIsEncoded($value), null, Persistence\Sql::class)()) {
$sql = 'cast(' . $sql . ' as bytea)';
} elseif (\Closure::bind(static fn () => $dummyPersistence->jsonTypeValueIsEncoded($value), null, Persistence\Sql::class)()) {
$sql = 'cast(' . $sql . ' as json)';
} else {
$sql = 'cast(' . $sql . ' as citext)';
}
} else {
$sql = 'cast(' . $sql . ' as unknown)';
}

return $sql;
Expand Down
40 changes: 40 additions & 0 deletions src/Persistence/Sql/Postgresql/InitializeSessionMiddleware.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<?php

declare(strict_types=1);

namespace Atk4\Data\Persistence\Sql\Postgresql;

use Doctrine\DBAL\Driver;
use Doctrine\DBAL\Driver\Connection;
use Doctrine\DBAL\Driver\Middleware;
use Doctrine\DBAL\Driver\Middleware\AbstractDriverMiddleware;

/**
* Setup "citext" server extension to be available as we use "citext" type for all bound string variables.
*
* Based on https://github.com/doctrine/dbal/blob/3.6.5/src/Driver/OCI8/Middleware/InitializeSession.php
*/
class InitializeSessionMiddleware implements Middleware
{
#[\Override]
public function wrap(Driver $driver): Driver
{
return new class($driver) extends AbstractDriverMiddleware {
#[\Override]
public function connect(
#[\SensitiveParameter]
array $params
): Connection {
$connection = parent::connect($params);

if ($connection->query('SELECT to_regtype(\'citext\')')->fetchOne() === null) {
// "CREATE EXTENSION IF NOT EXISTS ..." cannot be used as it requires
// CREATE privilege even if the extension is already installed
$connection->query('CREATE EXTENSION citext');
}

return $connection;
}
};
}
}
4 changes: 3 additions & 1 deletion src/Persistence/Sql/Postgresql/PlatformTrait.php
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ private function getCreateCaseInsensitiveDomainsSql(): array
$sqls[] = 'DO' . "\n"
. '$$' . "\n"
. 'BEGIN' . "\n"
. ' CREATE EXTENSION IF NOT EXISTS citext;' . "\n"
. ' IF to_regtype(\'citext\') IS NULL THEN' . "\n"
. ' CREATE EXTENSION citext;' . "\n"
. ' END IF;' . "\n"
. implode("\n", array_map(static function (string $domain): string {
return ' IF to_regtype(\'' . $domain . '\') IS NULL THEN' . "\n"
. ' CREATE DOMAIN ' . $domain . ' AS citext;' . "\n"
Expand Down
12 changes: 6 additions & 6 deletions src/Persistence/Sql/Postgresql/Query.php
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ function ($sqlLeft, $sqlRight) use ($makeSqlFx) {

return $iifByteaSqlFx(
$sqlLeft,
$makeSqlFx($escapeNonUtf8Fx($sqlLeft), $escapeNonUtf8Fx($sqlRight, true)),
$makeSqlFx('cast(' . $sqlLeft . ' as citext)', $sqlRight)
$makeSqlFx($escapeNonUtf8Fx($sqlLeft), $escapeNonUtf8Fx($sqlRight)),
$makeSqlFx('cast(' . $sqlLeft . ' as citext)', 'cast(' . $sqlRight . ' as citext)')
);
}
);
Expand All @@ -80,13 +80,13 @@ function ($sqlLeft, $sqlRight) use ($makeSqlFx) {
#[\Override]
protected function _renderConditionLikeOperator(bool $negated, string $sqlLeft, string $sqlRight): string
{
return $this->_renderConditionConditionalCastToText($sqlLeft, $sqlRight, function ($sqlLeft, $sqlRight) use ($negated) {
return ($negated ? 'not ' : '') . $this->_renderConditionConditionalCastToText($sqlLeft, $sqlRight, function ($sqlLeft, $sqlRight) {
$sqlRightEscaped = 'regexp_replace(' . $sqlRight . ', '
. $this->escapeStringLiteral('(\\\[\\\_%])|(\\\)') . ', '
. $this->escapeStringLiteral('\1\2\2') . ', '
. $this->escapeStringLiteral('g') . ')';

return $sqlLeft . ($negated ? ' not' : '') . ' like ' . $sqlRightEscaped
return $sqlLeft . ' like ' . $sqlRightEscaped
. ' escape ' . $this->escapeStringLiteral('\\');
});
}
Expand All @@ -95,8 +95,8 @@ protected function _renderConditionLikeOperator(bool $negated, string $sqlLeft,
#[\Override]
protected function _renderConditionRegexpOperator(bool $negated, string $sqlLeft, string $sqlRight, bool $binary = false): string
{
return $this->_renderConditionConditionalCastToText($sqlLeft, $sqlRight, static function ($sqlLeft, $sqlRight) use ($negated) {
return $sqlLeft . ' ' . ($negated ? '!' : '') . '~ ' . $sqlRight;
return ($negated ? 'not ' : '') . $this->_renderConditionConditionalCastToText($sqlLeft, $sqlRight, static function ($sqlLeft, $sqlRight) {
return $sqlLeft . ' ~ ' . $sqlRight;
});
}

Expand Down
10 changes: 8 additions & 2 deletions src/Schema/TestCase.php
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ protected function logQuery(string $sql, array $params, array $types): void
$i = 0;
$quotedTokenRegex = $this->getConnection()->expr()::QUOTED_TOKEN_REGEX;
$sql = preg_replace_callback(
'~' . $quotedTokenRegex . '\K|(\?)|cast\((\?|:\w+) as (BOOLEAN|INTEGER|BIGINT|DOUBLE PRECISION|BINARY_DOUBLE)\)|\((\?|:\w+) \+ 0\.00\)~',
'~' . $quotedTokenRegex . '\K|(\?)|cast\((\?|:\w+) as (BOOLEAN|INTEGER|BIGINT|DOUBLE PRECISION|BINARY_DOUBLE|citext|bytea|unknown)\)|\((\?|:\w+) \+ 0\.00\)~',
static function ($matches) use (&$types, &$params, &$i) {
if ($matches[0] === '') {
return '';
Expand All @@ -113,7 +113,9 @@ static function ($matches) use (&$types, &$params, &$i) {
return $matches[0];
}

$k = isset($matches[4]) ? ($matches[4] === '?' ? ++$i : $matches[4]) : ($matches[2] === '?' ? ++$i : $matches[2]);
$k = isset($matches[4])
? ($matches[4] === '?' ? ++$i : $matches[4])
: ($matches[2] === '?' ? ++$i : $matches[2]);

if ($matches[3] === 'BOOLEAN' && ($types[$k] === ParameterType::BOOLEAN || $types[$k] === ParameterType::INTEGER)
&& (is_bool($params[$k]) || $params[$k] === '0' || $params[$k] === '1')
Expand All @@ -131,6 +133,10 @@ static function ($matches) use (&$types, &$params, &$i) {
$params[$k] = (float) $params[$k];

return $matches[4] ?? $matches[2];
} elseif (($matches[3] === 'citext' || $matches[3] === 'bytea') && is_string($params[$k])) {
return $matches[2];
} elseif ($matches[3] === 'unknown' && $params[$k] === null) {
return $matches[2];
}

return $matches[0];
Expand Down
2 changes: 1 addition & 1 deletion tests/JoinSqlTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ public function testJoinActualFieldNamesAndPrefix(): void
$j->addField('phone', ['actual' => 'contact_phone']);
// reverse join
$j2 = $user->join('salaries.' . $userForeignIdFieldName, ['prefix' => 'j2_']);
$j2->addField('salary', ['actual' => 'amount']);
$j2->addField('salary', ['actual' => 'amount', 'type' => 'integer']);

// update
$user2 = $user->load(1);
Expand Down
Loading

0 comments on commit 3bc5dce

Please sign in to comment.