From 03b706302f3c930ab6af54fe513effcd44f41960 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 20 Dec 2024 19:38:45 +0200 Subject: [PATCH 01/17] Updated package name, added docs --- dev_requirements.txt | 2 +- docs/advanced_features.rst | 33 +++++++++++++++++++++++++++++++++ setup.py | 2 +- 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 945afc35dc..41cd08e166 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -16,4 +16,4 @@ uvloop vulture>=2.3.0 wheel>=0.30.0 numpy>=1.24.0 -redispy-entraid-credentials @ git+https://github.com/redis-developer/redispy-entra-credentials.git/@main +redis-entraid @ git+https://github.com/redis-developer/redispy-entra-credentials.git/@main diff --git a/docs/advanced_features.rst b/docs/advanced_features.rst index de645bd764..e6bed9b94d 100644 --- a/docs/advanced_features.rst +++ b/docs/advanced_features.rst @@ -466,3 +466,36 @@ command is received. >>> with r.monitor() as m: >>> for command in m.listen(): >>> print(command) + + +Token-based authentication +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Since redis-py version 5.3.0 new StreamableCredentialProvider interface was introduced. +This interface describes a CredentialProvider with an ability to stream an events that will be handled by listener. + +To keep redis-py with minimal dependencies needed to run it, we decided to separate StreamableCredentialProvider +implementations in a separate packages. So If you're interested to try this feature please add them as a separate +dependency to your project. + +`EntraIdCredentialProvider` is a first implementation that allows you to integrate redis-py with Azure Cache for Redis +service. It will allows you to obtain a tokens from Microsoft EntraID and authenticate/re-authenticate your connections +with it in a background mode. + +To get `EntraIdCredentialProvider` you need to install following package: + +`pip install redispy-entraid-credentials` + +To setup a credential provider, first you have to create and configure an IdentityProvider and provide +TokenAuthConfig object. +`Here's a quick guide how to do this +`_ + +Now all you have to do is to pass an instance of `EntraIdCredentialProvider` via constructor, +available for sync and async clients: + +.. code:: python + + >>> cred_provider = EntraIdCredentialProvider(auth_config) + >>> r = Redis(credential_provider=cred_provider) + >>> r.ping() diff --git a/setup.py b/setup.py index ee3a7c2023..1a43ad7bdb 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="5.1.1", + version="5.3.0b1", packages=find_packages( include=[ "redis", From eaf7b64a207171bb2bee8f86039f5487a87138c6 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 20 Dec 2024 20:10:24 +0200 Subject: [PATCH 02/17] Fixed TBA documentation --- docs/advanced_features.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/advanced_features.rst b/docs/advanced_features.rst index e6bed9b94d..10b7b4681b 100644 --- a/docs/advanced_features.rst +++ b/docs/advanced_features.rst @@ -484,7 +484,7 @@ with it in a background mode. To get `EntraIdCredentialProvider` you need to install following package: -`pip install redispy-entraid-credentials` +`pip install redis-entraid` To setup a credential provider, first you have to create and configure an IdentityProvider and provide TokenAuthConfig object. From 3b9fc63f3abfe03a071982f845510733a90179e6 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 20 Dec 2024 20:14:26 +0200 Subject: [PATCH 03/17] Updated package version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1a43ad7bdb..ec20600aaa 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="5.3.0b1", + version="5.3.0b2", packages=find_packages( include=[ "redis", From 0b7c4f8e0f44bc866fcb754ffd172bea1204f6ed Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 20 Dec 2024 20:19:17 +0200 Subject: [PATCH 04/17] Updated package name to b3 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ec20600aaa..ac6905be54 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="5.3.0b2", + version="5.3.0b3", packages=find_packages( include=[ "redis", From 80561988786bb0a296ba49f8678fa1127028b412 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 23 Dec 2024 12:13:15 +0200 Subject: [PATCH 05/17] Updated package version --- .github/workflows/pypi-publish.yaml | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pypi-publish.yaml b/.github/workflows/pypi-publish.yaml index 108dfa6da5..e4815aa1b5 100644 --- a/.github/workflows/pypi-publish.yaml +++ b/.github/workflows/pypi-publish.yaml @@ -3,6 +3,7 @@ name: Publish tag to Pypi on: release: types: [published] + workflow_dispatch: permissions: contents: read # to fetch code (actions/checkout) diff --git a/setup.py b/setup.py index c29bfb1a97..167cd5ee07 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="5.3.0b3", + version="5.3.0b4", packages=find_packages( include=[ "redis", From 28964c1ec4fc481141f6025248845c5e22588a41 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Tue, 11 Feb 2025 15:58:39 +0200 Subject: [PATCH 06/17] Backport from master (5.3.0b5) (#3506) * Fixed flacky TokenManager test (#3468) * Fixed flacky TokenManager test * Fixed additional flacky test * Removed token count assertion * Skipped test on version 3.9 * Fix incorrect attribute reuse (#3456) add CacheEntry Co-authored-by: zhousheng06 Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * Expand type for EncodedT (#3472) As of PEP 688, type checkers will no longer implicitly consider bytearray to be compatible with bytes * Moved self._lock initialisation to Pool constructor (#3473) * Moved self._lock initialisation to Pool constructor * Added test case * Codestyle fixes * Added correct annotations * DOC-4423: add TCEs for various command pages (#3476) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * DOC-4345 added testable JSON search examples for home page (#3407) * DOC-4345 added testable JSON search examples for home page * DOC-4345 avoid possible non-deterministic results in tests * DOC-4345 close connection at end of example * DOC-4345 remove unnecessary blank lines * Adding unit text fixes to improve compatibility with MacOS. (#3486) * Adding unit text fixes to improve compatibility with MacOS. * Applying review comments * Unifying the exception msg validation pattern for both test_connection.py files --------- Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * Add return type to `close` functions (#3496) * Add types to ConnectionPool.from_url (#3495) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * Add types to execute method of pipelines (#3494) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * DOC-4796 fixed capped lists example (#3493) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * typing for client __init__ (#3357) * typing for client __init__ * typing with string literals * retry_on_error more specific typing * retry typing * fix lint --------- Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * test: Updated CredentialProvider test infrastructure (#3502) * test: Updated CredentialProvider test infrastructure * Added linter exclusion * Updated dev dependency * Codestyle fixes * Updated async test infra * Added missing constant * Updated package version * Updated testing versions and docs * Updated server versions * Fixed test --------- Co-authored-by: zs-neo <48560952+zs-neo@users.noreply.github.com> Co-authored-by: zhousheng06 Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Co-authored-by: David Dougherty Co-authored-by: andy-stark-redis <164213578+andy-stark-redis@users.noreply.github.com> Co-authored-by: petyaslavova Co-authored-by: Patrick Arminio Co-authored-by: Artur Mostowski --- .github/actions/run-tests/action.yml | 6 +- .github/workflows/integration.yaml | 4 +- dev_requirements.txt | 2 +- docker-compose.yml | 3 +- docs/advanced_features.rst | 15 ++- doctests/cmds_cnxmgmt.py | 36 +++++++ doctests/cmds_hash.py | 24 +++++ doctests/cmds_list.py | 123 +++++++++++++++++++++++ doctests/cmds_servermgmt.py | 30 ++++++ doctests/cmds_set.py | 35 +++++++ doctests/dt_list.py | 6 +- doctests/home_json.py | 137 ++++++++++++++++++++++++++ redis/asyncio/client.py | 2 +- redis/client.py | 104 ++++++++++--------- redis/cluster.py | 6 +- redis/connection.py | 15 ++- redis/typing.py | 2 +- setup.py | 2 +- tests/conftest.py | 99 ++++++++++++------- tests/test_asyncio/conftest.py | 97 +++++++++++------- tests/test_asyncio/test_connection.py | 30 +++--- tests/test_auth/test_token_manager.py | 36 +++---- tests/test_commands.py | 2 +- tests/test_connection.py | 25 ++--- tests/test_connection_pool.py | 28 +++++- tests/test_multiprocessing.py | 4 + 26 files changed, 670 insertions(+), 203 deletions(-) create mode 100644 doctests/cmds_cnxmgmt.py create mode 100644 doctests/cmds_list.py create mode 100644 doctests/cmds_servermgmt.py create mode 100644 doctests/cmds_set.py create mode 100644 doctests/home_json.py diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index 5ca6bf5a09..e5dcef03ff 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -56,9 +56,9 @@ runs: # Mapping of redis version to stack version declare -A redis_stack_version_mapping=( - ["7.4.1"]="7.4.0-v1" - ["7.2.6"]="7.2.0-v13" - ["6.2.16"]="6.2.6-v17" + ["7.4.2"]="7.4.0-v3" + ["7.2.7"]="7.2.0-v15" + ["6.2.17"]="6.2.6-v19" ) if [[ -v redis_stack_version_mapping[$REDIS_VERSION] ]]; then diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 7c74de5290..7e92cfb92d 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -29,7 +29,7 @@ env: COVERAGE_CORE: sysmon REDIS_IMAGE: redis:latest REDIS_STACK_IMAGE: redis/redis-stack-server:latest - CURRENT_REDIS_VERSION: '7.4.1' + CURRENT_REDIS_VERSION: '7.4.2' jobs: dependency-audit: @@ -74,7 +74,7 @@ jobs: max-parallel: 15 fail-fast: false matrix: - redis-version: [ '${{ needs.redis_version.outputs.CURRENT }}', '7.2.6', '6.2.16'] + redis-version: [ '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] python-version: ['3.8', '3.12'] parser-backend: ['plain'] event-loop: ['asyncio'] diff --git a/dev_requirements.txt b/dev_requirements.txt index be74470ec2..728536d6fb 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -16,4 +16,4 @@ uvloop vulture>=2.3.0 wheel>=0.30.0 numpy>=1.24.0 -redis-entraid==0.1.0b1 +redis-entraid==0.3.0b1 diff --git a/docker-compose.yml b/docker-compose.yml index 7804f09c8a..60657d5653 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -103,7 +103,7 @@ services: - all redis-stack: - image: ${REDIS_STACK_IMAGE:-redis/redis-stack-server:edge} + image: ${REDIS_STACK_IMAGE:-redis/redis-stack-server:latest} container_name: redis-stack ports: - 6479:6379 @@ -112,6 +112,7 @@ services: profiles: - standalone - all-stack + - all redis-stack-graph: image: redis/redis-stack-server:6.2.6-v15 diff --git a/docs/advanced_features.rst b/docs/advanced_features.rst index 10b7b4681b..cebf241e6c 100644 --- a/docs/advanced_features.rst +++ b/docs/advanced_features.rst @@ -471,31 +471,30 @@ command is received. Token-based authentication ~~~~~~~~~~~~~~~~~~~~~~~~~~ -Since redis-py version 5.3.0 new StreamableCredentialProvider interface was introduced. -This interface describes a CredentialProvider with an ability to stream an events that will be handled by listener. +Since redis-py version 5.3.0 new `StreamableCredentialProvider` interface was introduced. +This interface describes a `CredentialProvider` with an ability to stream an events that will be handled by listener. -To keep redis-py with minimal dependencies needed to run it, we decided to separate StreamableCredentialProvider +To keep redis-py with minimal dependencies needed to run it, we decided to separate `StreamableCredentialProvider` implementations in a separate packages. So If you're interested to try this feature please add them as a separate dependency to your project. `EntraIdCredentialProvider` is a first implementation that allows you to integrate redis-py with Azure Cache for Redis -service. It will allows you to obtain a tokens from Microsoft EntraID and authenticate/re-authenticate your connections +service. It will allows you to obtain a tokens from `Microsoft EntraID` and authenticate/re-authenticate your connections with it in a background mode. To get `EntraIdCredentialProvider` you need to install following package: `pip install redis-entraid` -To setup a credential provider, first you have to create and configure an IdentityProvider and provide -TokenAuthConfig object. +To setup a credential provider, please use one of the factory methods bundled with package. `Here's a quick guide how to do this -`_ +`_ Now all you have to do is to pass an instance of `EntraIdCredentialProvider` via constructor, available for sync and async clients: .. code:: python - >>> cred_provider = EntraIdCredentialProvider(auth_config) + >>> cred_provider = create_from_service_principal(CLIENT_ID, CLIENT_SECRET, TENANT_ID) >>> r = Redis(credential_provider=cred_provider) >>> r.ping() diff --git a/doctests/cmds_cnxmgmt.py b/doctests/cmds_cnxmgmt.py new file mode 100644 index 0000000000..c691f723cf --- /dev/null +++ b/doctests/cmds_cnxmgmt.py @@ -0,0 +1,36 @@ +# EXAMPLE: cmds_cnxmgmt +# HIDE_START +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# STEP_START auth1 +# REMOVE_START +r.config_set("requirepass", "temp_pass") +# REMOVE_END +res1 = r.auth(password="temp_pass") +print(res1) # >>> True + +res2 = r.auth(password="temp_pass", username="default") +print(res2) # >>> True + +# REMOVE_START +assert res1 == True +assert res2 == True +r.config_set("requirepass", "") +# REMOVE_END +# STEP_END + +# STEP_START auth2 +# REMOVE_START +r.acl_setuser("test-user", enabled=True, passwords=["+strong_password"], commands=["+acl"]) +# REMOVE_END +res = r.auth(username="test-user", password="strong_password") +print(res) # >>> True + +# REMOVE_START +assert res == True +r.acl_deluser("test-user") +# REMOVE_END +# STEP_END diff --git a/doctests/cmds_hash.py b/doctests/cmds_hash.py index 0bc1cb8038..65bbd52d60 100644 --- a/doctests/cmds_hash.py +++ b/doctests/cmds_hash.py @@ -61,3 +61,27 @@ r.delete("myhash") # REMOVE_END # STEP_END + +# STEP_START hgetall +res10 = r.hset("myhash", mapping={"field1": "Hello", "field2": "World"}) + +res11 = r.hgetall("myhash") +print(res11) # >>> { "field1": "Hello", "field2": "World" } + +# REMOVE_START +assert res11 == { "field1": "Hello", "field2": "World" } +r.delete("myhash") +# REMOVE_END +# STEP_END + +# STEP_START hvals +res10 = r.hset("myhash", mapping={"field1": "Hello", "field2": "World"}) + +res11 = r.hvals("myhash") +print(res11) # >>> [ "Hello", "World" ] + +# REMOVE_START +assert res11 == [ "Hello", "World" ] +r.delete("myhash") +# REMOVE_END +# STEP_END \ No newline at end of file diff --git a/doctests/cmds_list.py b/doctests/cmds_list.py new file mode 100644 index 0000000000..cce2d540a8 --- /dev/null +++ b/doctests/cmds_list.py @@ -0,0 +1,123 @@ +# EXAMPLE: cmds_list +# HIDE_START +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# STEP_START lpush +res1 = r.lpush("mylist", "world") +print(res1) # >>> 1 + +res2 = r.lpush("mylist", "hello") +print(res2) # >>> 2 + +res3 = r.lrange("mylist", 0, -1) +print(res3) # >>> [ "hello", "world" ] + +# REMOVE_START +assert res3 == [ "hello", "world" ] +r.delete("mylist") +# REMOVE_END +# STEP_END + +# STEP_START lrange +res4 = r.rpush("mylist", "one"); +print(res4) # >>> 1 + +res5 = r.rpush("mylist", "two") +print(res5) # >>> 2 + +res6 = r.rpush("mylist", "three") +print(res6) # >>> 3 + +res7 = r.lrange('mylist', 0, 0) +print(res7) # >>> [ 'one' ] + +res8 = r.lrange('mylist', -3, 2) +print(res8) # >>> [ 'one', 'two', 'three' ] + +res9 = r.lrange('mylist', -100, 100) +print(res9) # >>> [ 'one', 'two', 'three' ] + +res10 = r.lrange('mylist', 5, 10) +print(res10) # >>> [] + +# REMOVE_START +assert res7 == [ 'one' ] +assert res8 == [ 'one', 'two', 'three' ] +assert res9 == [ 'one', 'two', 'three' ] +assert res10 == [] +r.delete('mylist') +# REMOVE_END +# STEP_END + +# STEP_START llen +res11 = r.lpush("mylist", "World") +print(res11) # >>> 1 + +res12 = r.lpush("mylist", "Hello") +print(res12) # >>> 2 + +res13 = r.llen("mylist") +print(res13) # >>> 2 + +# REMOVE_START +assert res13 == 2 +r.delete("mylist") +# REMOVE_END +# STEP_END + +# STEP_START rpush +res14 = r.rpush("mylist", "hello") +print(res14) # >>> 1 + +res15 = r.rpush("mylist", "world") +print(res15) # >>> 2 + +res16 = r.lrange("mylist", 0, -1) +print(res16) # >>> [ "hello", "world" ] + +# REMOVE_START +assert res16 == [ "hello", "world" ] +r.delete("mylist") +# REMOVE_END +# STEP_END + +# STEP_START lpop +res17 = r.rpush("mylist", *["one", "two", "three", "four", "five"]) +print(res17) # >>> 5 + +res18 = r.lpop("mylist") +print(res18) # >>> "one" + +res19 = r.lpop("mylist", 2) +print(res19) # >>> ['two', 'three'] + +res17 = r.lrange("mylist", 0, -1) +print(res17) # >>> [ "four", "five" ] + +# REMOVE_START +assert res17 == [ "four", "five" ] +r.delete("mylist") +# REMOVE_END +# STEP_END + +# STEP_START rpop +res18 = r.rpush("mylist", *["one", "two", "three", "four", "five"]) +print(res18) # >>> 5 + +res19 = r.rpop("mylist") +print(res19) # >>> "five" + +res20 = r.rpop("mylist", 2) +print(res20) # >>> ['four', 'three'] + +res21 = r.lrange("mylist", 0, -1) +print(res21) # >>> [ "one", "two" ] + +# REMOVE_START +assert res21 == [ "one", "two" ] +r.delete("mylist") +# REMOVE_END +# STEP_END \ No newline at end of file diff --git a/doctests/cmds_servermgmt.py b/doctests/cmds_servermgmt.py new file mode 100644 index 0000000000..6ad2b6acb2 --- /dev/null +++ b/doctests/cmds_servermgmt.py @@ -0,0 +1,30 @@ +# EXAMPLE: cmds_servermgmt +# HIDE_START +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# STEP_START flushall +# REMOVE_START +r.set("foo", "1") +r.set("bar", "2") +r.set("baz", "3") +# REMOVE_END +res1 = r.flushall(asynchronous=False) +print(res1) # >>> True + +res2 = r.keys() +print(res2) # >>> [] + +# REMOVE_START +assert res1 == True +assert res2 == [] +# REMOVE_END +# STEP_END + +# STEP_START info +res3 = r.info() +print(res3) +# >>> {'redis_version': '7.4.0', 'redis_git_sha1': 'c9d29f6a',...} +# STEP_END \ No newline at end of file diff --git a/doctests/cmds_set.py b/doctests/cmds_set.py new file mode 100644 index 0000000000..ece74e8cf0 --- /dev/null +++ b/doctests/cmds_set.py @@ -0,0 +1,35 @@ +# EXAMPLE: cmds_set +# HIDE_START +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# STEP_START sadd +res1 = r.sadd("myset", "Hello", "World") +print(res1) # >>> 2 + +res2 = r.sadd("myset", "World") +print(res2) # >>> 0 + +res3 = r.smembers("myset") +print(res3) # >>> {'Hello', 'World'} + +# REMOVE_START +assert res3 == {'Hello', 'World'} +r.delete('myset') +# REMOVE_END +# STEP_END + +# STEP_START smembers +res4 = r.sadd("myset", "Hello", "World") +print(res4) # >>> 2 + +res5 = r.smembers("myset") +print(res5) # >>> {'Hello', 'World'} + +# REMOVE_START +assert res5 == {'Hello', 'World'} +r.delete('myset') +# REMOVE_END +# STEP_END \ No newline at end of file diff --git a/doctests/dt_list.py b/doctests/dt_list.py index be8a4b8562..111da8eb08 100644 --- a/doctests/dt_list.py +++ b/doctests/dt_list.py @@ -165,20 +165,20 @@ # REMOVE_END # STEP_START ltrim -res27 = r.lpush("bikes:repairs", "bike:1", "bike:2", "bike:3", "bike:4", "bike:5") +res27 = r.rpush("bikes:repairs", "bike:1", "bike:2", "bike:3", "bike:4", "bike:5") print(res27) # >>> 5 res28 = r.ltrim("bikes:repairs", 0, 2) print(res28) # >>> True res29 = r.lrange("bikes:repairs", 0, -1) -print(res29) # >>> ['bike:5', 'bike:4', 'bike:3'] +print(res29) # >>> ['bike:1', 'bike:2', 'bike:3'] # STEP_END # REMOVE_START assert res27 == 5 assert res28 is True -assert res29 == ["bike:5", "bike:4", "bike:3"] +assert res29 == ["bike:1", "bike:2", "bike:3"] r.delete("bikes:repairs") # REMOVE_END diff --git a/doctests/home_json.py b/doctests/home_json.py new file mode 100644 index 0000000000..922c83d2fe --- /dev/null +++ b/doctests/home_json.py @@ -0,0 +1,137 @@ +# EXAMPLE: py_home_json +""" +JSON examples from redis-py "home" page" + https://redis.io/docs/latest/develop/connect/clients/python/redis-py/#example-indexing-and-querying-json-documents +""" + +# STEP_START import +import redis +from redis.commands.json.path import Path +import redis.commands.search.aggregation as aggregations +import redis.commands.search.reducers as reducers +from redis.commands.search.field import TextField, NumericField, TagField +from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.query import Query +import redis.exceptions +# STEP_END + +# STEP_START connect +r = redis.Redis(decode_responses=True) +# STEP_END + +# REMOVE_START +try: + r.ft("idx:users").dropindex(True) +except redis.exceptions.ResponseError: + pass + +r.delete("user:1", "user:2", "user:3") +# REMOVE_END +# STEP_START create_data +user1 = { + "name": "Paul John", + "email": "paul.john@example.com", + "age": 42, + "city": "London" +} + +user2 = { + "name": "Eden Zamir", + "email": "eden.zamir@example.com", + "age": 29, + "city": "Tel Aviv" +} + +user3 = { + "name": "Paul Zamir", + "email": "paul.zamir@example.com", + "age": 35, + "city": "Tel Aviv" +} +# STEP_END + +# STEP_START make_index +schema = ( + TextField("$.name", as_name="name"), + TagField("$.city", as_name="city"), + NumericField("$.age", as_name="age") +) + +indexCreated = r.ft("idx:users").create_index( + schema, + definition=IndexDefinition( + prefix=["user:"], index_type=IndexType.JSON + ) +) +# STEP_END +# Tests for 'make_index' step. +# REMOVE_START +assert indexCreated +# REMOVE_END + +# STEP_START add_data +user1Set = r.json().set("user:1", Path.root_path(), user1) +user2Set = r.json().set("user:2", Path.root_path(), user2) +user3Set = r.json().set("user:3", Path.root_path(), user3) +# STEP_END +# Tests for 'add_data' step. +# REMOVE_START +assert user1Set +assert user2Set +assert user3Set +# REMOVE_END + +# STEP_START query1 +findPaulResult = r.ft("idx:users").search( + Query("Paul @age:[30 40]") +) + +print(findPaulResult) +# >>> Result{1 total, docs: [Document {'id': 'user:3', ... +# STEP_END +# Tests for 'query1' step. +# REMOVE_START +assert str(findPaulResult) == ( + "Result{1 total, docs: [Document {'id': 'user:3', 'payload': None, " + + "'json': '{\"name\":\"Paul Zamir\",\"email\":" + + "\"paul.zamir@example.com\",\"age\":35,\"city\":\"Tel Aviv\"}'}]}" +) +# REMOVE_END + +# STEP_START query2 +citiesResult = r.ft("idx:users").search( + Query("Paul").return_field("$.city", as_field="city") +).docs + +print(citiesResult) +# >>> [Document {'id': 'user:1', 'payload': None, ... +# STEP_END +# Tests for 'query2' step. +# REMOVE_START +citiesResult.sort(key=lambda doc: doc['id']) + +assert str(citiesResult) == ( + "[Document {'id': 'user:1', 'payload': None, 'city': 'London'}, " + + "Document {'id': 'user:3', 'payload': None, 'city': 'Tel Aviv'}]" +) +# REMOVE_END + +# STEP_START query3 +req = aggregations.AggregateRequest("*").group_by( + '@city', reducers.count().alias('count') +) + +aggResult = r.ft("idx:users").aggregate(req).rows +print(aggResult) +# >>> [['city', 'London', 'count', '1'], ['city', 'Tel Aviv', 'count', '2']] +# STEP_END +# Tests for 'query3' step. +# REMOVE_START +aggResult.sort(key=lambda row: row[1]) + +assert str(aggResult) == ( + "[['city', 'London', 'count', '1'], ['city', 'Tel Aviv', 'count', '2']]" +) +# REMOVE_END + +r.close() diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 9478d539d7..7c17938714 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -1554,7 +1554,7 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception): await self.reset() raise - async def execute(self, raise_on_error: bool = True): + async def execute(self, raise_on_error: bool = True) -> List[Any]: """Execute all the commands in the current pipeline""" stack = self.command_stack if not stack and not self.watching: diff --git a/redis/client.py b/redis/client.py index a7c1364a10..5a9f4fafb5 100755 --- a/redis/client.py +++ b/redis/client.py @@ -4,7 +4,17 @@ import time import warnings from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Type, + Union, +) from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( @@ -53,6 +63,11 @@ str_if_bytes, ) +if TYPE_CHECKING: + import ssl + + import OpenSSL + SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" @@ -175,47 +190,47 @@ def from_pool( def __init__( self, - host="localhost", - port=6379, - db=0, - password=None, - socket_timeout=None, - socket_connect_timeout=None, - socket_keepalive=None, - socket_keepalive_options=None, - connection_pool=None, - unix_socket_path=None, - encoding="utf-8", - encoding_errors="strict", - charset=None, - errors=None, - decode_responses=False, - retry_on_timeout=False, - retry_on_error=None, - ssl=False, - ssl_keyfile=None, - ssl_certfile=None, - ssl_cert_reqs="required", - ssl_ca_certs=None, - ssl_ca_path=None, - ssl_ca_data=None, - ssl_check_hostname=False, - ssl_password=None, - ssl_validate_ocsp=False, - ssl_validate_ocsp_stapled=False, - ssl_ocsp_context=None, - ssl_ocsp_expected_cert=None, - ssl_min_version=None, - ssl_ciphers=None, - max_connections=None, - single_connection_client=False, - health_check_interval=0, - client_name=None, - lib_name="redis-py", - lib_version=get_lib_version(), - username=None, - retry=None, - redis_connect_func=None, + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + charset: Optional[str] = None, + errors: Optional[str] = None, + decode_responses: bool = False, + retry_on_timeout: bool = False, + retry_on_error: Optional[List[Type[Exception]]] = None, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + ssl_ca_path: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_check_hostname: bool = False, + ssl_password: Optional[str] = None, + ssl_validate_ocsp: bool = False, + ssl_validate_ocsp_stapled: bool = False, + ssl_ocsp_context: Optional["OpenSSL.SSL.Context"] = None, + ssl_ocsp_expected_cert: Optional[str] = None, + ssl_min_version: Optional["ssl.TLSVersion"] = None, + ssl_ciphers: Optional[str] = None, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Optional[Retry] = None, + redis_connect_func: Optional[Callable[[], None]] = None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, cache: Optional[CacheInterface] = None, @@ -550,7 +565,7 @@ def __exit__(self, exc_type, exc_value, traceback): def __del__(self): self.close() - def close(self): + def close(self) -> None: # In case a connection property does not yet exist # (due to a crash earlier in the Redis() constructor), return # immediately as there is nothing to clean-up. @@ -1551,11 +1566,10 @@ def _disconnect_raise_reset( conn.retry_on_error is None or isinstance(error, tuple(conn.retry_on_error)) is False ): - self.reset() raise error - def execute(self, raise_on_error=True): + def execute(self, raise_on_error: bool = True) -> List[Any]: """Execute all the commands in the current pipeline""" stack = self.command_stack if not stack and not self.watching: diff --git a/redis/cluster.py b/redis/cluster.py index 38bd5dde1a..e8f47afe25 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1244,7 +1244,7 @@ def _execute_command(self, target_node, *args, **kwargs): raise ClusterError("TTL exhausted.") - def close(self): + def close(self) -> None: try: with self._lock: if self.nodes_manager: @@ -1686,7 +1686,7 @@ def initialize(self): # If initialize was called after a MovedError, clear it self._moved_exception = None - def close(self): + def close(self) -> None: self.default_node = None for node in self.nodes_cache.values(): if node.redis_connection: @@ -2067,7 +2067,7 @@ def annotate_exception(self, exception, number, command): ) exception.args = (msg,) + exception.args[1:] - def execute(self, raise_on_error=True): + def execute(self, raise_on_error: bool = True) -> List[Any]: """ Execute all the commands in the current pipeline """ diff --git a/redis/connection.py b/redis/connection.py index 9d29b4aba6..d47f46590b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -9,7 +9,7 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from urllib.parse import parse_qs, unquote, urlparse from redis.cache import ( @@ -904,9 +904,11 @@ def read_response( and self._cache.get(self._current_command_cache_key).status != CacheEntryStatus.IN_PROGRESS ): - return copy.deepcopy( + res = copy.deepcopy( self._cache.get(self._current_command_cache_key).cache_value ) + self._current_command_cache_key = None + return res response = self._conn.read_response( disable_decoding=disable_decoding, @@ -932,6 +934,8 @@ def read_response( cache_entry.cache_value = response self._cache.set(cache_entry) + self._current_command_cache_key = None + return response def pack_command(self, *args): @@ -1259,6 +1263,9 @@ def parse_url(url): return kwargs +_CP = TypeVar("_CP", bound="ConnectionPool") + + class ConnectionPool: """ Create a connection pool. ``If max_connections`` is set, then this @@ -1274,7 +1281,7 @@ class ConnectionPool: """ @classmethod - def from_url(cls, url, **kwargs): + def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: """ Return a connection pool configured from the given URL. @@ -1374,6 +1381,7 @@ def __init__( # will notice the first thread already did the work and simply # release the lock. self._fork_lock = threading.Lock() + self._lock = threading.Lock() self.reset() def __repr__(self) -> (str, str): @@ -1391,7 +1399,6 @@ def get_protocol(self): return self.connection_kwargs.get("protocol", None) def reset(self) -> None: - self._lock = threading.Lock() self._created_connections = 0 self._available_connections = [] self._in_use_connections = set() diff --git a/redis/typing.py b/redis/typing.py index b4d442c444..24ad607480 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -20,7 +20,7 @@ Number = Union[int, float] -EncodedT = Union[bytes, memoryview] +EncodedT = Union[bytes, bytearray, memoryview] DecodedT = Union[str, int, float] EncodableT = Union[EncodedT, DecodedT] AbsExpiryT = Union[int, datetime] diff --git a/setup.py b/setup.py index 167cd5ee07..81bbedfe9f 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="5.3.0b4", + version="5.3.0b5", packages=find_packages( include=[ "redis", diff --git a/tests/conftest.py b/tests/conftest.py index a900cea8bf..fc732c0d72 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import time from datetime import datetime, timezone from enum import Enum -from typing import Callable, TypeVar +from typing import Callable, TypeVar, Union from unittest import mock from unittest.mock import Mock from urllib.parse import urlparse @@ -17,6 +17,7 @@ from redis import Sentinel from redis.auth.idp import IdentityProviderInterface from redis.auth.token import JWToken +from redis.auth.token_manager import RetryPolicy, TokenManagerConfig from redis.backoff import NoBackoff from redis.cache import ( CacheConfig, @@ -29,12 +30,21 @@ from redis.credentials import CredentialProvider from redis.exceptions import RedisClusterException from redis.retry import Retry -from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig +from redis_entraid.cred_provider import ( + DEFAULT_DELAY_IN_MS, + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_MAX_ATTEMPTS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + EntraIdCredentialsProvider, +) from redis_entraid.identity_provider import ( ManagedIdentityIdType, + ManagedIdentityProviderConfig, ManagedIdentityType, - create_provider_from_managed_identity, - create_provider_from_service_principal, + ServicePrincipalIdentityProviderConfig, + _create_provider_from_managed_identity, + _create_provider_from_service_principal, ) from tests.ssl_utils import get_tls_certificates @@ -623,17 +633,33 @@ def identity_provider(request) -> IdentityProviderInterface: return mock_identity_provider() auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + config = get_identity_provider_config(request=request) if auth_type == "MANAGED_IDENTITY": - return _get_managed_identity_provider(request) + return _create_provider_from_managed_identity(config) + + return _create_provider_from_service_principal(config) - return _get_service_principal_provider(request) +def get_identity_provider_config( + request, +) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) -def _get_managed_identity_provider(request): - authority = os.getenv("AZURE_AUTHORITY") + if auth_type == AuthType.MANAGED_IDENTITY: + return _get_managed_identity_provider_config(request) + + return _get_service_principal_provider_config(request) + + +def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: resource = os.getenv("AZURE_RESOURCE") - id_value = os.getenv("AZURE_ID_VALUE", None) + id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -641,23 +667,24 @@ def _get_managed_identity_provider(request): kwargs = {} identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) - id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID) + id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - return create_provider_from_managed_identity( + return ManagedIdentityProviderConfig( identity_type=identity_type, resource=resource, id_type=id_type, id_value=id_value, - authority=authority, - **kwargs, + kwargs=kwargs, ) -def _get_service_principal_provider(request): +def _get_service_principal_provider_config( + request, +) -> ServicePrincipalIdentityProviderConfig: client_id = os.getenv("AZURE_CLIENT_ID") client_credential = os.getenv("AZURE_CLIENT_SECRET") - authority = os.getenv("AZURE_AUTHORITY") - scopes = os.getenv("AZURE_REDIS_SCOPES", []) + tenant_id = os.getenv("AZURE_TENANT_ID") + scopes = os.getenv("AZURE_REDIS_SCOPES", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -671,14 +698,14 @@ def _get_service_principal_provider(request): if isinstance(scopes, str): scopes = scopes.split(",") - return create_provider_from_service_principal( + return ServicePrincipalIdentityProviderConfig( client_id=client_id, client_credential=client_credential, scopes=scopes, timeout=timeout, token_kwargs=token_kwargs, - authority=authority, - **kwargs, + tenant_id=tenant_id, + app_kwargs=kwargs, ) @@ -690,31 +717,29 @@ def get_credential_provider(request) -> CredentialProvider: return cred_provider_class(**cred_provider_kwargs) idp = identity_provider(request) - initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0) - block_for_initial = cred_provider_kwargs.get("block_for_initial", False) expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO + "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO ) lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS - ) - max_attempts = cred_provider_kwargs.get( - "max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS + "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS ) - delay_in_ms = cred_provider_kwargs.get( - "delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS + max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) + delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) + + token_mgr_config = TokenManagerConfig( + expiration_refresh_ratio=expiration_refresh_ratio, + lower_refresh_bound_millis=lower_refresh_bound_millis, + token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa + retry_policy=RetryPolicy( + max_attempts=max_attempts, + delay_in_ms=delay_in_ms, + ), ) - auth_config = TokenAuthConfig(idp) - auth_config.expiration_refresh_ratio = expiration_refresh_ratio - auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis - auth_config.max_attempts = max_attempts - auth_config.delay_in_ms = delay_in_ms - return EntraIdCredentialsProvider( - config=auth_config, - initial_delay_in_ms=initial_delay_in_ms, - block_for_initial=block_for_initial, + identity_provider=idp, + token_manager_config=token_mgr_config, + initial_delay_in_ms=delay_in_ms, ) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 8833426af1..fb6c51140e 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -17,14 +17,24 @@ from redis.asyncio.retry import Retry from redis.auth.idp import IdentityProviderInterface from redis.auth.token import JWToken +from redis.auth.token_manager import RetryPolicy, TokenManagerConfig from redis.backoff import NoBackoff from redis.credentials import CredentialProvider -from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig +from redis_entraid.cred_provider import ( + DEFAULT_DELAY_IN_MS, + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_MAX_ATTEMPTS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + EntraIdCredentialsProvider, +) from redis_entraid.identity_provider import ( ManagedIdentityIdType, + ManagedIdentityProviderConfig, ManagedIdentityType, - create_provider_from_managed_identity, - create_provider_from_service_principal, + ServicePrincipalIdentityProviderConfig, + _create_provider_from_managed_identity, + _create_provider_from_service_principal, ) from tests.conftest import REDIS_INFO @@ -255,17 +265,33 @@ def identity_provider(request) -> IdentityProviderInterface: return mock_identity_provider() auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + config = get_identity_provider_config(request=request) if auth_type == "MANAGED_IDENTITY": - return _get_managed_identity_provider(request) + return _create_provider_from_managed_identity(config) + + return _create_provider_from_service_principal(config) + + +def get_identity_provider_config( + request, +) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} - return _get_service_principal_provider(request) + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + + if auth_type == AuthType.MANAGED_IDENTITY: + return _get_managed_identity_provider_config(request) + return _get_service_principal_provider_config(request) -def _get_managed_identity_provider(request): - authority = os.getenv("AZURE_AUTHORITY") + +def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: resource = os.getenv("AZURE_RESOURCE") - id_value = os.getenv("AZURE_ID_VALUE", None) + id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -273,23 +299,24 @@ def _get_managed_identity_provider(request): kwargs = {} identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) - id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID) + id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - return create_provider_from_managed_identity( + return ManagedIdentityProviderConfig( identity_type=identity_type, resource=resource, id_type=id_type, id_value=id_value, - authority=authority, - **kwargs, + kwargs=kwargs, ) -def _get_service_principal_provider(request): +def _get_service_principal_provider_config( + request, +) -> ServicePrincipalIdentityProviderConfig: client_id = os.getenv("AZURE_CLIENT_ID") client_credential = os.getenv("AZURE_CLIENT_SECRET") - authority = os.getenv("AZURE_AUTHORITY") - scopes = os.getenv("AZURE_REDIS_SCOPES", []) + tenant_id = os.getenv("AZURE_TENANT_ID") + scopes = os.getenv("AZURE_REDIS_SCOPES", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -303,14 +330,14 @@ def _get_service_principal_provider(request): if isinstance(scopes, str): scopes = scopes.split(",") - return create_provider_from_service_principal( + return ServicePrincipalIdentityProviderConfig( client_id=client_id, client_credential=client_credential, scopes=scopes, timeout=timeout, token_kwargs=token_kwargs, - authority=authority, - **kwargs, + tenant_id=tenant_id, + app_kwargs=kwargs, ) @@ -322,31 +349,29 @@ def get_credential_provider(request) -> CredentialProvider: return cred_provider_class(**cred_provider_kwargs) idp = identity_provider(request) - initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0) - block_for_initial = cred_provider_kwargs.get("block_for_initial", False) expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO + "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO ) lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS - ) - max_attempts = cred_provider_kwargs.get( - "max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS + "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS ) - delay_in_ms = cred_provider_kwargs.get( - "delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS + max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) + delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) + + token_mgr_config = TokenManagerConfig( + expiration_refresh_ratio=expiration_refresh_ratio, + lower_refresh_bound_millis=lower_refresh_bound_millis, + token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa + retry_policy=RetryPolicy( + max_attempts=max_attempts, + delay_in_ms=delay_in_ms, + ), ) - auth_config = TokenAuthConfig(idp) - auth_config.expiration_refresh_ratio = expiration_refresh_ratio - auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis - auth_config.max_attempts = max_attempts - auth_config.delay_in_ms = delay_in_ms - return EntraIdCredentialsProvider( - config=auth_config, - initial_delay_in_ms=initial_delay_in_ms, - block_for_initial=block_for_initial, + identity_provider=idp, + token_manager_config=token_mgr_config, + initial_delay_in_ms=delay_in_ms, ) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index e584fc6999..d4956f16e9 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -1,6 +1,7 @@ import asyncio import socket import types +from errno import ECONNREFUSED from unittest.mock import patch import pytest @@ -36,15 +37,16 @@ async def test_invalid_response(create_redis): fake_stream = MockStream(raw + b"\r\n") parser: _AsyncRESPBase = r.connection._parser - with mock.patch.object(parser, "_stream", fake_stream): - with pytest.raises(InvalidResponse) as cm: - await parser.read_response() + if isinstance(parser, _AsyncRESPBase): - assert str(cm.value) == f"Protocol Error: {raw!r}" + exp_err = f"Protocol Error: {raw!r}" else: - assert ( - str(cm.value) == f'Protocol error, got "{raw.decode()}" as reply type byte' - ) + exp_err = f'Protocol error, got "{raw.decode()}" as reply type byte' + + with mock.patch.object(parser, "_stream", fake_stream): + with pytest.raises(InvalidResponse, match=exp_err): + await parser.read_response() + await r.connection.disconnect() @@ -170,10 +172,9 @@ async def test_connect_timeout_error_without_retry(): conn._connect = mock.AsyncMock() conn._connect.side_effect = socket.timeout - with pytest.raises(TimeoutError) as e: + with pytest.raises(TimeoutError, match="Timeout connecting to server"): await conn.connect() assert conn._connect.call_count == 1 - assert str(e.value) == "Timeout connecting to server" @pytest.mark.onlynoncluster @@ -531,17 +532,14 @@ async def test_format_error_message(conn, error, expected_message): async def test_network_connection_failure(): - with pytest.raises(ConnectionError) as e: + exp_err = rf"^Error {ECONNREFUSED} connecting to 127.0.0.1:9999.(.+)$" + with pytest.raises(ConnectionError, match=exp_err): redis = Redis(host="127.0.0.1", port=9999) await redis.set("a", "b") - assert str(e.value).startswith("Error 111 connecting to 127.0.0.1:9999. Connect") async def test_unix_socket_connection_failure(): - with pytest.raises(ConnectionError) as e: + exp_err = "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." + with pytest.raises(ConnectionError, match=exp_err): redis = Redis(unix_socket_path="unix:///tmp/a.sock") await redis.set("a", "b") - assert ( - str(e.value) - == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." - ) diff --git a/tests/test_auth/test_token_manager.py b/tests/test_auth/test_token_manager.py index bb396e246c..cdbf60889d 100644 --- a/tests/test_auth/test_token_manager.py +++ b/tests/test_auth/test_token_manager.py @@ -17,17 +17,17 @@ class TestTokenManager: @pytest.mark.parametrize( - "exp_refresh_ratio,tokens_refreshed", + "exp_refresh_ratio", [ - (0.9, 2), - (0.28, 4), + 0.9, + 0.28, ], ids=[ - "Refresh ratio = 0.9, 2 tokens in 0,1 second", - "Refresh ratio = 0.28, 4 tokens in 0,1 second", + "Refresh ratio = 0.9", + "Refresh ratio = 0.28", ], ) - def test_success_token_renewal(self, exp_refresh_ratio, tokens_refreshed): + def test_success_token_renewal(self, exp_refresh_ratio): tokens = [] mock_provider = Mock(spec=IdentityProviderInterface) mock_provider.request_token.side_effect = [ @@ -39,14 +39,14 @@ def test_success_token_renewal(self, exp_refresh_ratio, tokens_refreshed): ), SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 130, - (datetime.now(timezone.utc).timestamp() * 1000) + 30, + (datetime.now(timezone.utc).timestamp() * 1000) + 150, + (datetime.now(timezone.utc).timestamp() * 1000) + 50, {"oid": "test"}, ), SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 160, - (datetime.now(timezone.utc).timestamp() * 1000) + 60, + (datetime.now(timezone.utc).timestamp() * 1000) + 170, + (datetime.now(timezone.utc).timestamp() * 1000) + 70, {"oid": "test"}, ), SimpleToken( @@ -70,7 +70,7 @@ def on_next(token): mgr.start(mock_listener) sleep(0.1) - assert len(tokens) == tokens_refreshed + assert len(tokens) > 0 @pytest.mark.parametrize( "exp_refresh_ratio,tokens_refreshed", @@ -176,19 +176,13 @@ def test_token_renewal_with_skip_initial(self): mock_provider.request_token.side_effect = [ SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000) + 50, (datetime.now(timezone.utc).timestamp() * 1000), {"oid": "test"}, ), SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 120, - (datetime.now(timezone.utc).timestamp() * 1000), - {"oid": "test"}, - ), - SimpleToken( - "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 140, + (datetime.now(timezone.utc).timestamp() * 1000) + 150, (datetime.now(timezone.utc).timestamp() * 1000), {"oid": "test"}, ), @@ -207,9 +201,9 @@ def on_next(token): mgr.start(mock_listener, skip_initial=True) # Should be less than a 0.1, or it will be flacky due to # additional token renewal. - sleep(0.2) + sleep(0.1) - assert len(tokens) == 2 + assert len(tokens) == 1 @pytest.mark.asyncio async def test_async_token_renewal_with_skip_initial(self): diff --git a/tests/test_commands.py b/tests/test_commands.py index 4cad4c14b6..0f5a9c7b16 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -4345,7 +4345,7 @@ def test_xgroup_create_entriesread(self, r: redis.Redis): "pending": 0, "last-delivered-id": b"0-0", "entries-read": 7, - "lag": -6, + "lag": 1, } ] assert r.xinfo_groups(stream) == expected diff --git a/tests/test_connection.py b/tests/test_connection.py index 7683a1416d..6c1498a329 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,8 +1,10 @@ import copy import platform import socket +import sys import threading import types +from errno import ECONNREFUSED from typing import Any from unittest import mock from unittest.mock import call, patch @@ -43,9 +45,8 @@ def test_invalid_response(r): raw = b"x" parser = r.connection._parser with mock.patch.object(parser._buffer, "readline", return_value=raw): - with pytest.raises(InvalidResponse) as cm: + with pytest.raises(InvalidResponse, match=f"Protocol Error: {raw!r}"): parser.read_response() - assert str(cm.value) == f"Protocol Error: {raw!r}" @skip_if_server_version_lt("4.0.0") @@ -140,10 +141,9 @@ def test_connect_timeout_error_without_retry(self): conn._connect = mock.Mock() conn._connect.side_effect = socket.timeout - with pytest.raises(TimeoutError) as e: + with pytest.raises(TimeoutError, match="Timeout connecting to server"): conn.connect() assert conn._connect.call_count == 1 - assert str(e.value) == "Timeout connecting to server" self.clear(conn) @@ -249,6 +249,7 @@ def get_redis_connection(): r1.close() +@pytest.mark.skipif(sys.version_info == (3, 9), reason="Flacky test on Python 3.9") @pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) def test_redis_connection_pool(request, from_url): """Verify that basic Redis instances using `connection_pool` @@ -347,20 +348,17 @@ def test_format_error_message(conn, error, expected_message): def test_network_connection_failure(): - with pytest.raises(ConnectionError) as e: + exp_err = f"Error {ECONNREFUSED} connecting to localhost:9999. Connection refused." + with pytest.raises(ConnectionError, match=exp_err): redis = Redis(port=9999) redis.set("a", "b") - assert str(e.value) == "Error 111 connecting to localhost:9999. Connection refused." def test_unix_socket_connection_failure(): - with pytest.raises(ConnectionError) as e: + exp_err = "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." + with pytest.raises(ConnectionError, match=exp_err): redis = Redis(unix_socket_path="unix:///tmp/a.sock") redis.set("a", "b") - assert ( - str(e.value) - == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." - ) class TestUnitConnectionPool: @@ -499,9 +497,9 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): ) proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) assert proxy_connection.read_response() == b"bar" + assert proxy_connection._current_command_cache_key is None assert proxy_connection.read_response() == b"bar" - mock_connection.read_response.assert_called_once() mock_cache.set.assert_has_calls( [ call( @@ -528,9 +526,6 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): call(CacheKey(command="GET", redis_keys=("foo",))), call(CacheKey(command="GET", redis_keys=("foo",))), call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), ] ) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index dee7c554d3..118294ee1b 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -7,10 +7,16 @@ import pytest import redis -from redis.connection import to_bool -from redis.utils import SSL_AVAILABLE - -from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt +from redis.cache import CacheConfig +from redis.connection import CacheProxyConnection, Connection, to_bool +from redis.utils import HIREDIS_AVAILABLE, SSL_AVAILABLE + +from .conftest import ( + _get_client, + skip_if_redis_enterprise, + skip_if_resp_version, + skip_if_server_version_lt, +) from .test_pubsub import wait_for_message @@ -196,6 +202,20 @@ def test_repr_contains_db_info_unix(self): expected = "path=abc,db=0,client_name=test-client" assert expected in repr(pool) + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + @pytest.mark.onlynoncluster + @skip_if_resp_version(2) + @skip_if_server_version_lt("7.4.0") + def test_initialise_pool_with_cache(self, master_host): + pool = redis.BlockingConnectionPool( + connection_class=Connection, + host=master_host[0], + port=master_host[1], + protocol=3, + cache_config=CacheConfig(), + ) + assert isinstance(pool.get_connection("_"), CacheProxyConnection) + class TestConnectionPoolURLParsing: def test_hostname(self): diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 5cda3190a6..116d20dab0 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -1,5 +1,6 @@ import contextlib import multiprocessing +import sys import pytest import redis @@ -8,6 +9,9 @@ from .conftest import _get_client +if sys.platform == "darwin": + multiprocessing.set_start_method("fork", force=True) + @contextlib.contextmanager def exit_callback(callback, *args): From 916afcb76514815ab59f39f1e21ec329e1515a1d Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Tue, 29 Apr 2025 16:48:48 +0300 Subject: [PATCH 07/17] When SlotNotCoveredError is raised, the cluster topology should be reinitialized as part of error handling and retrying of the commands. (#3621) --- redis/asyncio/cluster.py | 8 +++++++- redis/cluster.py | 17 ++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 408fa19363..e1d4651a08 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -799,10 +799,16 @@ async def _execute_command( # and try again with the new setup await self.aclose() raise - except ClusterDownError: + except (ClusterDownError, SlotNotCoveredError): # ClusterDownError can occur during a failover and to get # self-healed, we will try to reinitialize the cluster layout # and retry executing the command + + # SlotNotCoveredError can occur when the cluster is not fully + # initialized or can be temporary issue. + # We will try to reinitialize the cluster topology + # and retry executing the command + await self.aclose() await asyncio.sleep(0.25) raise diff --git a/redis/cluster.py b/redis/cluster.py index e8f47afe25..a7b4678b46 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -424,7 +424,12 @@ class AbstractRedisCluster: list_keys_to_dict(["SCRIPT FLUSH"], lambda command, res: all(res.values())), ) - ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError) + ERRORS_ALLOW_RETRY = ( + ConnectionError, + TimeoutError, + ClusterDownError, + SlotNotCoveredError, + ) def replace_default_node(self, target_node: "ClusterNode" = None) -> None: """Replace the default cluster node. @@ -1225,13 +1230,19 @@ def _execute_command(self, target_node, *args, **kwargs): except AskError as e: redirect_addr = get_node_name(host=e.host, port=e.port) asking = True - except ClusterDownError as e: + except (ClusterDownError, SlotNotCoveredError): # ClusterDownError can occur during a failover and to get # self-healed, we will try to reinitialize the cluster layout # and retry executing the command + + # SlotNotCoveredError can occur when the cluster is not fully + # initialized or can be temporary issue. + # We will try to reinitialize the cluster topology + # and retry executing the command + time.sleep(0.25) self.nodes_manager.initialize() - raise e + raise except ResponseError: raise except Exception as e: From 338cb9960d71a8462eff20f2cd25771698d1698a Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 29 Apr 2025 19:51:22 +0300 Subject: [PATCH 08/17] Updating package version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 81bbedfe9f..6c2dbb6cd2 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="5.3.0b5", + version="5.3.0", packages=find_packages( include=[ "redis", From c5e4324f236ec78b2ec3b50507eb3baba0dc2607 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 20 Feb 2025 09:54:28 +0200 Subject: [PATCH 09/17] Deprecating unused arguments in connection pools's get_connection functions (#3517) --- redis/asyncio/client.py | 16 ++---- redis/asyncio/connection.py | 16 +++++- redis/client.py | 14 ++--- redis/cluster.py | 25 +++++---- redis/connection.py | 16 +++++- redis/utils.py | 65 ++++++++++++++++++++++ tests/test_asyncio/test_connection.py | 2 +- tests/test_asyncio/test_connection_pool.py | 64 ++++++++++----------- tests/test_asyncio/test_credentials.py | 2 +- tests/test_asyncio/test_encoding.py | 2 +- tests/test_asyncio/test_retry.py | 4 +- tests/test_asyncio/test_sentinel.py | 2 +- tests/test_cache.py | 6 +- tests/test_cluster.py | 4 +- tests/test_connection_pool.py | 60 ++++++++++---------- tests/test_credentials.py | 2 +- tests/test_multiprocessing.py | 10 ++-- tests/test_retry.py | 4 +- tests/test_sentinel.py | 2 +- 19 files changed, 199 insertions(+), 117 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 7c17938714..4254441073 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -375,7 +375,7 @@ async def initialize(self: _RedisT) -> _RedisT: if self.single_connection_client: async with self._single_conn_lock: if self.connection is None: - self.connection = await self.connection_pool.get_connection("_") + self.connection = await self.connection_pool.get_connection() self._event_dispatcher.dispatch( AfterSingleConnectionInstantiationEvent( @@ -638,7 +638,7 @@ async def execute_command(self, *args, **options): await self.initialize() pool = self.connection_pool command_name = args[0] - conn = self.connection or await pool.get_connection(command_name, **options) + conn = self.connection or await pool.get_connection() if self.single_connection_client: await self._single_conn_lock.acquire() @@ -712,7 +712,7 @@ def __init__(self, connection_pool: ConnectionPool): async def connect(self): if self.connection is None: - self.connection = await self.connection_pool.get_connection("MONITOR") + self.connection = await self.connection_pool.get_connection() async def __aenter__(self): await self.connect() @@ -900,9 +900,7 @@ async def connect(self): Ensure that the PubSub is connected """ if self.connection is None: - self.connection = await self.connection_pool.get_connection( - "pubsub", self.shard_hint - ) + self.connection = await self.connection_pool.get_connection() # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) @@ -1370,9 +1368,7 @@ async def immediate_execute_command(self, *args, **options): conn = self.connection # if this is the first call, we need a connection if not conn: - conn = await self.connection_pool.get_connection( - command_name, self.shard_hint - ) + conn = await self.connection_pool.get_connection() self.connection = conn return await conn.retry.call_with_retry( @@ -1568,7 +1564,7 @@ async def execute(self, raise_on_error: bool = True) -> List[Any]: conn = self.connection if not conn: - conn = await self.connection_pool.get_connection("MULTI", self.shard_hint) + conn = await self.connection_pool.get_connection() # assign to self.connection so reset() releases the connection # back to the pool after we're done self.connection = conn diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 4a743ff374..e67dc5b207 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -29,7 +29,7 @@ from ..auth.token import TokenInterface from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher -from ..utils import format_error_message +from ..utils import deprecated_args, format_error_message # the functionality is available in 3.11.x but has a major issue before # 3.11.3. See https://github.com/redis/redis-py/issues/2633 @@ -1087,7 +1087,12 @@ def can_get_connection(self) -> bool: or len(self._in_use_connections) < self.max_connections ) - async def get_connection(self, command_name, *keys, **options): + @deprecated_args( + args_to_warn=["*"], + reason="Use get_connection() without args instead", + version="5.0.3", + ) + async def get_connection(self, command_name=None, *keys, **options): async with self._lock: """Get a connected connection from the pool""" connection = self.get_available_connection() @@ -1255,7 +1260,12 @@ def __init__( self._condition = asyncio.Condition() self.timeout = timeout - async def get_connection(self, command_name, *keys, **options): + @deprecated_args( + args_to_warn=["*"], + reason="Use get_connection() without args instead", + version="5.0.3", + ) + async def get_connection(self, command_name=None, *keys, **options): """Gets a connection from the pool, blocking until one is available""" try: async with self._condition: diff --git a/redis/client.py b/redis/client.py index 5a9f4fafb5..fc535c8ca0 100755 --- a/redis/client.py +++ b/redis/client.py @@ -366,7 +366,7 @@ def __init__( self.connection = None self._single_connection_client = single_connection_client if self._single_connection_client: - self.connection = self.connection_pool.get_connection("_") + self.connection = self.connection_pool.get_connection() self._event_dispatcher.dispatch( AfterSingleConnectionInstantiationEvent( self.connection, ClientType.SYNC, self.single_connection_lock @@ -608,7 +608,7 @@ def _execute_command(self, *args, **options): """Execute a command and return a parsed response""" pool = self.connection_pool command_name = args[0] - conn = self.connection or pool.get_connection(command_name, **options) + conn = self.connection or pool.get_connection() if self._single_connection_client: self.single_connection_lock.acquire() @@ -667,7 +667,7 @@ class Monitor: def __init__(self, connection_pool): self.connection_pool = connection_pool - self.connection = self.connection_pool.get_connection("MONITOR") + self.connection = self.connection_pool.get_connection() def __enter__(self): self.connection.send_command("MONITOR") @@ -840,9 +840,7 @@ def execute_command(self, *args): # subscribed to one or more channels if self.connection is None: - self.connection = self.connection_pool.get_connection( - "pubsub", self.shard_hint - ) + self.connection = self.connection_pool.get_connection() # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) @@ -1397,7 +1395,7 @@ def immediate_execute_command(self, *args, **options): conn = self.connection # if this is the first call, we need a connection if not conn: - conn = self.connection_pool.get_connection(command_name, self.shard_hint) + conn = self.connection_pool.get_connection() self.connection = conn return conn.retry.call_with_retry( @@ -1583,7 +1581,7 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: conn = self.connection if not conn: - conn = self.connection_pool.get_connection("MULTI", self.shard_hint) + conn = self.connection_pool.get_connection() # assign to self.connection so reset() releases the connection # back to the pool after we're done self.connection = conn diff --git a/redis/cluster.py b/redis/cluster.py index a7b4678b46..a54a55f5ec 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -42,6 +42,7 @@ from redis.retry import Retry from redis.utils import ( HIREDIS_AVAILABLE, + deprecated_args, dict_merge, list_keys_to_dict, merge_result, @@ -54,10 +55,13 @@ def get_node_name(host: str, port: Union[str, int]) -> str: return f"{host}:{port}" +@deprecated_args( + allowed_args=["redis_node"], + reason="Use get_connection(redis_node) instead", + version="5.0.3", +) def get_connection(redis_node, *args, **options): - return redis_node.connection or redis_node.connection_pool.get_connection( - args[0], **options - ) + return redis_node.connection or redis_node.connection_pool.get_connection() def parse_scan_result(command, res, **options): @@ -1173,7 +1177,7 @@ def _execute_command(self, target_node, *args, **kwargs): moved = False redis_node = self.get_redis_connection(target_node) - connection = get_connection(redis_node, *args, **kwargs) + connection = get_connection(redis_node) if asking: connection.send_command("ASKING") redis_node.parse_response(connection, "ASKING", **kwargs) @@ -1652,7 +1656,7 @@ def initialize(self): if len(disagreements) > 5: raise RedisClusterException( f"startup_nodes could not agree on a valid " - f'slots cache: {", ".join(disagreements)}' + f"slots cache: {', '.join(disagreements)}" ) fully_covered = self.check_slots_coverage(tmp_slots) @@ -1850,9 +1854,7 @@ def execute_command(self, *args): self.node = node redis_connection = self.cluster.get_redis_connection(node) self.connection_pool = redis_connection.connection_pool - self.connection = self.connection_pool.get_connection( - "pubsub", self.shard_hint - ) + self.connection = self.connection_pool.get_connection() # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) @@ -2073,8 +2075,7 @@ def annotate_exception(self, exception, number, command): """ cmd = " ".join(map(safe_str, command)) msg = ( - f"Command # {number} ({cmd}) of pipeline " - f"caused error: {exception.args[0]}" + f"Command # {number} ({cmd}) of pipeline caused error: {exception.args[0]}" ) exception.args = (msg,) + exception.args[1:] @@ -2212,8 +2213,8 @@ def _send_cluster_commands( if node_name not in nodes: redis_node = self.get_redis_connection(node) try: - connection = get_connection(redis_node, c.args) - except ConnectionError: + connection = get_connection(redis_node) + except (ConnectionError, TimeoutError): for n in nodes.values(): n.connection_pool.release(n.connection) # Connection retries are being handled in the node's diff --git a/redis/connection.py b/redis/connection.py index d47f46590b..3189690802 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -42,6 +42,7 @@ HIREDIS_AVAILABLE, SSL_AVAILABLE, compare_versions, + deprecated_args, ensure_string, format_error_message, get_lib_version, @@ -1461,8 +1462,14 @@ def _checkpid(self) -> None: finally: self._fork_lock.release() - def get_connection(self, command_name: str, *keys, **options) -> "Connection": + @deprecated_args( + args_to_warn=["*"], + reason="Use get_connection() without args instead", + version="5.0.3", + ) + def get_connection(self, command_name=None, *keys, **options) -> "Connection": "Get a connection from the pool" + self._checkpid() with self._lock: try: @@ -1683,7 +1690,12 @@ def make_connection(self): self._connections.append(connection) return connection - def get_connection(self, command_name, *keys, **options): + @deprecated_args( + args_to_warn=["*"], + reason="Use get_connection() without args instead", + version="5.0.3", + ) + def get_connection(self, command_name=None, *keys, **options): """ Get a connection, blocking for ``self.timeout`` until a connection is available from the pool. diff --git a/redis/utils.py b/redis/utils.py index 8693fb3c8f..66465636a1 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -122,6 +122,71 @@ def wrapper(*args, **kwargs): return decorator +def warn_deprecated_arg_usage( + arg_name: Union[list, str], + function_name: str, + reason: str = "", + version: str = "", + stacklevel: int = 2, +): + import warnings + + msg = ( + f"Call to '{function_name}' function with deprecated" + f" usage of input argument/s '{arg_name}'." + ) + if reason: + msg += f" ({reason})" + if version: + msg += f" -- Deprecated since version {version}." + warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) + + +def deprecated_args( + args_to_warn: list = ["*"], + allowed_args: list = [], + reason: str = "", + version: str = "", +): + """ + Decorator to mark specified args of a function as deprecated. + If '*' is in args_to_warn, all arguments will be marked as deprecated. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Get function argument names + arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] + + provided_args = dict(zip(arg_names, args)) + provided_args.update(kwargs) + + provided_args.pop("self", None) + for allowed_arg in allowed_args: + provided_args.pop(allowed_arg, None) + + for arg in args_to_warn: + if arg == "*" and len(provided_args) > 0: + warn_deprecated_arg_usage( + list(provided_args.keys()), + func.__name__, + reason, + version, + stacklevel=3, + ) + elif arg in provided_args: + warn_deprecated_arg_usage( + arg, func.__name__, reason, version, stacklevel=3 + ) + + return func(*args, **kwargs) + + return wrapper + + return decorator + + def _set_info_logger(): """ Set up a logger that log info logs to stdout. diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index d4956f16e9..38764d30cd 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -78,7 +78,7 @@ async def call_with_retry(self, _, __): mock_conn = mock.AsyncMock(spec=Connection) mock_conn.retry = Retry_() - async def get_conn(_): + async def get_conn(): # Validate only one client is created in single-client mode when # concurrent requests are made nonlocal init_call_count diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 83545b4ede..3d120e4ca7 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -29,8 +29,8 @@ def get_total_connected_connections(pool): @staticmethod async def create_two_conn(r: redis.Redis): if not r.single_connection_client: # Single already initialized connection - r.connection = await r.connection_pool.get_connection("_") - return await r.connection_pool.get_connection("_") + r.connection = await r.connection_pool.get_connection() + return await r.connection_pool.get_connection() @staticmethod def has_no_connected_connections(pool: redis.ConnectionPool): @@ -138,7 +138,7 @@ async def test_connection_creation(self): async with self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection ) as pool: - connection = await pool.get_connection("_") + connection = await pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -155,8 +155,8 @@ async def test_aclosing(self): async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - c1 = await pool.get_connection("_") - c2 = await pool.get_connection("_") + c1 = await pool.get_connection() + c2 = await pool.get_connection() assert c1 != c2 async def test_max_connections(self, master_host): @@ -164,17 +164,17 @@ async def test_max_connections(self, master_host): async with self.get_pool( max_connections=2, connection_kwargs=connection_kwargs ) as pool: - await pool.get_connection("_") - await pool.get_connection("_") + await pool.get_connection() + await pool.get_connection() with pytest.raises(redis.ConnectionError): - await pool.get_connection("_") + await pool.get_connection() async def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - c1 = await pool.get_connection("_") + c1 = await pool.get_connection() await pool.release(c1) - c2 = await pool.get_connection("_") + c2 = await pool.get_connection() assert c1 == c2 async def test_repr_contains_db_info_tcp(self): @@ -223,7 +223,7 @@ async def test_connection_creation(self, master_host): "port": master_host[1], } async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - connection = await pool.get_connection("_") + connection = await pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -236,14 +236,14 @@ async def test_disconnect(self, master_host): "port": master_host[1], } async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - await pool.get_connection("_") + await pool.get_connection() await pool.disconnect() async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - c1 = await pool.get_connection("_") - c2 = await pool.get_connection("_") + c1 = await pool.get_connection() + c2 = await pool.get_connection() assert c1 != c2 async def test_connection_pool_blocks_until_timeout(self, master_host): @@ -252,11 +252,11 @@ async def test_connection_pool_blocks_until_timeout(self, master_host): async with self.get_pool( max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs ) as pool: - c1 = await pool.get_connection("_") + c1 = await pool.get_connection() start = asyncio.get_running_loop().time() with pytest.raises(redis.ConnectionError): - await pool.get_connection("_") + await pool.get_connection() # we should have waited at least some period of time assert asyncio.get_running_loop().time() - start >= 0.05 @@ -271,23 +271,23 @@ async def test_connection_pool_blocks_until_conn_available(self, master_host): async with self.get_pool( max_connections=1, timeout=2, connection_kwargs=connection_kwargs ) as pool: - c1 = await pool.get_connection("_") + c1 = await pool.get_connection() async def target(): await asyncio.sleep(0.1) await pool.release(c1) start = asyncio.get_running_loop().time() - await asyncio.gather(target(), pool.get_connection("_")) + await asyncio.gather(target(), pool.get_connection()) stop = asyncio.get_running_loop().time() assert (stop - start) <= 0.2 async def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - c1 = await pool.get_connection("_") + c1 = await pool.get_connection() await pool.release(c1) - c2 = await pool.get_connection("_") + c2 = await pool.get_connection() assert c1 == c2 def test_repr_contains_db_info_tcp(self): @@ -552,23 +552,23 @@ def test_cert_reqs_options(self): import ssl class DummyConnectionPool(redis.ConnectionPool): - def get_connection(self, *args, **kwargs): + def get_connection(self): return self.make_connection() pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=none") - assert pool.get_connection("_").cert_reqs == ssl.CERT_NONE + assert pool.get_connection().cert_reqs == ssl.CERT_NONE pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=optional") - assert pool.get_connection("_").cert_reqs == ssl.CERT_OPTIONAL + assert pool.get_connection().cert_reqs == ssl.CERT_OPTIONAL pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=required") - assert pool.get_connection("_").cert_reqs == ssl.CERT_REQUIRED + assert pool.get_connection().cert_reqs == ssl.CERT_REQUIRED pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=False") - assert pool.get_connection("_").check_hostname is False + assert pool.get_connection().check_hostname is False pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=True") - assert pool.get_connection("_").check_hostname is True + assert pool.get_connection().check_hostname is True class TestConnection: @@ -756,7 +756,7 @@ async def test_health_check_not_invoked_within_interval(self, r): async def test_health_check_in_pipeline(self, r): async with r.pipeline(transaction=False) as pipe: - pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection = await pipe.connection_pool.get_connection() pipe.connection.next_health_check = 0 with mock.patch.object( pipe.connection, "send_command", wraps=pipe.connection.send_command @@ -767,7 +767,7 @@ async def test_health_check_in_pipeline(self, r): async def test_health_check_in_transaction(self, r): async with r.pipeline(transaction=True) as pipe: - pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection = await pipe.connection_pool.get_connection() pipe.connection.next_health_check = 0 with mock.patch.object( pipe.connection, "send_command", wraps=pipe.connection.send_command @@ -779,7 +779,7 @@ async def test_health_check_in_transaction(self, r): async def test_health_check_in_watched_pipeline(self, r): await r.set("foo", "bar") async with r.pipeline(transaction=False) as pipe: - pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection = await pipe.connection_pool.get_connection() pipe.connection.next_health_check = 0 with mock.patch.object( pipe.connection, "send_command", wraps=pipe.connection.send_command @@ -803,7 +803,7 @@ async def test_health_check_in_watched_pipeline(self, r): async def test_health_check_in_pubsub_before_subscribe(self, r): """A health check happens before the first [p]subscribe""" p = r.pubsub() - p.connection = await p.connection_pool.get_connection("_") + p.connection = await p.connection_pool.get_connection() p.connection.next_health_check = 0 with mock.patch.object( p.connection, "send_command", wraps=p.connection.send_command @@ -825,7 +825,7 @@ async def test_health_check_in_pubsub_after_subscribed(self, r): connection health """ p = r.pubsub() - p.connection = await p.connection_pool.get_connection("_") + p.connection = await p.connection_pool.get_connection() p.connection.next_health_check = 0 with mock.patch.object( p.connection, "send_command", wraps=p.connection.send_command @@ -865,7 +865,7 @@ async def test_health_check_in_pubsub_poll(self, r): check the connection's health. """ p = r.pubsub() - p.connection = await p.connection_pool.get_connection("_") + p.connection = await p.connection_pool.get_connection() with mock.patch.object( p.connection, "send_command", wraps=p.connection.send_command ) as m: diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index ca42d19090..1eb988ce71 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -274,7 +274,7 @@ async def test_change_username_password_on_existing_connection( await init_acl_user(r, username, password) r2 = await create_redis(flushdb=False, username=username, password=password) assert await r2.ping() is True - conn = await r2.connection_pool.get_connection("_") + conn = await r2.connection_pool.get_connection() await conn.send_command("PING") assert str_if_bytes(await conn.read_response()) == "PONG" assert conn.username == username diff --git a/tests/test_asyncio/test_encoding.py b/tests/test_asyncio/test_encoding.py index 162ccb367d..74a9f28b2d 100644 --- a/tests/test_asyncio/test_encoding.py +++ b/tests/test_asyncio/test_encoding.py @@ -74,7 +74,7 @@ class TestMemoryviewsAreNotPacked: async def test_memoryviews_are_not_packed(self, r): arg = memoryview(b"some_arg") arg_list = ["SOME_COMMAND", arg] - c = r.connection or await r.connection_pool.get_connection("_") + c = r.connection or await r.connection_pool.get_connection() cmd = c.pack_command(*arg_list) assert cmd[1] is arg cmds = c.pack_commands([arg_list, arg_list]) diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py index 8bc71c1479..cd251a986f 100644 --- a/tests/test_asyncio/test_retry.py +++ b/tests/test_asyncio/test_retry.py @@ -126,13 +126,13 @@ async def test_get_set_retry_object(self, request): assert r.get_retry()._retries == retry._retries assert isinstance(r.get_retry()._backoff, NoBackoff) new_retry_policy = Retry(ExponentialBackoff(), 3) - exiting_conn = await r.connection_pool.get_connection("_") + exiting_conn = await r.connection_pool.get_connection() r.set_retry(new_retry_policy) assert r.get_retry()._retries == new_retry_policy._retries assert isinstance(r.get_retry()._backoff, ExponentialBackoff) assert exiting_conn.retry._retries == new_retry_policy._retries await r.connection_pool.release(exiting_conn) - new_conn = await r.connection_pool.get_connection("_") + new_conn = await r.connection_pool.get_connection() assert new_conn.retry._retries == new_retry_policy._retries await r.connection_pool.release(new_conn) await r.aclose() diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index e553fdb00b..a27ba92bb8 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -269,7 +269,7 @@ async def mock_disconnect(): @pytest.mark.onlynoncluster async def test_repr_correctly_represents_connection_object(sentinel): pool = SentinelConnectionPool("mymaster", sentinel) - connection = await pool.get_connection("PING") + connection = await pool.get_connection() assert ( str(connection) diff --git a/tests/test_cache.py b/tests/test_cache.py index 67733dc9af..7010baff5f 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -159,7 +159,7 @@ def test_cache_clears_on_disconnect(self, r, cache): == b"bar" ) # Force disconnection - r.connection_pool.get_connection("_").disconnect() + r.connection_pool.get_connection().disconnect() # Make sure cache is empty assert cache.size == 0 @@ -429,7 +429,7 @@ def test_cache_clears_on_disconnect(self, r, r2): # Force disconnection r.nodes_manager.get_node_from_slot( 12000 - ).redis_connection.connection_pool.get_connection("_").disconnect() + ).redis_connection.connection_pool.get_connection().disconnect() # Make sure cache is empty assert cache.size == 0 @@ -667,7 +667,7 @@ def test_cache_clears_on_disconnect(self, master, cache): == b"bar" ) # Force disconnection - master.connection_pool.get_connection("_").disconnect() + master.connection_pool.get_connection().disconnect() # Make sure cache_data is empty assert cache.size == 0 diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 1b9b9969c5..908ac26211 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -845,7 +845,7 @@ def test_cluster_get_set_retry_object(self, request): assert node.redis_connection.get_retry()._retries == retry._retries assert isinstance(node.redis_connection.get_retry()._backoff, NoBackoff) rand_node = r.get_random_node() - existing_conn = rand_node.redis_connection.connection_pool.get_connection("_") + existing_conn = rand_node.redis_connection.connection_pool.get_connection() # Change retry policy new_retry = Retry(ExponentialBackoff(), 3) r.set_retry(new_retry) @@ -857,7 +857,7 @@ def test_cluster_get_set_retry_object(self, request): node.redis_connection.get_retry()._backoff, ExponentialBackoff ) assert existing_conn.retry._retries == new_retry._retries - new_conn = rand_node.redis_connection.connection_pool.get_connection("_") + new_conn = rand_node.redis_connection.connection_pool.get_connection() assert new_conn.retry._retries == new_retry._retries def test_cluster_retry_object(self, r) -> None: diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 118294ee1b..387a0f4565 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -54,7 +54,7 @@ def test_connection_creation(self): pool = self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection ) - connection = pool.get_connection("_") + connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -71,24 +71,24 @@ def test_closing(self): def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool.get_connection("_") - c2 = pool.get_connection("_") + c1 = pool.get_connection() + c2 = pool.get_connection() assert c1 != c2 def test_max_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) - pool.get_connection("_") - pool.get_connection("_") + pool.get_connection() + pool.get_connection() with pytest.raises(redis.ConnectionError): - pool.get_connection("_") + pool.get_connection() def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool.get_connection("_") + c1 = pool.get_connection() pool.release(c1) - c2 = pool.get_connection("_") + c2 = pool.get_connection() assert c1 == c2 def test_repr_contains_db_info_tcp(self): @@ -133,15 +133,15 @@ def test_connection_creation(self, master_host): "port": master_host[1], } pool = self.get_pool(connection_kwargs=connection_kwargs) - connection = pool.get_connection("_") + connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool.get_connection("_") - c2 = pool.get_connection("_") + c1 = pool.get_connection() + c2 = pool.get_connection() assert c1 != c2 def test_connection_pool_blocks_until_timeout(self, master_host): @@ -150,11 +150,11 @@ def test_connection_pool_blocks_until_timeout(self, master_host): pool = self.get_pool( max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs ) - pool.get_connection("_") + pool.get_connection() start = time.time() with pytest.raises(redis.ConnectionError): - pool.get_connection("_") + pool.get_connection() # we should have waited at least 0.1 seconds assert time.time() - start >= 0.1 @@ -167,7 +167,7 @@ def test_connection_pool_blocks_until_conn_available(self, master_host): pool = self.get_pool( max_connections=1, timeout=2, connection_kwargs=connection_kwargs ) - c1 = pool.get_connection("_") + c1 = pool.get_connection() def target(): time.sleep(0.1) @@ -175,15 +175,15 @@ def target(): start = time.time() Thread(target=target).start() - pool.get_connection("_") + pool.get_connection() assert time.time() - start >= 0.1 def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool.get_connection("_") + c1 = pool.get_connection() pool.release(c1) - c2 = pool.get_connection("_") + c2 = pool.get_connection() assert c1 == c2 def test_repr_contains_db_info_tcp(self): @@ -214,7 +214,7 @@ def test_initialise_pool_with_cache(self, master_host): protocol=3, cache_config=CacheConfig(), ) - assert isinstance(pool.get_connection("_"), CacheProxyConnection) + assert isinstance(pool.get_connection(), CacheProxyConnection) class TestConnectionPoolURLParsing: @@ -489,23 +489,23 @@ def test_cert_reqs_options(self): import ssl class DummyConnectionPool(redis.ConnectionPool): - def get_connection(self, *args, **kwargs): + def get_connection(self): return self.make_connection() pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=none") - assert pool.get_connection("_").cert_reqs == ssl.CERT_NONE + assert pool.get_connection().cert_reqs == ssl.CERT_NONE pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=optional") - assert pool.get_connection("_").cert_reqs == ssl.CERT_OPTIONAL + assert pool.get_connection().cert_reqs == ssl.CERT_OPTIONAL pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=required") - assert pool.get_connection("_").cert_reqs == ssl.CERT_REQUIRED + assert pool.get_connection().cert_reqs == ssl.CERT_REQUIRED pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=False") - assert pool.get_connection("_").check_hostname is False + assert pool.get_connection().check_hostname is False pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=True") - assert pool.get_connection("_").check_hostname is True + assert pool.get_connection().check_hostname is True class TestConnection: @@ -701,7 +701,7 @@ def test_health_check_not_invoked_within_interval(self, r): def test_health_check_in_pipeline(self, r): with r.pipeline(transaction=False) as pipe: - pipe.connection = pipe.connection_pool.get_connection("_") + pipe.connection = pipe.connection_pool.get_connection() pipe.connection.next_health_check = 0 with mock.patch.object( pipe.connection, "send_command", wraps=pipe.connection.send_command @@ -712,7 +712,7 @@ def test_health_check_in_pipeline(self, r): def test_health_check_in_transaction(self, r): with r.pipeline(transaction=True) as pipe: - pipe.connection = pipe.connection_pool.get_connection("_") + pipe.connection = pipe.connection_pool.get_connection() pipe.connection.next_health_check = 0 with mock.patch.object( pipe.connection, "send_command", wraps=pipe.connection.send_command @@ -724,7 +724,7 @@ def test_health_check_in_transaction(self, r): def test_health_check_in_watched_pipeline(self, r): r.set("foo", "bar") with r.pipeline(transaction=False) as pipe: - pipe.connection = pipe.connection_pool.get_connection("_") + pipe.connection = pipe.connection_pool.get_connection() pipe.connection.next_health_check = 0 with mock.patch.object( pipe.connection, "send_command", wraps=pipe.connection.send_command @@ -748,7 +748,7 @@ def test_health_check_in_watched_pipeline(self, r): def test_health_check_in_pubsub_before_subscribe(self, r): "A health check happens before the first [p]subscribe" p = r.pubsub() - p.connection = p.connection_pool.get_connection("_") + p.connection = p.connection_pool.get_connection() p.connection.next_health_check = 0 with mock.patch.object( p.connection, "send_command", wraps=p.connection.send_command @@ -770,7 +770,7 @@ def test_health_check_in_pubsub_after_subscribed(self, r): connection health """ p = r.pubsub() - p.connection = p.connection_pool.get_connection("_") + p.connection = p.connection_pool.get_connection() p.connection.next_health_check = 0 with mock.patch.object( p.connection, "send_command", wraps=p.connection.send_command @@ -810,7 +810,7 @@ def test_health_check_in_pubsub_poll(self, r): check the connection's health. """ p = r.pubsub() - p.connection = p.connection_pool.get_connection("_") + p.connection = p.connection_pool.get_connection() with mock.patch.object( p.connection, "send_command", wraps=p.connection.send_command ) as m: diff --git a/tests/test_credentials.py b/tests/test_credentials.py index b0b79d305f..95ec5577cc 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -252,7 +252,7 @@ def teardown(): redis.Redis, request, flushdb=False, username=username, password=password ) assert r2.ping() is True - conn = r2.connection_pool.get_connection("_") + conn = r2.connection_pool.get_connection() conn.send_command("PING") assert str_if_bytes(conn.read_response()) == "PONG" assert conn.username == username diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 116d20dab0..0e8e8958c5 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -95,7 +95,7 @@ def test_pool(self, max_connections, master_host): max_connections=max_connections, ) - conn = pool.get_connection("ping") + conn = pool.get_connection() main_conn_pid = conn.pid with exit_callback(pool.release, conn): conn.send_command("ping") @@ -103,7 +103,7 @@ def test_pool(self, max_connections, master_host): def target(pool): with exit_callback(pool.disconnect): - conn = pool.get_connection("ping") + conn = pool.get_connection() assert conn.pid != main_conn_pid with exit_callback(pool.release, conn): assert conn.send_command("ping") is None @@ -116,7 +116,7 @@ def target(pool): # Check that connection is still alive after fork process has exited # and disconnected the connections in its pool - conn = pool.get_connection("ping") + conn = pool.get_connection() with exit_callback(pool.release, conn): assert conn.send_command("ping") is None assert conn.read_response() == b"PONG" @@ -132,12 +132,12 @@ def test_close_pool_in_main(self, max_connections, master_host): max_connections=max_connections, ) - conn = pool.get_connection("ping") + conn = pool.get_connection() assert conn.send_command("ping") is None assert conn.read_response() == b"PONG" def target(pool, disconnect_event): - conn = pool.get_connection("ping") + conn = pool.get_connection() with exit_callback(pool.release, conn): assert conn.send_command("ping") is None assert conn.read_response() == b"PONG" diff --git a/tests/test_retry.py b/tests/test_retry.py index 183807386d..e1e4c414a4 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -206,7 +206,7 @@ def test_client_retry_on_timeout(self, request): def test_get_set_retry_object(self, request): retry = Retry(NoBackoff(), 2) r = _get_client(Redis, request, retry_on_timeout=True, retry=retry) - exist_conn = r.connection_pool.get_connection("_") + exist_conn = r.connection_pool.get_connection() assert r.get_retry()._retries == retry._retries assert isinstance(r.get_retry()._backoff, NoBackoff) new_retry_policy = Retry(ExponentialBackoff(), 3) @@ -214,5 +214,5 @@ def test_get_set_retry_object(self, request): assert r.get_retry()._retries == new_retry_policy._retries assert isinstance(r.get_retry()._backoff, ExponentialBackoff) assert exist_conn.retry._retries == new_retry_policy._retries - new_conn = r.connection_pool.get_connection("_") + new_conn = r.connection_pool.get_connection() assert new_conn.retry._retries == new_retry_policy._retries diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 54b9647098..93455f3290 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -101,7 +101,7 @@ def test_discover_master_error(sentinel): @pytest.mark.onlynoncluster def test_dead_pool(sentinel): master = sentinel.master_for("mymaster", db=9) - conn = master.connection_pool.get_connection("_") + conn = master.connection_pool.get_connection() conn.disconnect() del master conn.connect() From eb91d4fa6659f2fe67b2232da9dc800b7b4f21ad Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Wed, 19 Mar 2025 09:58:11 +0200 Subject: [PATCH 10/17] Adding load balancing strategy configuration to cluster clients(replacement for 'read_from_replicas' config) (#3563) * Adding laod balancing strategy configuration to cluster clients(replacement for 'read_from_replicas' config) * Fixing linter errors * Changing the LoadBalancingStrategy type hints to be defined as optional. Fixed wording in pydocs * Adding integration tests with the different load balancing strategies for read operation * Fixing linters --- redis/asyncio/cluster.py | 46 +++++++-- redis/cluster.py | 100 +++++++++++++++--- tests/test_asyncio/test_cluster.py | 139 +++++++++++++++++++++++-- tests/test_cluster.py | 160 +++++++++++++++++++++++++++-- tests/test_multiprocessing.py | 44 ++++++++ 5 files changed, 444 insertions(+), 45 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index e1d4651a08..aa472ed5b8 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -39,6 +39,7 @@ SLOT_ID, AbstractRedisCluster, LoadBalancer, + LoadBalancingStrategy, block_pipeline_command, get_node_name, parse_cluster_slots, @@ -67,6 +68,7 @@ ) from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import ( + deprecated_args, deprecated_function, dict_merge, get_lib_version, @@ -133,9 +135,15 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand | See: https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters :param read_from_replicas: - | Enable read from replicas in READONLY mode. You can read possibly stale data. + | @deprecated - please use load_balancing_strategy instead + | Enable read from replicas in READONLY mode. When set to true, read commands will be assigned between the primary and its replications in a Round-Robin manner. + The data read from replicas is eventually consistent with the data in primary nodes. + :param load_balancing_strategy: + | Enable read from replicas in READONLY mode and defines the load balancing + strategy that will be used for cluster node selection. + The data read from replicas is eventually consistent with the data in primary nodes. :param reinitialize_steps: | Specifies the number of MOVED errors that need to occur before reinitializing the whole cluster topology. If a MOVED error occurs and the cluster does not @@ -228,6 +236,11 @@ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": "result_callbacks", ) + @deprecated_args( + args_to_warn=["read_from_replicas"], + reason="Please configure the 'load_balancing_strategy' instead", + version="5.0.3", + ) def __init__( self, host: Optional[str] = None, @@ -236,6 +249,7 @@ def __init__( startup_nodes: Optional[List["ClusterNode"]] = None, require_full_coverage: bool = True, read_from_replicas: bool = False, + load_balancing_strategy: Optional[LoadBalancingStrategy] = None, reinitialize_steps: int = 5, cluster_error_retry_attempts: int = 3, connection_error_retry_attempts: int = 3, @@ -335,7 +349,7 @@ def __init__( } ) - if read_from_replicas: + if read_from_replicas or load_balancing_strategy: # Call our on_connect function to configure READONLY mode kwargs["redis_connect_func"] = self.on_connect @@ -384,6 +398,7 @@ def __init__( ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas + self.load_balancing_strategy = load_balancing_strategy self.reinitialize_steps = reinitialize_steps self.cluster_error_retry_attempts = cluster_error_retry_attempts self.connection_error_retry_attempts = connection_error_retry_attempts @@ -602,6 +617,7 @@ async def _determine_nodes( self.nodes_manager.get_node_from_slot( await self._determine_slot(command, *args), self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy if command in READ_COMMANDS else None, ) ] @@ -782,7 +798,11 @@ async def _execute_command( # refresh the target node slot = await self._determine_slot(*args) target_node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and args[0] in READ_COMMANDS + slot, + self.read_from_replicas and args[0] in READ_COMMANDS, + self.load_balancing_strategy + if args[0] in READ_COMMANDS + else None, ) moved = False @@ -1183,9 +1203,7 @@ def get_node( return self.nodes_cache.get(node_name) else: raise DataError( - "get_node requires one of the following: " - "1. node name " - "2. host and port" + "get_node requires one of the following: 1. node name 2. host and port" ) def set_nodes( @@ -1245,17 +1263,23 @@ def _update_moved_slots(self) -> None: self._moved_exception = None def get_node_from_slot( - self, slot: int, read_from_replicas: bool = False + self, + slot: int, + read_from_replicas: bool = False, + load_balancing_strategy=None, ) -> "ClusterNode": if self._moved_exception: self._update_moved_slots() + if read_from_replicas is True and load_balancing_strategy is None: + load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN + try: - if read_from_replicas: - # get the server index in a Round-Robin manner + if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: + # get the server index using the strategy defined in load_balancing_strategy primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) + primary_name, len(self.slots_cache[slot]), load_balancing_strategy ) return self.slots_cache[slot][node_idx] return self.slots_cache[slot][0] @@ -1367,7 +1391,7 @@ async def initialize(self) -> None: if len(disagreements) > 5: raise RedisClusterException( f"startup_nodes could not agree on a valid " - f'slots cache: {", ".join(disagreements)}' + f"slots cache: {', '.join(disagreements)}" ) # Validate if all slots are covered or if we should try next startup node diff --git a/redis/cluster.py b/redis/cluster.py index a54a55f5ec..7e7f590b82 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -4,6 +4,7 @@ import threading import time from collections import OrderedDict +from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union from redis._parsers import CommandsParser, Encoder @@ -505,6 +506,11 @@ class initializer. In the case of conflicting arguments, querystring """ return cls(url=url, **kwargs) + @deprecated_args( + args_to_warn=["read_from_replicas"], + reason="Please configure the 'load_balancing_strategy' instead", + version="5.0.3", + ) def __init__( self, host: Optional[str] = None, @@ -515,6 +521,7 @@ def __init__( require_full_coverage: bool = False, reinitialize_steps: int = 5, read_from_replicas: bool = False, + load_balancing_strategy: Optional["LoadBalancingStrategy"] = None, dynamic_startup_nodes: bool = True, url: Optional[str] = None, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, @@ -543,11 +550,16 @@ def __init__( cluster client. If not all slots are covered, RedisClusterException will be thrown. :param read_from_replicas: + @deprecated - please use load_balancing_strategy instead Enable read from replicas in READONLY mode. You can read possibly stale data. When set to true, read commands will be assigned between the primary and its replications in a Round-Robin manner. - :param dynamic_startup_nodes: + :param load_balancing_strategy: + Enable read from replicas in READONLY mode and defines the load balancing + strategy that will be used for cluster node selection. + The data read from replicas is eventually consistent with the data in primary nodes. + :param dynamic_startup_nodes: Set the RedisCluster's startup nodes to all of the discovered nodes. If true (default value), the cluster's discovered nodes will be used to determine the cluster nodes-slots mapping in the next topology refresh. @@ -652,6 +664,7 @@ def __init__( self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.node_flags = self.__class__.NODE_FLAGS.copy() self.read_from_replicas = read_from_replicas + self.load_balancing_strategy = load_balancing_strategy self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps if event_dispatcher is None: @@ -704,7 +717,7 @@ def on_connect(self, connection): connection.set_parser(ClusterParser) connection.on_connect() - if self.read_from_replicas: + if self.read_from_replicas or self.load_balancing_strategy: # Sending READONLY command to server to configure connection as # readonly. Since each cluster node may change its server type due # to a failover, we should establish a READONLY connection @@ -831,6 +844,7 @@ def pipeline(self, transaction=None, shard_hint=None): cluster_response_callbacks=self.cluster_response_callbacks, cluster_error_retry_attempts=self.cluster_error_retry_attempts, read_from_replicas=self.read_from_replicas, + load_balancing_strategy=self.load_balancing_strategy, reinitialize_steps=self.reinitialize_steps, lock=self._lock, ) @@ -948,7 +962,9 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: # get the node that holds the key's slot slot = self.determine_slot(*args) node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS + slot, + self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy if command in READ_COMMANDS else None, ) return [node] @@ -1172,7 +1188,11 @@ def _execute_command(self, target_node, *args, **kwargs): # refresh the target node slot = self.determine_slot(*args) target_node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS + slot, + self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy + if command in READ_COMMANDS + else None, ) moved = False @@ -1327,6 +1347,12 @@ def __del__(self): self.redis_connection.close() +class LoadBalancingStrategy(Enum): + ROUND_ROBIN = "round_robin" + ROUND_ROBIN_REPLICAS = "round_robin_replicas" + RANDOM_REPLICA = "random_replica" + + class LoadBalancer: """ Round-Robin Load Balancing @@ -1336,15 +1362,38 @@ def __init__(self, start_index: int = 0) -> None: self.primary_to_idx = {} self.start_index = start_index - def get_server_index(self, primary: str, list_size: int) -> int: - server_index = self.primary_to_idx.setdefault(primary, self.start_index) - # Update the index - self.primary_to_idx[primary] = (server_index + 1) % list_size - return server_index + def get_server_index( + self, + primary: str, + list_size: int, + load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN, + ) -> int: + if load_balancing_strategy == LoadBalancingStrategy.RANDOM_REPLICA: + return self._get_random_replica_index(list_size) + else: + return self._get_round_robin_index( + primary, + list_size, + load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + ) def reset(self) -> None: self.primary_to_idx.clear() + def _get_random_replica_index(self, list_size: int) -> int: + return random.randint(1, list_size - 1) + + def _get_round_robin_index( + self, primary: str, list_size: int, replicas_only: bool + ) -> int: + server_index = self.primary_to_idx.setdefault(primary, self.start_index) + if replicas_only and server_index == 0: + # skip the primary node index + server_index = 1 + # Update the index for the next round + self.primary_to_idx[primary] = (server_index + 1) % list_size + return server_index + class NodesManager: def __init__( @@ -1448,7 +1497,21 @@ def _update_moved_slots(self): # Reset moved_exception self._moved_exception = None - def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): + @deprecated_args( + args_to_warn=["server_type"], + reason=( + "In case you need select some load balancing strategy " + "that will use replicas, please set it through 'load_balancing_strategy'" + ), + version="5.0.3", + ) + def get_node_from_slot( + self, + slot, + read_from_replicas=False, + load_balancing_strategy=None, + server_type=None, + ): """ Gets a node that servers this hash slot """ @@ -1463,11 +1526,14 @@ def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): f'"require_full_coverage={self._require_full_coverage}"' ) - if read_from_replicas is True: - # get the server index in a Round-Robin manner + if read_from_replicas is True and load_balancing_strategy is None: + load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN + + if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: + # get the server index using the strategy defined in load_balancing_strategy primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) + primary_name, len(self.slots_cache[slot]), load_balancing_strategy ) elif ( server_type is None @@ -1750,7 +1816,7 @@ def __init__( first command execution. The node will be determined by: 1. Hashing the channel name in the request to find its keyslot 2. Selecting a node that handles the keyslot: If read_from_replicas is - set to true, a replica can be selected. + set to true or load_balancing_strategy is set, a replica can be selected. :type redis_cluster: RedisCluster :type node: ClusterNode @@ -1846,7 +1912,9 @@ def execute_command(self, *args): channel = args[1] slot = self.cluster.keyslot(channel) node = self.cluster.nodes_manager.get_node_from_slot( - slot, self.cluster.read_from_replicas + slot, + self.cluster.read_from_replicas, + self.cluster.load_balancing_strategy, ) else: # Get a random node @@ -1989,6 +2057,7 @@ def __init__( cluster_response_callbacks: Optional[Dict[str, Callable]] = None, startup_nodes: Optional[List["ClusterNode"]] = None, read_from_replicas: bool = False, + load_balancing_strategy: Optional[LoadBalancingStrategy] = None, cluster_error_retry_attempts: int = 3, reinitialize_steps: int = 5, lock=None, @@ -2004,6 +2073,7 @@ def __init__( ) self.startup_nodes = startup_nodes if startup_nodes else [] self.read_from_replicas = read_from_replicas + self.load_balancing_strategy = load_balancing_strategy self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.cluster_response_callbacks = cluster_response_callbacks self.cluster_error_retry_attempts = cluster_error_retry_attempts diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index c95babf687..7e7ad23d15 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -14,7 +14,13 @@ from redis.asyncio.connection import Connection, SSLConnection, async_timeout from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff -from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name +from redis.cluster import ( + PIPELINE_BLOCKED_COMMANDS, + PRIMARY, + REPLICA, + LoadBalancingStrategy, + get_node_name, +) from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( AskError, @@ -182,7 +188,18 @@ def cmd_init_mock(self, r: ClusterNode) -> None: cmd_parser_initialize.side_effect = cmd_init_mock - return await RedisCluster(*args, **kwargs) + # Create a subclass of RedisCluster that overrides __del__ + class MockedRedisCluster(RedisCluster): + def __del__(self): + # Override to prevent connection cleanup attempts + pass + + @property + def connection_pool(self): + # Required abstract property implementation + return self.nodes_manager.get_default_node().redis_connection.connection_pool + + return await MockedRedisCluster(*args, **kwargs) def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: @@ -678,7 +695,24 @@ def cmd_init_mock(self, r: ClusterNode) -> None: assert execute_command.failed_calls == 1 assert execute_command.successful_calls == 1 - async def test_reading_from_replicas_in_round_robin(self) -> None: + @pytest.mark.parametrize( + "read_from_replicas,load_balancing_strategy,mocks_srv_ports", + [ + (True, None, [7001, 7002, 7001]), + (True, LoadBalancingStrategy.ROUND_ROBIN, [7001, 7002, 7001]), + (True, LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, [7002, 7002, 7002]), + (True, LoadBalancingStrategy.RANDOM_REPLICA, [7002, 7002, 7002]), + (False, LoadBalancingStrategy.ROUND_ROBIN, [7001, 7002, 7001]), + (False, LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, [7002, 7002, 7002]), + (False, LoadBalancingStrategy.RANDOM_REPLICA, [7002, 7002, 7002]), + ], + ) + async def test_reading_with_load_balancing_strategies( + self, + read_from_replicas: bool, + load_balancing_strategy: LoadBalancingStrategy, + mocks_srv_ports: List[int], + ) -> None: with mock.patch.multiple( Connection, send_command=mock.DEFAULT, @@ -694,19 +728,19 @@ async def test_reading_from_replicas_in_round_robin(self) -> None: async def execute_command_mock_first(self, *args, **options): await self.connection_class(**self.connection_kwargs).connect() # Primary - assert self.port == 7001 + assert self.port == mocks_srv_ports[0] execute_command.side_effect = execute_command_mock_second return "MOCK_OK" def execute_command_mock_second(self, *args, **options): # Replica - assert self.port == 7002 + assert self.port == mocks_srv_ports[1] execute_command.side_effect = execute_command_mock_third return "MOCK_OK" def execute_command_mock_third(self, *args, **options): # Primary - assert self.port == 7001 + assert self.port == mocks_srv_ports[2] return "MOCK_OK" # We don't need to create a real cluster connection but we @@ -721,9 +755,13 @@ def execute_command_mock_third(self, *args, **options): # Create a cluster with reading from replications read_cluster = await get_mocked_redis_client( - host=default_host, port=default_port, read_from_replicas=True + host=default_host, + port=default_port, + read_from_replicas=read_from_replicas, + load_balancing_strategy=load_balancing_strategy, ) - assert read_cluster.read_from_replicas is True + assert read_cluster.read_from_replicas is read_from_replicas + assert read_cluster.load_balancing_strategy is load_balancing_strategy # Check that we read from the slot's nodes in a round robin # matter. # 'foo' belongs to slot 12182 and the slot's nodes are: @@ -971,6 +1009,34 @@ async def test_get_and_set(self, r: RedisCluster) -> None: assert await r.get("integer") == str(integer).encode() assert (await r.get("unicode_string")).decode("utf-8") == unicode_string + @pytest.mark.parametrize( + "load_balancing_strategy", + [ + LoadBalancingStrategy.ROUND_ROBIN, + LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + LoadBalancingStrategy.RANDOM_REPLICA, + ], + ) + async def test_get_and_set_with_load_balanced_client( + self, create_redis, load_balancing_strategy: LoadBalancingStrategy + ) -> None: + r = await create_redis( + cls=RedisCluster, + load_balancing_strategy=load_balancing_strategy, + ) + + # get and set can't be tested independently of each other + assert await r.get("a") is None + + byte_string = b"value" + assert await r.set("byte_string", byte_string) + + # run the get command for the same key several times + # to iterate over the read nodes + assert await r.get("byte_string") == byte_string + assert await r.get("byte_string") == byte_string + assert await r.get("byte_string") == byte_string + async def test_mget_nonatomic(self, r: RedisCluster) -> None: assert await r.mget_nonatomic([]) == [] assert await r.mget_nonatomic(["a", "b"]) == [None, None] @@ -2371,11 +2437,14 @@ async def test_load_balancer(self, r: RedisCluster) -> None: primary2_name = n_manager.slots_cache[slot_2][0].name list1_size = len(n_manager.slots_cache[slot_1]) list2_size = len(n_manager.slots_cache[slot_2]) + + # default load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN # slot 1 assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary1_name, list1_size) == 1 assert lb.get_server_index(primary1_name, list1_size) == 2 assert lb.get_server_index(primary1_name, list1_size) == 0 + # slot 2 assert lb.get_server_index(primary2_name, list2_size) == 0 assert lb.get_server_index(primary2_name, list2_size) == 1 @@ -2385,6 +2454,29 @@ async def test_load_balancer(self, r: RedisCluster) -> None: assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary2_name, list2_size) == 0 + # reset the indexes before load balancing strategy test + lb.reset() + # load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN_REPLICAS + for i in [1, 2, 1]: + srv_index = lb.get_server_index( + primary1_name, + list1_size, + load_balancing_strategy=LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + ) + assert srv_index == i + + # reset the indexes before load balancing strategy test + lb.reset() + # load balancer strategy: LoadBalancerStrategy.RANDOM_REPLICA + for i in range(5): + srv_index = lb.get_server_index( + primary1_name, + list1_size, + load_balancing_strategy=LoadBalancingStrategy.RANDOM_REPLICA, + ) + + assert srv_index > 0 and srv_index <= 2 + async def test_init_slots_cache_not_all_slots_covered(self) -> None: """ Test that if not all slots are covered it should raise an exception @@ -2856,6 +2948,37 @@ async def test_readonly_pipeline_from_readonly_client( break assert executed_on_replica + @pytest.mark.parametrize( + "load_balancing_strategy", + [ + LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + LoadBalancingStrategy.RANDOM_REPLICA, + ], + ) + async def test_readonly_pipeline_with_reading_from_replicas_strategies( + self, r: RedisCluster, load_balancing_strategy: LoadBalancingStrategy + ) -> None: + """ + Test that the pipeline uses replicas for different replica-based + load balancing strategies. + """ + # Set the load balancing strategy + r.load_balancing_strategy = load_balancing_strategy + key = "bar" + await r.set(key, "foo") + + async with r.pipeline() as pipe: + mock_all_nodes_resp(r, "MOCK_OK") + assert await pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] + slot_nodes = r.nodes_manager.slots_cache[r.keyslot(key)] + executed_on_replicas_only = True + for node in slot_nodes: + if node.server_type == PRIMARY: + if node._free.pop().read_response.await_count > 0: + executed_on_replicas_only = False + break + assert executed_on_replicas_only + async def test_can_run_concurrent_pipelines(self, r: RedisCluster) -> None: """Test that the pipeline can be used concurrently.""" await asyncio.gather( diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 908ac26211..118c355b97 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -4,6 +4,7 @@ import socket import socketserver import threading +from typing import List import warnings from queue import LifoQueue, Queue from time import sleep @@ -19,6 +20,7 @@ REDIS_CLUSTER_HASH_SLOTS, REPLICA, ClusterNode, + LoadBalancingStrategy, NodesManager, RedisCluster, get_node_name, @@ -202,7 +204,18 @@ def cmd_init_mock(self, r): cmd_parser_initialize.side_effect = cmd_init_mock - return RedisCluster(*args, **kwargs) + # Create a subclass of RedisCluster that overrides __del__ + class MockedRedisCluster(RedisCluster): + def __del__(self): + # Override to prevent connection cleanup attempts + pass + + @property + def connection_pool(self): + # Required abstract property implementation + return self.nodes_manager.get_default_node().redis_connection.connection_pool + + return MockedRedisCluster(*args, **kwargs) def mock_node_resp(node, response): @@ -590,7 +603,24 @@ def cmd_init_mock(self, r): assert parse_response.failed_calls == 1 assert parse_response.successful_calls == 1 - def test_reading_from_replicas_in_round_robin(self): + @pytest.mark.parametrize( + "read_from_replicas,load_balancing_strategy,mocks_srv_ports", + [ + (True, None, [7001, 7002, 7001]), + (True, LoadBalancingStrategy.ROUND_ROBIN, [7001, 7002, 7001]), + (True, LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, [7002, 7002, 7002]), + (True, LoadBalancingStrategy.RANDOM_REPLICA, [7002, 7002, 7002]), + (False, LoadBalancingStrategy.ROUND_ROBIN, [7001, 7002, 7001]), + (False, LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, [7002, 7002, 7002]), + (False, LoadBalancingStrategy.RANDOM_REPLICA, [7002, 7002, 7002]), + ], + ) + def test_reading_with_load_balancing_strategies( + self, + read_from_replicas: bool, + load_balancing_strategy: LoadBalancingStrategy, + mocks_srv_ports: List[int], + ): with patch.multiple( Connection, send_command=DEFAULT, @@ -603,19 +633,19 @@ def test_reading_from_replicas_in_round_robin(self): def parse_response_mock_first(connection, *args, **options): # Primary - assert connection.port == 7001 + assert connection.port == mocks_srv_ports[0] parse_response.side_effect = parse_response_mock_second return "MOCK_OK" def parse_response_mock_second(connection, *args, **options): # Replica - assert connection.port == 7002 + assert connection.port == mocks_srv_ports[1] parse_response.side_effect = parse_response_mock_third return "MOCK_OK" def parse_response_mock_third(connection, *args, **options): # Primary - assert connection.port == 7001 + assert connection.port == mocks_srv_ports[2] return "MOCK_OK" # We don't need to create a real cluster connection but we @@ -630,9 +660,13 @@ def parse_response_mock_third(connection, *args, **options): # Create a cluster with reading from replications read_cluster = get_mocked_redis_client( - host=default_host, port=default_port, read_from_replicas=True + host=default_host, + port=default_port, + read_from_replicas=read_from_replicas, + load_balancing_strategy=load_balancing_strategy, ) - assert read_cluster.read_from_replicas is True + assert read_cluster.read_from_replicas is read_from_replicas + assert read_cluster.load_balancing_strategy is load_balancing_strategy # Check that we read from the slot's nodes in a round robin # matter. # 'foo' belongs to slot 12182 and the slot's nodes are: @@ -640,16 +674,27 @@ def parse_response_mock_third(connection, *args, **options): read_cluster.get("foo") read_cluster.get("foo") read_cluster.get("foo") - mocks["send_command"].assert_has_calls( + expected_calls_list = [] + expected_calls_list.append(call("READONLY")) + expected_calls_list.append(call("GET", "foo", keys=["foo"])) + + if ( + load_balancing_strategy is None + or load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN + ): + # in the round robin strategy the primary node can also receive read + # requests and this means that there will be second node connected + expected_calls_list.append(call("READONLY")) + + expected_calls_list.extend( [ - call("READONLY"), - call("GET", "foo", keys=["foo"]), - call("READONLY"), call("GET", "foo", keys=["foo"]), call("GET", "foo", keys=["foo"]), ] ) + mocks["send_command"].assert_has_calls(expected_calls_list) + def test_keyslot(self, r): """ Test that method will compute correct key in all supported cases @@ -975,6 +1020,35 @@ def test_get_and_set(self, r): assert r.get("integer") == str(integer).encode() assert r.get("unicode_string").decode("utf-8") == unicode_string + @pytest.mark.parametrize( + "load_balancing_strategy", + [ + LoadBalancingStrategy.ROUND_ROBIN, + LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + LoadBalancingStrategy.RANDOM_REPLICA, + ], + ) + def test_get_and_set_with_load_balanced_client( + self, request, load_balancing_strategy: LoadBalancingStrategy + ) -> None: + r = _get_client( + cls=RedisCluster, + request=request, + load_balancing_strategy=load_balancing_strategy, + ) + + # get and set can't be tested independently of each other + assert r.get("a") is None + + byte_string = b"value" + assert r.set("byte_string", byte_string) + + # run the get command for the same key several times + # to iterate over the read nodes + assert r.get("byte_string") == byte_string + assert r.get("byte_string") == byte_string + assert r.get("byte_string") == byte_string + def test_mget_nonatomic(self, r): assert r.mget_nonatomic([]) == [] assert r.mget_nonatomic(["a", "b"]) == [None, None] @@ -2517,6 +2591,8 @@ def test_load_balancer(self, r): primary2_name = n_manager.slots_cache[slot_2][0].name list1_size = len(n_manager.slots_cache[slot_1]) list2_size = len(n_manager.slots_cache[slot_2]) + + # default load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN # slot 1 assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary1_name, list1_size) == 1 @@ -2531,6 +2607,29 @@ def test_load_balancer(self, r): assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary2_name, list2_size) == 0 + # reset the indexes before load balancing strategy test + lb.reset() + # load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN_REPLICAS + for i in [1, 2, 1]: + srv_index = lb.get_server_index( + primary1_name, + list1_size, + load_balancing_strategy=LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + ) + assert srv_index == i + + # reset the indexes before load balancing strategy test + lb.reset() # reset the indexes + # load balancer strategy: LoadBalancerStrategy.RANDOM_REPLICA + for i in range(5): + srv_index = lb.get_server_index( + primary1_name, + list1_size, + load_balancing_strategy=LoadBalancingStrategy.RANDOM_REPLICA, + ) + + assert srv_index > 0 and srv_index <= 2 + def test_init_slots_cache_not_all_slots_covered(self): """ Test that if not all slots are covered it should raise an exception @@ -3377,6 +3476,45 @@ def test_readonly_pipeline_from_readonly_client(self, request): break assert executed_on_replica is True + @pytest.mark.parametrize( + "load_balancing_strategy", + [ + LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + LoadBalancingStrategy.RANDOM_REPLICA, + ], + ) + def test_readonly_pipeline_with_reading_from_replicas_strategies( + self, request, load_balancing_strategy: LoadBalancingStrategy + ) -> None: + """ + Test that the pipeline uses replicas for different replica-based + load balancing strategies. + """ + ro = _get_client( + RedisCluster, + request, + load_balancing_strategy=load_balancing_strategy, + ) + key = "bar" + ro.set(key, "foo") + import time + + time.sleep(0.2) + + with ro.pipeline() as readonly_pipe: + mock_all_nodes_resp(ro, "MOCK_OK") + assert readonly_pipe.load_balancing_strategy == load_balancing_strategy + assert readonly_pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] + slot_nodes = ro.nodes_manager.slots_cache[ro.keyslot(key)] + executed_on_replicas_only = True + for node in slot_nodes: + if node.server_type == PRIMARY: + conn = node.redis_connection.connection + if conn.read_response.called: + executed_on_replicas_only = False + break + assert executed_on_replicas_only + @pytest.mark.onlycluster class TestClusterMonitor: diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 0e8e8958c5..b4d2630b2b 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -22,6 +22,16 @@ def exit_callback(callback, *args): class TestMultiprocessing: + # On macOS and newly non-macOS POSIX systems (since Python 3.14), + # the default method has been changed to forkserver. + # The code in this module does not work with it, + # hence the explicit change to 'fork' + # See https://github.com/python/cpython/issues/125714 + if multiprocessing.get_start_method() in ["forkserver", "spawn"]: + _mp_context = multiprocessing.get_context(method="fork") + else: + _mp_context = multiprocessing.get_context() + # Test connection sharing between forks. # See issue #1085 for details. @@ -84,6 +94,40 @@ def target(conn, ev): proc.join(3) assert proc.exitcode == 0 + @pytest.mark.parametrize("max_connections", [2, None]) + def test_release_parent_connection_from_pool_in_child_process( + self, max_connections, master_host + ): + """ + A connection owned by a parent should not decrease the _created_connections + counter in child when released - when the child process starts to use the + pool it resets all the counters that have been set in the parent process. + """ + + pool = ConnectionPool.from_url( + f"redis://{master_host[0]}:{master_host[1]}", + max_connections=max_connections, + ) + + parent_conn = pool.get_connection() + + def target(pool, parent_conn): + with exit_callback(pool.disconnect): + child_conn = pool.get_connection() + assert child_conn.pid != parent_conn.pid + pool.release(child_conn) + assert pool._created_connections == 1 + assert child_conn in pool._available_connections + pool.release(parent_conn) + assert pool._created_connections == 1 + assert child_conn in pool._available_connections + assert parent_conn not in pool._available_connections + + proc = self._mp_context.Process(target=target, args=(pool, parent_conn)) + proc.start() + proc.join(3) + assert proc.exitcode == 0 + @pytest.mark.parametrize("max_connections", [1, 2, None]) def test_pool(self, max_connections, master_host): """ From 53dba14ac9f68db0b5ce34076bb1f8b9e2c6fcbe Mon Sep 17 00:00:00 2001 From: Jim Cameron-Burn Date: Mon, 24 Mar 2025 05:31:04 +0000 Subject: [PATCH 11/17] Exponential with jitter backoff (#3550) --- redis/backoff.py | 15 +++++++++++++++ tests/test_backoff.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 tests/test_backoff.py diff --git a/redis/backoff.py b/redis/backoff.py index f612d60704..e236764d71 100644 --- a/redis/backoff.py +++ b/redis/backoff.py @@ -110,5 +110,20 @@ def compute(self, failures: int) -> float: return self._previous_backoff +class ExponentialWithJitterBackoff(AbstractBackoff): + """Exponential backoff upon failure, with jitter""" + + def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None: + """ + `cap`: maximum backoff time in seconds + `base`: base backoff time in seconds + """ + self._cap = cap + self._base = base + + def compute(self, failures: int) -> float: + return min(self._cap, random.random() * self._base * 2**failures) + + def default_backoff(): return EqualJitterBackoff() diff --git a/tests/test_backoff.py b/tests/test_backoff.py new file mode 100644 index 0000000000..0a491276ff --- /dev/null +++ b/tests/test_backoff.py @@ -0,0 +1,18 @@ +from unittest.mock import Mock + +import pytest + +from redis.backoff import ExponentialWithJitterBackoff + + +def test_exponential_with_jitter_backoff(monkeypatch: pytest.MonkeyPatch) -> None: + mock_random = Mock(side_effect=[0.25, 0.5, 0.75, 1.0, 0.9]) + monkeypatch.setattr("random.random", mock_random) + + bo = ExponentialWithJitterBackoff(cap=5, base=1) + + assert bo.compute(0) == 0.25 # min(5, 0.25*2^0) + assert bo.compute(1) == 1.0 # min(5, 0.5*2^1) + assert bo.compute(2) == 3.0 # min(5, 0.75*2^2) + assert bo.compute(3) == 5.0 # min(5, 1*2^3) + assert bo.compute(4) == 5.0 # min(5, 0.9*2^4) From 35ca1025f2961eeb498335f8066d6596c87421c3 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Wed, 30 Apr 2025 13:02:33 +0300 Subject: [PATCH 12/17] Fixing the versions of some deprecations that wrongly added as 5.0.3 - the correct version is 5.3.0 (#3625) --- redis/asyncio/cluster.py | 2 +- redis/asyncio/connection.py | 4 ++-- redis/cluster.py | 6 +++--- redis/connection.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index aa472ed5b8..1a22636967 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -239,7 +239,7 @@ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": @deprecated_args( args_to_warn=["read_from_replicas"], reason="Please configure the 'load_balancing_strategy' instead", - version="5.0.3", + version="5.3.0", ) def __init__( self, diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index e67dc5b207..8c0dbdd32a 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1090,7 +1090,7 @@ def can_get_connection(self) -> bool: @deprecated_args( args_to_warn=["*"], reason="Use get_connection() without args instead", - version="5.0.3", + version="5.3.0", ) async def get_connection(self, command_name=None, *keys, **options): async with self._lock: @@ -1263,7 +1263,7 @@ def __init__( @deprecated_args( args_to_warn=["*"], reason="Use get_connection() without args instead", - version="5.0.3", + version="5.3.0", ) async def get_connection(self, command_name=None, *keys, **options): """Gets a connection from the pool, blocking until one is available""" diff --git a/redis/cluster.py b/redis/cluster.py index 7e7f590b82..37e810644b 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -59,7 +59,7 @@ def get_node_name(host: str, port: Union[str, int]) -> str: @deprecated_args( allowed_args=["redis_node"], reason="Use get_connection(redis_node) instead", - version="5.0.3", + version="5.3.0", ) def get_connection(redis_node, *args, **options): return redis_node.connection or redis_node.connection_pool.get_connection() @@ -509,7 +509,7 @@ class initializer. In the case of conflicting arguments, querystring @deprecated_args( args_to_warn=["read_from_replicas"], reason="Please configure the 'load_balancing_strategy' instead", - version="5.0.3", + version="5.3.0", ) def __init__( self, @@ -1503,7 +1503,7 @@ def _update_moved_slots(self): "In case you need select some load balancing strategy " "that will use replicas, please set it through 'load_balancing_strategy'" ), - version="5.0.3", + version="5.3.0", ) def get_node_from_slot( self, diff --git a/redis/connection.py b/redis/connection.py index 3189690802..ec0314ad15 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1465,7 +1465,7 @@ def _checkpid(self) -> None: @deprecated_args( args_to_warn=["*"], reason="Use get_connection() without args instead", - version="5.0.3", + version="5.3.0", ) def get_connection(self, command_name=None, *keys, **options) -> "Connection": "Get a connection from the pool" @@ -1693,7 +1693,7 @@ def make_connection(self): @deprecated_args( args_to_warn=["*"], reason="Use get_connection() without args instead", - version="5.0.3", + version="5.3.0", ) def get_connection(self, command_name=None, *keys, **options): """ From b928f971b4936eb6dfca5a32c3085a419a3a50a6 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Mon, 24 Feb 2025 17:33:56 +0200 Subject: [PATCH 13/17] Remove decreasing of created connections count when releasing not owned by connection pool connection(fixes issue #2832). (#3514) * Removing decreasing of created connections count when releasing not owned by connection pool connection(#2832). * Fixed another issue that was allowing adding connections to a pool owned by other pools. Adding unit tests. * Fixing a typo in a comment --- redis/connection.py | 10 +++++----- tests/test_connection_pool.py | 15 +++++++++++++++ tests/test_multiprocessing.py | 12 +----------- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index ec0314ad15..ec377c5f44 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1532,7 +1532,7 @@ def release(self, connection: "Connection") -> None: except KeyError: # Gracefully fail when a connection is returned to this pool # that the pool doesn't actually own - pass + return if self.owns_connection(connection): self._available_connections.append(connection) @@ -1540,10 +1540,10 @@ def release(self, connection: "Connection") -> None: AfterConnectionReleasedEvent(connection) ) else: - # pool doesn't own this connection. do not add it back - # to the pool and decrement the count so that another - # connection can take its place if needed - self._created_connections -= 1 + # Pool doesn't own this connection, do not add it back + # to the pool. + # The created connections count should not be changed, + # because the connection was not created by the pool. connection.disconnect() return diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 387a0f4565..65f42923fe 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -91,6 +91,21 @@ def test_reuse_previously_released_connection(self, master_host): c2 = pool.get_connection() assert c1 == c2 + def test_release_not_owned_connection(self, master_host): + connection_kwargs = {"host": master_host[0], "port": master_host[1]} + pool1 = self.get_pool(connection_kwargs=connection_kwargs) + c1 = pool1.get_connection("_") + pool2 = self.get_pool( + connection_kwargs={"host": master_host[0], "port": master_host[1]} + ) + c2 = pool2.get_connection("_") + pool2.release(c2) + + assert len(pool2._available_connections) == 1 + + pool2.release(c1) + assert len(pool2._available_connections) == 1 + def test_repr_contains_db_info_tcp(self): connection_kwargs = { "host": "localhost", diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index b4d2630b2b..f42f8c7919 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -22,16 +22,6 @@ def exit_callback(callback, *args): class TestMultiprocessing: - # On macOS and newly non-macOS POSIX systems (since Python 3.14), - # the default method has been changed to forkserver. - # The code in this module does not work with it, - # hence the explicit change to 'fork' - # See https://github.com/python/cpython/issues/125714 - if multiprocessing.get_start_method() in ["forkserver", "spawn"]: - _mp_context = multiprocessing.get_context(method="fork") - else: - _mp_context = multiprocessing.get_context() - # Test connection sharing between forks. # See issue #1085 for details. @@ -123,7 +113,7 @@ def target(pool, parent_conn): assert child_conn in pool._available_connections assert parent_conn not in pool._available_connections - proc = self._mp_context.Process(target=target, args=(pool, parent_conn)) + proc = multiprocessing.Process(target=target, args=(pool, parent_conn)) proc.start() proc.join(3) assert proc.exitcode == 0 From a33663958d42e698069962e260260c80fd093162 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 30 Apr 2025 14:42:52 +0300 Subject: [PATCH 14/17] Fixing linter errors - dues to backporting from master where formatter is updated --- redis/asyncio/cluster.py | 17 +++++++++++------ redis/cluster.py | 13 ++++++++----- tests/test_asyncio/test_cluster.py | 14 +++++++------- tests/test_backoff.py | 1 - tests/test_cluster.py | 15 ++++++++------- 5 files changed, 34 insertions(+), 26 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 1a22636967..b32e6ff23b 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -139,11 +139,13 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand | Enable read from replicas in READONLY mode. When set to true, read commands will be assigned between the primary and its replications in a Round-Robin manner. - The data read from replicas is eventually consistent with the data in primary nodes. + The data read from replicas is eventually consistent + with the data in primary nodes. :param load_balancing_strategy: | Enable read from replicas in READONLY mode and defines the load balancing strategy that will be used for cluster node selection. - The data read from replicas is eventually consistent with the data in primary nodes. + The data read from replicas is eventually consistent + with the data in primary nodes. :param reinitialize_steps: | Specifies the number of MOVED errors that need to occur before reinitializing the whole cluster topology. If a MOVED error occurs and the cluster does not @@ -800,9 +802,11 @@ async def _execute_command( target_node = self.nodes_manager.get_node_from_slot( slot, self.read_from_replicas and args[0] in READ_COMMANDS, - self.load_balancing_strategy - if args[0] in READ_COMMANDS - else None, + ( + self.load_balancing_strategy + if args[0] in READ_COMMANDS + else None + ), ) moved = False @@ -1276,7 +1280,8 @@ def get_node_from_slot( try: if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: - # get the server index using the strategy defined in load_balancing_strategy + # get the server index using the strategy defined + # in load_balancing_strategy primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( primary_name, len(self.slots_cache[slot]), load_balancing_strategy diff --git a/redis/cluster.py b/redis/cluster.py index 37e810644b..7549295ea6 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -558,7 +558,8 @@ def __init__( :param load_balancing_strategy: Enable read from replicas in READONLY mode and defines the load balancing strategy that will be used for cluster node selection. - The data read from replicas is eventually consistent with the data in primary nodes. + The data read from replicas is eventually consistent + with the data in primary nodes. :param dynamic_startup_nodes: Set the RedisCluster's startup nodes to all of the discovered nodes. If true (default value), the cluster's discovered nodes will be used to @@ -1190,9 +1191,11 @@ def _execute_command(self, target_node, *args, **kwargs): target_node = self.nodes_manager.get_node_from_slot( slot, self.read_from_replicas and command in READ_COMMANDS, - self.load_balancing_strategy - if command in READ_COMMANDS - else None, + ( + self.load_balancing_strategy + if command in READ_COMMANDS + else None + ), ) moved = False @@ -1366,7 +1369,7 @@ def get_server_index( self, primary: str, list_size: int, - load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN, + load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN, # noqa: line too long ignored ) -> int: if load_balancing_strategy == LoadBalancingStrategy.RANDOM_REPLICA: return self._get_random_replica_index(list_size) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 7e7ad23d15..a087134c95 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -152,7 +152,6 @@ async def get_mocked_redis_client( with mock.patch.object(ClusterNode, "execute_command") as execute_command_mock: async def execute_command(*_args, **_kwargs): - if _args[0] == "CLUSTER SLOTS": if cluster_slots_raise_error: raise ResponseError() @@ -197,7 +196,8 @@ def __del__(self): @property def connection_pool(self): # Required abstract property implementation - return self.nodes_manager.get_default_node().redis_connection.connection_pool + default_node = self.nodes_manager.get_default_node() + return default_node.redis_connection.connection_pool return await MockedRedisCluster(*args, **kwargs) @@ -1643,7 +1643,7 @@ async def test_cluster_bitop_not_empty_string(self, r: RedisCluster) -> None: @skip_if_server_version_lt("2.6.0") async def test_cluster_bitop_not(self, r: RedisCluster) -> None: - test_str = b"\xAA\x00\xFF\x55" + test_str = b"\xaa\x00\xff\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF await r.set("{foo}a", test_str) await r.bitop("not", "{foo}r", "{foo}a") @@ -1651,7 +1651,7 @@ async def test_cluster_bitop_not(self, r: RedisCluster) -> None: @skip_if_server_version_lt("2.6.0") async def test_cluster_bitop_not_in_place(self, r: RedisCluster) -> None: - test_str = b"\xAA\x00\xFF\x55" + test_str = b"\xaa\x00\xff\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF await r.set("{foo}a", test_str) await r.bitop("not", "{foo}a", "{foo}a") @@ -1659,7 +1659,7 @@ async def test_cluster_bitop_not_in_place(self, r: RedisCluster) -> None: @skip_if_server_version_lt("2.6.0") async def test_cluster_bitop_single_string(self, r: RedisCluster) -> None: - test_str = b"\x01\x02\xFF" + test_str = b"\x01\x02\xff" await r.set("{foo}a", test_str) await r.bitop("and", "{foo}res1", "{foo}a") await r.bitop("or", "{foo}res2", "{foo}a") @@ -1670,8 +1670,8 @@ async def test_cluster_bitop_single_string(self, r: RedisCluster) -> None: @skip_if_server_version_lt("2.6.0") async def test_cluster_bitop_string_operands(self, r: RedisCluster) -> None: - await r.set("{foo}a", b"\x01\x02\xFF\xFF") - await r.set("{foo}b", b"\x01\x02\xFF") + await r.set("{foo}a", b"\x01\x02\xff\xff") + await r.set("{foo}b", b"\x01\x02\xff") await r.bitop("and", "{foo}res1", "{foo}a", "{foo}b") await r.bitop("or", "{foo}res2", "{foo}a", "{foo}b") await r.bitop("xor", "{foo}res3", "{foo}a", "{foo}b") diff --git a/tests/test_backoff.py b/tests/test_backoff.py index 0a491276ff..234796dde0 100644 --- a/tests/test_backoff.py +++ b/tests/test_backoff.py @@ -1,7 +1,6 @@ from unittest.mock import Mock import pytest - from redis.backoff import ExponentialWithJitterBackoff diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 118c355b97..3757d43ae0 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -4,10 +4,10 @@ import socket import socketserver import threading -from typing import List import warnings from queue import LifoQueue, Queue from time import sleep +from typing import List from unittest.mock import DEFAULT, Mock, call, patch import pytest @@ -213,7 +213,8 @@ def __del__(self): @property def connection_pool(self): # Required abstract property implementation - return self.nodes_manager.get_default_node().redis_connection.connection_pool + default_node = self.nodes_manager.get_default_node() + return default_node.redis_connection.connection_pool return MockedRedisCluster(*args, **kwargs) @@ -1766,7 +1767,7 @@ def test_cluster_bitop_not_empty_string(self, r): @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_not(self, r): - test_str = b"\xAA\x00\xFF\x55" + test_str = b"\xaa\x00\xff\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF r["{foo}a"] = test_str r.bitop("not", "{foo}r", "{foo}a") @@ -1774,7 +1775,7 @@ def test_cluster_bitop_not(self, r): @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_not_in_place(self, r): - test_str = b"\xAA\x00\xFF\x55" + test_str = b"\xaa\x00\xff\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF r["{foo}a"] = test_str r.bitop("not", "{foo}a", "{foo}a") @@ -1782,7 +1783,7 @@ def test_cluster_bitop_not_in_place(self, r): @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_single_string(self, r): - test_str = b"\x01\x02\xFF" + test_str = b"\x01\x02\xff" r["{foo}a"] = test_str r.bitop("and", "{foo}res1", "{foo}a") r.bitop("or", "{foo}res2", "{foo}a") @@ -1793,8 +1794,8 @@ def test_cluster_bitop_single_string(self, r): @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_string_operands(self, r): - r["{foo}a"] = b"\x01\x02\xFF\xFF" - r["{foo}b"] = b"\x01\x02\xFF" + r["{foo}a"] = b"\x01\x02\xff\xff" + r["{foo}b"] = b"\x01\x02\xff" r.bitop("and", "{foo}res1", "{foo}a", "{foo}b") r.bitop("or", "{foo}res2", "{foo}a", "{foo}b") r.bitop("xor", "{foo}res3", "{foo}a", "{foo}b") From 1c0a6f21250f44326dff217b960169b9d1fd0e71 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 30 Apr 2025 15:08:02 +0300 Subject: [PATCH 15/17] Fix flake version to the last known working with our code. Since flake is no longer used in other versions - we don't need to use the latest --- dev_requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 728536d6fb..db7fba2c98 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,7 +1,7 @@ black==24.3.0 click==8.0.4 -flake8-isort -flake8 +flake8-isort==6.1.2 +flake8==7.1.1 flynt~=0.69.0 invoke==2.2.0 mock From 46740c88e1f95a33d900844737bee4e7df5432dd Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 13 Feb 2025 14:45:40 +0200 Subject: [PATCH 16/17] Replacing the redis and redis-stack-server images with redis-libs-tests image in test infrastructure (#3505) * Replacing the redis image with redis-libs-tests image in test infrastructure * Replacing redis-stack-server image usage with client-libs-test. Fixing lib version in setup.py * Defining stack tag variable for the build and test github action * Removing unused env var from build and test github actions --- .github/actions/run-tests/action.yml | 49 ++++++++++--------- .github/workflows/integration.yaml | 8 ++-- .gitignore | 3 ++ docker-compose.yml | 70 ++++++++++++---------------- 4 files changed, 61 insertions(+), 69 deletions(-) diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index e5dcef03ff..7765a15648 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -31,57 +31,56 @@ runs: - name: Setup Test environment env: REDIS_VERSION: ${{ inputs.redis-version }} - REDIS_IMAGE: "redis:${{ inputs.redis-version }}" - CLIENT_LIBS_TEST_IMAGE: "redislabs/client-libs-test:${{ inputs.redis-version }}" + CLIENT_LIBS_TEST_IMAGE_TAG: ${{ inputs.redis-version }} run: | set -e - + echo "::group::Installing dependencies" pip install -U setuptools wheel pip install -r requirements.txt pip install -r dev_requirements.txt if [ "${{inputs.parser-backend}}" == "hiredis" ]; then pip install "hiredis${{inputs.hiredis-version}}" - echo "PARSER_BACKEND=$(echo "${{inputs.parser-backend}}_${{inputs.hiredis-version}}" | sed 's/[^a-zA-Z0-9]/_/g')" >> $GITHUB_ENV + echo "PARSER_BACKEND=$(echo "${{inputs.parser-backend}}_${{inputs.hiredis-version}}" | sed 's/[^a-zA-Z0-9]/_/g')" >> $GITHUB_ENV else echo "PARSER_BACKEND=${{inputs.parser-backend}}" >> $GITHUB_ENV fi echo "::endgroup::" - + echo "::group::Starting Redis servers" redis_major_version=$(echo "$REDIS_VERSION" | grep -oP '^\d+') - + if (( redis_major_version < 8 )); then echo "Using redis-stack for module tests" - - # Mapping of redis version to stack version + + # Mapping of redis version to stack version declare -A redis_stack_version_mapping=( - ["7.4.2"]="7.4.0-v3" - ["7.2.7"]="7.2.0-v15" - ["6.2.17"]="6.2.6-v19" + ["7.4.2"]="rs-7.4.0-v2" + ["7.2.7"]="rs-7.2.0-v14" + ["6.2.17"]="rs-6.2.6-v18" ) - + if [[ -v redis_stack_version_mapping[$REDIS_VERSION] ]]; then - export REDIS_STACK_IMAGE="redis/redis-stack-server:${redis_stack_version_mapping[$REDIS_VERSION]}" + export CLIENT_LIBS_TEST_STACK_IMAGE_TAG=${redis_stack_version_mapping[$REDIS_VERSION]} echo "REDIS_MOD_URL=redis://127.0.0.1:6479/0" >> $GITHUB_ENV else echo "Version not found in the mapping." exit 1 fi - + if (( redis_major_version < 7 )); then export REDIS_STACK_EXTRA_ARGS="--tls-auth-clients optional --save ''" export REDIS_EXTRA_ARGS="--tls-auth-clients optional --save ''" echo "REDIS_MAJOR_VERSION=${redis_major_version}" >> $GITHUB_ENV fi - + invoke devenv --endpoints=all-stack else echo "Using redis CE for module tests" echo "REDIS_MOD_URL=redis://127.0.0.1:6379" >> $GITHUB_ENV invoke devenv --endpoints all - fi - + fi + sleep 10 # time to settle echo "::endgroup::" shell: bash @@ -89,34 +88,34 @@ runs: - name: Run tests run: | set -e - + run_tests() { local protocol=$1 local eventloop="" - + if [ "${{inputs.event-loop}}" == "uvloop" ]; then eventloop="--uvloop" fi - + echo "::group::RESP${protocol} standalone tests" echo "REDIS_MOD_URL=${REDIS_MOD_URL}" - + if (( $REDIS_MAJOR_VERSION < 7 )) && [ "$protocol" == "3" ]; then echo "Skipping module tests: Modules doesn't support RESP3 for Redis versions < 7" invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" --extra-markers="not redismod and not cp_integration" - else + else invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" fi - + echo "::endgroup::" - + if [ "$protocol" == "2" ] || [ "${{inputs.parser-backend}}" != 'hiredis' ]; then echo "::group::RESP${protocol} cluster tests" invoke cluster-tests $eventloop --protocol=${protocol} echo "::endgroup::" fi } - + run_tests 2 "${{inputs.event-loop}}" run_tests 3 "${{inputs.event-loop}}" shell: bash diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 7e92cfb92d..22b31e180e 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -27,8 +27,7 @@ env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} # this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665 COVERAGE_CORE: sysmon - REDIS_IMAGE: redis:latest - REDIS_STACK_IMAGE: redis/redis-stack-server:latest + CURRENT_CLIENT_LIBS_TEST_STACK_IMAGE_TAG: 'rs-7.4.0-v2' CURRENT_REDIS_VERSION: '7.4.2' jobs: @@ -180,9 +179,8 @@ jobs: python-version: 3.9 - name: Run installed unit tests env: - REDIS_VERSION: ${{ env.CURRENT_REDIS_VERSION }} - REDIS_IMAGE: "redis:${{ env.CURRENT_REDIS_VERSION }}" - CLIENT_LIBS_TEST_IMAGE: "redislabs/client-libs-test:${{ env.CURRENT_REDIS_VERSION }}" + CLIENT_LIBS_TEST_IMAGE_TAG: ${{ env.CURRENT_REDIS_VERSION }} + CLIENT_LIBS_TEST_STACK_IMAGE_TAG: ${{ env.CURRENT_CLIENT_LIBS_TEST_STACK_IMAGE_TAG }} run: | bash .github/workflows/install_and_test.sh ${{ matrix.extension }} diff --git a/.gitignore b/.gitignore index ee1bda0fa5..5f77dcfde4 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,6 @@ docker/stunnel/keys /dockers/*/tls/* /dockers/standalone/ /dockers/cluster/ +/dockers/replica/ +/dockers/sentinel/ +/dockers/redis-stack/ diff --git a/docker-compose.yml b/docker-compose.yml index 60657d5653..8ca3471311 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,9 +1,14 @@ --- +x-client-libs-stack-image: &client-libs-stack-image + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-rs-7.4.0-v2}" + +x-client-libs-image: &client-libs-image + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-7.4.2}" services: redis: - image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:7.4.1} + <<: *client-libs-image container_name: redis-standalone environment: - TLS_ENABLED=yes @@ -24,20 +29,26 @@ services: - all replica: - image: ${REDIS_IMAGE:-redis:7.4.1} + <<: *client-libs-image container_name: redis-replica depends_on: - redis - command: redis-server --replicaof redis 6379 --protected-mode no --save "" + environment: + - TLS_ENABLED=no + - REDIS_CLUSTER=no + - PORT=6380 + command: ${REDIS_EXTRA_ARGS:---enable-debug-command yes --replicaof redis 6379 --protected-mode no --save ""} ports: - - 6380:6379 + - 6380:6380 + volumes: + - "./dockers/replica:/redis/work" profiles: - replica - all-stack - all cluster: - image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:7.4.1} + <<: *client-libs-image container_name: redis-cluster environment: - REDIS_CLUSTER=yes @@ -58,57 +69,38 @@ services: - all sentinel: - image: ${REDIS_IMAGE:-redis:7.4.1} + <<: *client-libs-image container_name: redis-sentinel depends_on: - redis - entrypoint: "redis-sentinel /redis.conf --port 26379" + environment: + - REDIS_CLUSTER=no + - NODES=3 + - PORT=26379 + command: ${REDIS_EXTRA_ARGS:---sentinel} ports: - 26379:26379 - volumes: - - "./dockers/sentinel.conf:/redis.conf" - profiles: - - sentinel - - all-stack - - all - - sentinel2: - image: ${REDIS_IMAGE:-redis:7.4.1} - container_name: redis-sentinel2 - depends_on: - - redis - entrypoint: "redis-sentinel /redis.conf --port 26380" - ports: - 26380:26380 - volumes: - - "./dockers/sentinel.conf:/redis.conf" - profiles: - - sentinel - - all-stack - - all - - sentinel3: - image: ${REDIS_IMAGE:-redis:7.4.1} - container_name: redis-sentinel3 - depends_on: - - redis - entrypoint: "redis-sentinel /redis.conf --port 26381" - ports: - 26381:26381 volumes: - - "./dockers/sentinel.conf:/redis.conf" + - "./dockers/sentinel.conf:/redis/config-default/redis.conf" + - "./dockers/sentinel:/redis/work" profiles: - sentinel - all-stack - all redis-stack: - image: ${REDIS_STACK_IMAGE:-redis/redis-stack-server:latest} + <<: *client-libs-stack-image container_name: redis-stack + environment: + - REDIS_CLUSTER=no + - PORT=6379 + command: ${REDIS_EXTRA_ARGS:---enable-debug-command yes --enable-module-command yes --save ""} ports: - 6479:6379 - environment: - - "REDIS_ARGS=${REDIS_STACK_EXTRA_ARGS:---enable-debug-command yes --enable-module-command yes --save ''}" + volumes: + - "./dockers/redis-stack:/redis/work" profiles: - standalone - all-stack From 653d9ef00fc687e72c0bb065d2c61c664b13b63b Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 30 Apr 2025 16:19:16 +0300 Subject: [PATCH 17/17] flake8-isort version is set to 6.1.1 - to be compatible with python 3.8 --- dev_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index db7fba2c98..f9b6efdb11 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,6 +1,6 @@ black==24.3.0 click==8.0.4 -flake8-isort==6.1.2 +flake8-isort==6.1.1 flake8==7.1.1 flynt~=0.69.0 invoke==2.2.0