Skip to content

Commit

Permalink
narrow: Add backend support for "channels" operator.
Browse files Browse the repository at this point in the history
Adds backend support for "channels" operator.

This will deprecate/replace the "streams" operator eventually, but
we will keep support of the operator for backwards compatibility
for a while.

Part of renaming stream to channel project.
  • Loading branch information
laurynmm authored and timabbott committed Apr 12, 2024
1 parent 0e972e2 commit 608b305
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 41 deletions.
23 changes: 14 additions & 9 deletions zerver/lib/narrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def read_stop_words() -> List[str]:

# "stream" is a legacy alias for "channel"
channel_operators: List[str] = ["channel", "stream"]
# "streams" is a legacy alias for "channels"
channels_operators: List[str] = ["channels", "streams"]


def check_narrow_for_events(narrow: Collection[NarrowTerm]) -> None:
Expand All @@ -113,7 +115,7 @@ def is_spectator_compatible(narrow: Iterable[Dict[str, Any]]) -> bool:
# This implementation should agree with is_spectator_compatible in hash_parser.ts.
supported_operators = [
*channel_operators,
"streams",
*channels_operators,
"topic",
"sender",
"has",
Expand All @@ -136,8 +138,9 @@ def is_web_public_narrow(narrow: Optional[Iterable[Dict[str, Any]]]) -> bool:

return any(
# Web-public queries are only allowed for limited types of narrows.
# term == {'operator': 'streams', 'operand': 'web-public', 'negated': False}
term["operator"] == "streams"
# term == {'operator': 'channels', 'operand': 'web-public', 'negated': False}
# or term == {'operator': 'streams', 'operand': 'web-public', 'negated': False}
term["operator"] in channels_operators
and term["operand"] == "web-public"
and term["negated"] is False
for term in narrow
Expand Down Expand Up @@ -286,7 +289,9 @@ def __init__(
"channel": self.by_channel,
# "stream" is a legacy alias for "channel"
"stream": self.by_channel,
"streams": self.by_streams,
"channels": self.by_channels,
# "streams" is a legacy alias for "channels"
"streams": self.by_channels,
"topic": self.by_topic,
"sender": self.by_sender,
"near": self.by_near,
Expand Down Expand Up @@ -488,17 +493,17 @@ def by_channel(
cond = column("recipient_id", Integer) == recipient_id
return query.where(maybe_negate(cond))

def by_streams(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
def by_channels(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
self.check_not_both_channel_and_dm_narrow(is_channel_narrow=True)

if operand == "public":
# Get all both subscribed and non-subscribed public streams
# but exclude any private subscribed streams.
# Get all both subscribed and non-subscribed public channels
# but exclude any private subscribed channels.
recipient_queryset = get_public_streams_queryset(self.realm)
elif operand == "web-public":
recipient_queryset = get_web_public_streams_queryset(self.realm)
else:
raise BadNarrowOperatorError("unknown streams operand " + operand)
raise BadNarrowOperatorError("unknown channels operand " + operand)

recipient_ids = recipient_queryset.values_list("recipient_id", flat=True).order_by("id")
cond = column("recipient_id", Integer).in_(recipient_ids)
Expand Down Expand Up @@ -911,7 +916,7 @@ def ok_to_include_history(
else:
include_history = can_access_stream_history_by_id(user_profile, operand)
elif (
term["operator"] == "streams"
term["operator"] in channels_operators
and term["operand"] == "public"
and not term.get("negated", False)
and user_profile.can_access_public_streams()
Expand Down
88 changes: 56 additions & 32 deletions zerver/tests/test_message_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,14 @@ def test_add_term_using_channel_operator_and_non_existing_operand_should_raise_e
term = dict(operator="channel", operand="non-existing-channel")
self.assertRaises(BadNarrowOperatorError, self._build_query, term)

def test_add_term_using_streams_operator_and_invalid_operand_should_raise_error(
def test_add_term_using_channels_operator_and_invalid_operand_should_raise_error(
self,
) -> None: # NEGATED
term = dict(operator="streams", operand="invalid_operands")
term = dict(operator="channels", operand="invalid_operands")
self.assertRaises(BadNarrowOperatorError, self._build_query, term)

def test_add_term_using_streams_operator_and_public_operand(self) -> None:
term = dict(operator="streams", operand="public")
def test_add_term_using_channels_operator_and_public_operand(self) -> None:
term = dict(operator="channels", operand="public")
self._do_add_term_test(
term,
"WHERE recipient_id IN (__[POSTCOMPILE_recipient_id_1])",
Expand Down Expand Up @@ -173,8 +173,8 @@ def test_add_term_using_streams_operator_and_public_operand(self) -> None:
"WHERE recipient_id IN (__[POSTCOMPILE_recipient_id_1])",
)

def test_add_term_using_streams_operator_and_public_operand_negated(self) -> None:
term = dict(operator="streams", operand="public", negated=True)
def test_add_term_using_channels_operator_and_public_operand_negated(self) -> None:
term = dict(operator="channels", operand="public", negated=True)
self._do_add_term_test(
term,
"WHERE (recipient_id NOT IN (__[POSTCOMPILE_recipient_id_1]))",
Expand Down Expand Up @@ -322,7 +322,7 @@ def test_combined_channel_dm(self) -> None:
self._build_query(topic_term)
self.assertEqual(expected_error_message, str(error.exception))

channels_term = dict(operator="streams", operand="public")
channels_term = dict(operator="channels", operand="public")
with self.assertRaises(BadNarrowOperatorError) as error:
self._build_query(channels_term)
self.assertEqual(expected_error_message, str(error.exception))
Expand Down Expand Up @@ -633,6 +633,27 @@ def test_add_term_using_stream_operator_and_non_existing_operand_should_raise_er
term = dict(operator="stream", operand="non-existing-channel")
self.assertRaises(BadNarrowOperatorError, self._build_query, term)

# Test that "streams" (legacy alias for "channels" operator) works.
def test_add_term_using_streams_operator_and_invalid_operand_should_raise_error(
self,
) -> None: # NEGATED
term = dict(operator="streams", operand="invalid_operands")
self.assertRaises(BadNarrowOperatorError, self._build_query, term)

def test_add_term_using_streams_operator_and_public_operand(self) -> None:
term = dict(operator="streams", operand="public")
self._do_add_term_test(
term,
"WHERE recipient_id IN (__[POSTCOMPILE_recipient_id_1])",
)

def test_add_term_using_streams_operator_and_public_operand_negated(self) -> None:
term = dict(operator="streams", operand="public", negated=True)
self._do_add_term_test(
term,
"WHERE (recipient_id NOT IN (__[POSTCOMPILE_recipient_id_1]))",
)

def _do_add_term_test(
self, term: Dict[str, Any], where_clause: str, params: Optional[Dict[str, Any]] = None
) -> None:
Expand Down Expand Up @@ -928,7 +949,7 @@ def test_is_spectator_compatible(self) -> None:
)
self.assertFalse(is_spectator_compatible([{"operator": "is", "operand": "starred"}]))
self.assertFalse(is_spectator_compatible([{"operator": "is", "operand": "dm"}]))
self.assertTrue(is_spectator_compatible([{"operator": "streams", "operand": "public"}]))
self.assertTrue(is_spectator_compatible([{"operator": "channels", "operand": "public"}]))

# Malformed input not allowed
self.assertFalse(is_spectator_compatible([{"operator": "has"}]))
Expand All @@ -953,6 +974,8 @@ def test_is_spectator_compatible(self) -> None:
]
)
)
# "streams" is a legacy alias for "channels" operator
self.assertTrue(is_spectator_compatible([{"operator": "streams", "operand": "public"}]))


class IncludeHistoryTest(ZulipTestCase):
Expand All @@ -966,15 +989,15 @@ def test_ok_to_include_history(self) -> None:
]
self.assertFalse(ok_to_include_history(narrow, user_profile, False))

# streams:public searches should include history for non-guest members.
# channels:public searches should include history for non-guest members.
narrow = [
dict(operator="streams", operand="public"),
dict(operator="channels", operand="public"),
]
self.assertTrue(ok_to_include_history(narrow, user_profile, False))

# Negated -streams:public searches should not include history.
# Negated -channels:public searches should not include history.
narrow = [
dict(operator="streams", operand="public", negated=True),
dict(operator="channels", operand="public", negated=True),
]
self.assertFalse(ok_to_include_history(narrow, user_profile, False))

Expand Down Expand Up @@ -1029,24 +1052,24 @@ def test_ok_to_include_history(self) -> None:
self.assertFalse(ok_to_include_history(narrow, user_profile, False))

# No point in searching history for is operator even if included with
# streams:public
# channels:public
narrow = [
dict(operator="streams", operand="public"),
dict(operator="channels", operand="public"),
dict(operator="is", operand="mentioned"),
]
self.assertFalse(ok_to_include_history(narrow, user_profile, False))
narrow = [
dict(operator="streams", operand="public"),
dict(operator="channels", operand="public"),
dict(operator="is", operand="unread"),
]
self.assertFalse(ok_to_include_history(narrow, user_profile, False))
narrow = [
dict(operator="streams", operand="public"),
dict(operator="channels", operand="public"),
dict(operator="is", operand="alerted"),
]
self.assertFalse(ok_to_include_history(narrow, user_profile, False))
narrow = [
dict(operator="streams", operand="public"),
dict(operator="channels", operand="public"),
dict(operator="is", operand="resolved"),
]
self.assertFalse(ok_to_include_history(narrow, user_profile, False))
Expand All @@ -1069,9 +1092,9 @@ def test_ok_to_include_history(self) -> None:
# Using 'Cordelia' to compare between a guest and a normal user
subscribed_user_profile = self.example_user("cordelia")

# streams:public searches should not include history for guest members.
# channels:public searches should not include history for guest members.
narrow = [
dict(operator="streams", operand="public"),
dict(operator="channels", operand="public"),
]
self.assertFalse(ok_to_include_history(narrow, guest_user_profile, False))

Expand Down Expand Up @@ -1869,7 +1892,7 @@ def test_successful_get_messages(self) -> None:
)

def test_unauthenticated_get_messages(self) -> None:
# Require `streams:web-public` as narrow to get web-public messages.
# Require channels:web-public as narrow to get web-public messages.
get_params = {
"anchor": 10000000000000000,
"num_before": 5,
Expand All @@ -1886,7 +1909,7 @@ def test_unauthenticated_get_messages(self) -> None:
# Successful access to web-public channel messages.
web_public_channel_get_params: Dict[str, Union[int, str, bool]] = {
**get_params,
"narrow": orjson.dumps([dict(operator="streams", operand="web-public")]).decode(),
"narrow": orjson.dumps([dict(operator="channels", operand="web-public")]).decode(),
}
result = self.client_get("/json/messages", dict(web_public_channel_get_params))
# More detailed check of message parameters is done in `test_get_messages_with_web_public`.
Expand Down Expand Up @@ -1918,7 +1941,7 @@ def test_unauthenticated_get_messages(self) -> None:
# "is:dm" is not a is_spectator_compatible narrow.
"narrow": orjson.dumps(
[
dict(operator="streams", operand="web-public"),
dict(operator="channels", operand="web-public"),
dict(operator="is", operand="dm"),
]
).decode(),
Expand All @@ -1937,20 +1960,20 @@ def test_unauthenticated_get_messages(self) -> None:
result = self.client_get("/json/messages", dict(web_public_channel_get_params))
self.assert_json_success(result)

# Cannot access even web-public channels without `streams:web-public` narrow.
# Cannot access even web-public channels without channels:web-public narrow.
non_web_public_channel_get_params: Dict[str, Union[int, str, bool]] = {
**get_params,
"narrow": orjson.dumps([dict(operator="channel", operand="Rome")]).decode(),
}
result = self.client_get("/json/messages", dict(non_web_public_channel_get_params))
self.check_unauthenticated_response(result)

# Verify that same request would work with `streams:web-public` added.
# Verify that same request would work with channels:web-public added.
rome_web_public_get_params: Dict[str, Union[int, str, bool]] = {
**get_params,
"narrow": orjson.dumps(
[
dict(operator="streams", operand="web-public"),
dict(operator="channels", operand="web-public"),
# Rome is a web-channel channel.
dict(operator="channel", operand="Rome"),
]
Expand All @@ -1959,12 +1982,12 @@ def test_unauthenticated_get_messages(self) -> None:
result = self.client_get("/json/messages", dict(rome_web_public_get_params))
self.assert_json_success(result)

# Cannot access non-web-public channel even with `streams:web-public` narrow.
# Cannot access non-web-public channel even with channels:web-public narrow.
scotland_web_public_get_params: Dict[str, Union[int, str, bool]] = {
**get_params,
"narrow": orjson.dumps(
[
dict(operator="streams", operand="web-public"),
dict(operator="channels", operand="web-public"),
# Scotland is not a web-public channel.
dict(operator="channel", operand="Scotland"),
]
Expand Down Expand Up @@ -2024,7 +2047,7 @@ def test_unauthenticated_narrow_to_web_public_channels(self) -> None:
"num_after": 1,
"narrow": orjson.dumps(
[
dict(operator="streams", operand="web-public"),
dict(operator="channels", operand="web-public"),
dict(operator="channel", operand="web-public-channel"),
]
).decode(),
Expand All @@ -2035,7 +2058,7 @@ def test_unauthenticated_narrow_to_web_public_channels(self) -> None:
def test_get_messages_with_web_public(self) -> None:
"""
An unauthenticated call to GET /json/messages with valid parameters
including `streams:web-public` narrow returns list of messages in the
including channels:web-public narrow returns list of messages in the
web-public channels.
"""
self.setup_web_public_test(num_web_public_message=8)
Expand All @@ -2044,7 +2067,7 @@ def test_get_messages_with_web_public(self) -> None:
"anchor": "first_unread",
"num_before": 5,
"num_after": 1,
"narrow": orjson.dumps([dict(operator="streams", operand="web-public")]).decode(),
"narrow": orjson.dumps([dict(operator="channels", operand="web-public")]).decode(),
}
result = self.client_get("/json/messages", dict(post_params))
# Of the last 7 (num_before + num_after + 1) messages, only 5
Expand Down Expand Up @@ -4087,7 +4110,8 @@ def test_get_messages_with_narrow_queries(self) -> None:
sql_template = "SELECT anon_1.message_id \nFROM (SELECT id AS message_id \nFROM zerver_message \nWHERE realm_id = 2 AND recipient_id IN ({public_channels_recipients}) ORDER BY zerver_message.id ASC \n LIMIT 10) AS anon_1 ORDER BY message_id ASC"
sql = sql_template.format(**query_ids)
self.common_check_get_messages_query(
{"anchor": 0, "num_before": 0, "num_after": 9, "narrow": '[["streams", "public"]]'}, sql
{"anchor": 0, "num_before": 0, "num_after": 9, "narrow": '[["channels", "public"]]'},
sql,
)

sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \nWHERE user_profile_id = {hamlet_id} AND (recipient_id NOT IN ({public_channels_recipients})) ORDER BY message_id ASC \n LIMIT 10) AS anon_1 ORDER BY message_id ASC"
Expand All @@ -4097,7 +4121,7 @@ def test_get_messages_with_narrow_queries(self) -> None:
"anchor": 0,
"num_before": 0,
"num_after": 9,
"narrow": '[{"operator":"streams", "operand":"public", "negated": true}]',
"narrow": '[{"operator":"channels", "operand":"public", "negated": true}]',
},
sql,
)
Expand Down

0 comments on commit 608b305

Please sign in to comment.