diff --git a/docs/reference/esql-query-builder.md b/docs/reference/esql-query-builder.md
index 1cdc0c5b3..8390ea983 100644
--- a/docs/reference/esql-query-builder.md
+++ b/docs/reference/esql-query-builder.md
@@ -203,6 +203,26 @@ query = (
)
```
+### Preventing injection attacks
+
+ES|QL, like most query languages, is vulnerable to [code injection attacks](https://en.wikipedia.org/wiki/Code_injection) if untrusted data provided by users is added to a query. To eliminate this risk, ES|QL allows untrusted data to be given separately from the query as parameters.
+
+Continuing with the example above, let's assume that the application needs a `find_employee_by_name()` function that searches for the name given as an argument. If this argument is received by the application from users, then it is considered untrusted and should not be added to the query directly. Here is how to code the function in a secure manner:
+
+```python
+def find_employee_by_name(name):
+ query = (
+ ESQL.from_("employees")
+ .keep("first_name", "last_name", "height")
+ .where(E("first_name") == E("?"))
+ )
+ return client.esql.query(query=str(query), params=[name])
+```
+
+Here the part of the query in which the untrusted data needs to be inserted is replaced with a parameter, which in ES|QL is defined by the question mark. When using Python expressions, the parameter must be given as `E("?")` so that it is treated as an expression and not as a literal string.
+
+The list of values given in the `params` argument to the query endpoint are assigned in order to the parameters defined in the query.
+
## Using ES|QL functions
The ES|QL language includes a rich set of functions that can be used in expressions and conditionals. These can be included in expressions given as strings, as shown in the example below:
@@ -235,6 +255,6 @@ query = (
)
```
-Note that arguments passed to functions are assumed to be literals. When passing field names, it is necessary to wrap them with the `E()` helper function so that they are interpreted correctly.
+Note that arguments passed to functions are assumed to be literals. When passing field names, parameters or other ES|QL expressions, it is necessary to wrap them with the `E()` helper function so that they are interpreted correctly.
You can find the complete list of available functions in the Python client's [ES|QL API reference documentation](https://elasticsearch-py.readthedocs.io/en/stable/esql.html#module-elasticsearch.esql.functions).
diff --git a/docs/release-notes/breaking-changes.md b/docs/release-notes/breaking-changes.md
index 640a57036..0a354b9ce 100644
--- a/docs/release-notes/breaking-changes.md
+++ b/docs/release-notes/breaking-changes.md
@@ -28,7 +28,7 @@ For more information, check [PR #2840](https://github.com/elastic/elasticsearch-
* `host_info_callback` is now `sniffed_node_callback`
* `sniffer_timeout` is now `min_delay_between_sniffing`
* `sniff_on_connection_fail` is now `sniff_on_node_failure`
- * `maxsize` is now `connection_per_node`
+ * `maxsize` is now `connections_per_node`
::::
::::{dropdown} Remove deprecated url_prefix and use_ssl host keys
@@ -50,4 +50,4 @@ Elasticsearch 9 removed the kNN search and Unfreeze index APIs.
**Action**
* The kNN search API has been replaced by the `knn` option in the search API since Elasticsearch 8.4.
* The Unfreeze index API was deprecated in Elasticsearch 7.14 and has been removed in Elasticsearch 9.
- ::::
\ No newline at end of file
+ ::::
diff --git a/docs/release-notes/index.md b/docs/release-notes/index.md
index 314030cdd..76a1c610b 100644
--- a/docs/release-notes/index.md
+++ b/docs/release-notes/index.md
@@ -18,6 +18,56 @@ To check for security updates, go to [Security announcements for the Elastic sta
% *
% ### Fixes [elasticsearch-python-client-next-fixes]
+## 9.1.0 (2025-07-30)
+
+Enhancements
+
+* ES|QL query builder (technical preview) ([#2997](https://github.com/elastic/elasticsearch-py/pull/2997))
+* Update OpenTelemetry conventions ([#2999](https://github.com/elastic/elasticsearch-py/pull/2999))
+* Add option to disable accurate reporting of file and line location in warnings (Fixes #3003) ([#3006](https://github.com/elastic/elasticsearch-py/pull/3006))
+
+APIs
+
+* Remove `if_primary_term`, `if_seq_no` and `op_type` from Create API
+* Remove `master_timeout` from Ingest Get Ip Location Database API
+* Remove `application`, `priviledge` and `username` from the Security Get User API
+* Rename `type_query_string` to `type` in License Post Start Trial API
+* Add `require_data_stream` to Index API
+* Add `settings_filter` to Cluster Get Component Template API
+* Add `cause` to Cluster Put Component Template API
+* Add `master_timeout` to Cluster State API
+* Add `ccs_minimize_roundtrips` to EQL Search API
+* Add `keep_alive` and `keep_on_completion` to ES|QL Async Query API
+* Add `format` to ES|QL Async Query Get API
+* Add ES|QL Get Query and List Queries APIs
+* Add Indices Delete Data Stream Options API
+* Add Indices Get Data Stream Options and Put Data Stream Options APIS
+* Add Indices Get Data Stream Settings and Put Data Stream Settings APIs
+* Add `allow_no_indices`, `expand_wildcards` and `ignore_available` to Indices Recovery API
+* Add Indices Remove Block API
+* Add Amazon Sagemaker to Inference API
+* Add `input_type` to Inference API
+* Add `timeout` to all Inference Put APIs
+* Add Inference Put Custom API
+* Add Inference Put DeepSeek API
+* Add `task_settings` to Put HuggingFace API
+* Add `refresh` to Security Grant API Key API
+* Add `wait_for_completion` to the Snapshot Delete API
+* Add `state` to Snapshot Get API
+* Add `refresh` to Synonyms Put Synonym, Put Synonym Rule and Delete Synonym Rule APIs
+
+DSL
+
+* Handle lists in `copy_to` option in DSL field declarations correctly (Fixes #2992) ([#2993](https://github.com/elastic/elasticsearch-py/pull/2993))
+* Add `index_options` to SparseVector type
+* Add SparseVectorIndexOptions type
+* Add `key` to FiltersBucket type
+
+Other changes
+
+* Drop support for Python 3.8 ([#3001](https://github.com/elastic/elasticsearch-py/pull/3001))
+
+
## 9.0.2 (2025-06-05) [elasticsearch-python-client-902-release-notes]
diff --git a/elasticsearch/_async/client/__init__.py b/elasticsearch/_async/client/__init__.py
index 0874e120f..902834328 100644
--- a/elasticsearch/_async/client/__init__.py
+++ b/elasticsearch/_async/client/__init__.py
@@ -2234,7 +2234,6 @@ async def field_caps(
@_rewrite_parameters(
parameter_aliases={
"_source": "source",
- "_source_exclude_vectors": "source_exclude_vectors",
"_source_excludes": "source_excludes",
"_source_includes": "source_includes",
},
@@ -2254,7 +2253,6 @@ async def get(
refresh: t.Optional[bool] = None,
routing: t.Optional[str] = None,
source: t.Optional[t.Union[bool, t.Union[str, t.Sequence[str]]]] = None,
- source_exclude_vectors: t.Optional[bool] = None,
source_excludes: t.Optional[t.Union[str, t.Sequence[str]]] = None,
source_includes: t.Optional[t.Union[str, t.Sequence[str]]] = None,
stored_fields: t.Optional[t.Union[str, t.Sequence[str]]] = None,
@@ -2328,7 +2326,6 @@ async def get(
:param routing: A custom value used to route operations to a specific shard.
:param source: Indicates whether to return the `_source` field (`true` or `false`)
or lists the fields to return.
- :param source_exclude_vectors: Whether vectors should be excluded from _source
:param source_excludes: A comma-separated list of source fields to exclude from
the response. You can also use this parameter to exclude fields from the
subset specified in `_source_includes` query parameter. If the `_source`
@@ -2374,8 +2371,6 @@ async def get(
__query["routing"] = routing
if source is not None:
__query["_source"] = source
- if source_exclude_vectors is not None:
- __query["_source_exclude_vectors"] = source_exclude_vectors
if source_excludes is not None:
__query["_source_excludes"] = source_excludes
if source_includes is not None:
@@ -4309,7 +4304,6 @@ async def scroll(
),
parameter_aliases={
"_source": "source",
- "_source_exclude_vectors": "source_exclude_vectors",
"_source_excludes": "source_excludes",
"_source_includes": "source_includes",
"from": "from_",
@@ -4393,7 +4387,6 @@ async def search(
]
] = None,
source: t.Optional[t.Union[bool, t.Mapping[str, t.Any]]] = None,
- source_exclude_vectors: t.Optional[bool] = None,
source_excludes: t.Optional[t.Union[str, t.Sequence[str]]] = None,
source_includes: t.Optional[t.Union[str, t.Sequence[str]]] = None,
stats: t.Optional[t.Sequence[str]] = None,
@@ -4588,7 +4581,6 @@ async def search(
fields are returned in the `hits._source` property of the search response.
If the `stored_fields` property is specified, the `_source` property defaults
to `false`. Otherwise, it defaults to `true`.
- :param source_exclude_vectors: Whether vectors should be excluded from _source
:param source_excludes: A comma-separated list of source fields to exclude from
the response. You can also use this parameter to exclude fields from the
subset specified in `_source_includes` query parameter. If the `_source`
@@ -4713,8 +4705,6 @@ async def search(
__query["scroll"] = scroll
if search_type is not None:
__query["search_type"] = search_type
- if source_exclude_vectors is not None:
- __query["_source_exclude_vectors"] = source_exclude_vectors
if source_excludes is not None:
__query["_source_excludes"] = source_excludes
if source_includes is not None:
diff --git a/elasticsearch/_async/client/cluster.py b/elasticsearch/_async/client/cluster.py
index 91956f7c4..9ae420766 100644
--- a/elasticsearch/_async/client/cluster.py
+++ b/elasticsearch/_async/client/cluster.py
@@ -49,7 +49,6 @@ async def allocation_explain(
Explain the shard allocations.
Get explanations for shard allocations in the cluster.
- This API accepts the current_node, index, primary and shard parameters in the request body or in query parameters, but not in both at the same time.
For unassigned shards, it provides an explanation for why the shard is unassigned.
For assigned shards, it provides an explanation for why the shard is remaining on its current node and has not moved or rebalanced to another node.
This API can be very useful when attempting to diagnose why a shard is unassigned or why a shard continues to remain on its current node when you might expect otherwise.
@@ -58,16 +57,17 @@ async def allocation_explain(
`
Get data stream mappings.
-Get mapping information for one or more data streams.
- - - `Get the data stream options configuration of one or more data streams.
- `Update data stream mappings.
-This API can be used to override mappings on specific data streams. These overrides will take precedence over what - is specified in the template that the data stream matches. The mapping change is only applied to new write indices - that are created during rollover after this API is called. No indices are changed by this API.
- - - `completion
, rerank
, sparse_embedding
, text_embedding
)completion
, text_embedding
)chat_completion
, completion
, rerank
, sparse_embedding
, text_embedding
)completion
)completion
, 'rerank', text_embedding
)completion
, text_embedding
)completion
, text_embedding
)completion
, rerank
, text_embedding
)completion
, chat_completion
)Create an Amazon SageMaker inference endpoint.
+Create an inference endpoint to perform an inference task with the amazon_sagemaker
service.
Validate an anomaly detection job.
- `If you omit the <snapshot>
request path parameter, the request retrieves information only for currently running snapshots.
This usage is preferred.
If needed, you can specify <repository>
and <snapshot>
to retrieve information for specific snapshots, even if they're not currently running.
Note that the stats will not be available for any shard snapshots in an ongoing snapshot completed by a node that (even momentarily) left the cluster. - Loading the stats from the repository is an expensive operation (see the WARNING below). - Therefore the stats values for such shards will be -1 even though the "stage" value will be "DONE", in order to minimize latency. - A "description" field will be present for a shard snapshot completed by a departed node explaining why the shard snapshot's stats results are invalid. - Consequently, the total stats for the index will be less than expected due to the missing values from these shards.
WARNING: Using the API to return the status of any snapshots other than currently running snapshots can be expensive. The API requires a read from the repository for each shard in each snapshot. For example, if you have 100 snapshots with 1,000 shards each, an API request that includes all snapshots will require 100,000 reads (100 snapshots x 1,000 shards).
diff --git a/elasticsearch/_sync/client/__init__.py b/elasticsearch/_sync/client/__init__.py index 5f7a4313d..40f4cbed6 100644 --- a/elasticsearch/_sync/client/__init__.py +++ b/elasticsearch/_sync/client/__init__.py @@ -2232,7 +2232,6 @@ def field_caps( @_rewrite_parameters( parameter_aliases={ "_source": "source", - "_source_exclude_vectors": "source_exclude_vectors", "_source_excludes": "source_excludes", "_source_includes": "source_includes", }, @@ -2252,7 +2251,6 @@ def get( refresh: t.Optional[bool] = None, routing: t.Optional[str] = None, source: t.Optional[t.Union[bool, t.Union[str, t.Sequence[str]]]] = None, - source_exclude_vectors: t.Optional[bool] = None, source_excludes: t.Optional[t.Union[str, t.Sequence[str]]] = None, source_includes: t.Optional[t.Union[str, t.Sequence[str]]] = None, stored_fields: t.Optional[t.Union[str, t.Sequence[str]]] = None, @@ -2326,7 +2324,6 @@ def get( :param routing: A custom value used to route operations to a specific shard. :param source: Indicates whether to return the `_source` field (`true` or `false`) or lists the fields to return. - :param source_exclude_vectors: Whether vectors should be excluded from _source :param source_excludes: A comma-separated list of source fields to exclude from the response. You can also use this parameter to exclude fields from the subset specified in `_source_includes` query parameter. If the `_source` @@ -2372,8 +2369,6 @@ def get( __query["routing"] = routing if source is not None: __query["_source"] = source - if source_exclude_vectors is not None: - __query["_source_exclude_vectors"] = source_exclude_vectors if source_excludes is not None: __query["_source_excludes"] = source_excludes if source_includes is not None: @@ -4307,7 +4302,6 @@ def scroll( ), parameter_aliases={ "_source": "source", - "_source_exclude_vectors": "source_exclude_vectors", "_source_excludes": "source_excludes", "_source_includes": "source_includes", "from": "from_", @@ -4391,7 +4385,6 @@ def search( ] ] = None, source: t.Optional[t.Union[bool, t.Mapping[str, t.Any]]] = None, - source_exclude_vectors: t.Optional[bool] = None, source_excludes: t.Optional[t.Union[str, t.Sequence[str]]] = None, source_includes: t.Optional[t.Union[str, t.Sequence[str]]] = None, stats: t.Optional[t.Sequence[str]] = None, @@ -4586,7 +4579,6 @@ def search( fields are returned in the `hits._source` property of the search response. If the `stored_fields` property is specified, the `_source` property defaults to `false`. Otherwise, it defaults to `true`. - :param source_exclude_vectors: Whether vectors should be excluded from _source :param source_excludes: A comma-separated list of source fields to exclude from the response. You can also use this parameter to exclude fields from the subset specified in `_source_includes` query parameter. If the `_source` @@ -4711,8 +4703,6 @@ def search( __query["scroll"] = scroll if search_type is not None: __query["search_type"] = search_type - if source_exclude_vectors is not None: - __query["_source_exclude_vectors"] = source_exclude_vectors if source_excludes is not None: __query["_source_excludes"] = source_excludes if source_includes is not None: diff --git a/elasticsearch/_sync/client/cluster.py b/elasticsearch/_sync/client/cluster.py index a56892d54..2d4eebc54 100644 --- a/elasticsearch/_sync/client/cluster.py +++ b/elasticsearch/_sync/client/cluster.py @@ -49,7 +49,6 @@ def allocation_explain(Explain the shard allocations.
Get explanations for shard allocations in the cluster.
- This API accepts the current_node, index, primary and shard parameters in the request body or in query parameters, but not in both at the same time.
For unassigned shards, it provides an explanation for why the shard is unassigned.
For assigned shards, it provides an explanation for why the shard is remaining on its current node and has not moved or rebalanced to another node.
This API can be very useful when attempting to diagnose why a shard is unassigned or why a shard continues to remain on its current node when you might expect otherwise.
@@ -58,16 +57,17 @@ def allocation_explain(
`
Get data stream mappings.
-Get mapping information for one or more data streams.
- - - `Get the data stream options configuration of one or more data streams.
- `Update data stream mappings.
-This API can be used to override mappings on specific data streams. These overrides will take precedence over what - is specified in the template that the data stream matches. The mapping change is only applied to new write indices - that are created during rollover after this API is called. No indices are changed by this API.
- - - `completion
, rerank
, sparse_embedding
, text_embedding
)completion
, text_embedding
)chat_completion
, completion
, rerank
, sparse_embedding
, text_embedding
)completion
)completion
, 'rerank', text_embedding
)completion
, text_embedding
)completion
, text_embedding
)completion
, rerank
, text_embedding
)completion
, chat_completion
)Create an Amazon SageMaker inference endpoint.
+Create an inference endpoint to perform an inference task with the amazon_sagemaker
service.
Validate an anomaly detection job.
- `If you omit the <snapshot>
request path parameter, the request retrieves information only for currently running snapshots.
This usage is preferred.
If needed, you can specify <repository>
and <snapshot>
to retrieve information for specific snapshots, even if they're not currently running.
Note that the stats will not be available for any shard snapshots in an ongoing snapshot completed by a node that (even momentarily) left the cluster. - Loading the stats from the repository is an expensive operation (see the WARNING below). - Therefore the stats values for such shards will be -1 even though the "stage" value will be "DONE", in order to minimize latency. - A "description" field will be present for a shard snapshot completed by a departed node explaining why the shard snapshot's stats results are invalid. - Consequently, the total stats for the index will be less than expected due to the missing values from these shards.
WARNING: Using the API to return the status of any snapshots other than currently running snapshots can be expensive. The API requires a read from the repository for each shard in each snapshot. For example, if you have 100 snapshots with 1,000 shards each, an API request that includes all snapshots will require 100,000 reads (100 snapshots x 1,000 shards).
diff --git a/elasticsearch/_version.py b/elasticsearch/_version.py index 0624a7ff1..7b6c8994d 100644 --- a/elasticsearch/_version.py +++ b/elasticsearch/_version.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. -__versionstr__ = "9.0.2" +__versionstr__ = "9.1.0" diff --git a/elasticsearch/dsl/field.py b/elasticsearch/dsl/field.py index c33261458..895765e66 100644 --- a/elasticsearch/dsl/field.py +++ b/elasticsearch/dsl/field.py @@ -119,9 +119,16 @@ def __init__( def __getitem__(self, subfield: str) -> "Field": return cast(Field, self._params.get("fields", {})[subfield]) - def _serialize(self, data: Any) -> Any: + def _serialize(self, data: Any, skip_empty: bool) -> Any: return data + def _safe_serialize(self, data: Any, skip_empty: bool) -> Any: + try: + return self._serialize(data, skip_empty) + except TypeError: + # older method signature, without skip_empty + return self._serialize(data) # type: ignore[call-arg] + def _deserialize(self, data: Any) -> Any: return data @@ -133,10 +140,16 @@ def empty(self) -> Optional[Any]: return AttrList([]) return self._empty() - def serialize(self, data: Any) -> Any: + def serialize(self, data: Any, skip_empty: bool = True) -> Any: if isinstance(data, (list, AttrList, tuple)): - return list(map(self._serialize, cast(Iterable[Any], data))) - return self._serialize(data) + return list( + map( + self._safe_serialize, + cast(Iterable[Any], data), + [skip_empty] * len(data), + ) + ) + return self._safe_serialize(data, skip_empty) def deserialize(self, data: Any) -> Any: if isinstance(data, (list, AttrList, tuple)): @@ -186,7 +199,7 @@ def _deserialize(self, data: Any) -> Range["_SupportsComparison"]: data = {k: self._core_field.deserialize(v) for k, v in data.items()} # type: ignore[union-attr] return Range(data) - def _serialize(self, data: Any) -> Optional[Dict[str, Any]]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[Dict[str, Any]]: if data is None: return None if not isinstance(data, collections.abc.Mapping): @@ -550,7 +563,7 @@ def _deserialize(self, data: Any) -> "InnerDoc": return self._wrap(data) def _serialize( - self, data: Optional[Union[Dict[str, Any], "InnerDoc"]] + self, data: Optional[Union[Dict[str, Any], "InnerDoc"]], skip_empty: bool ) -> Optional[Dict[str, Any]]: if data is None: return None @@ -559,7 +572,7 @@ def _serialize( if isinstance(data, collections.abc.Mapping): return data - return data.to_dict() + return data.to_dict(skip_empty=skip_empty) def clean(self, data: Any) -> Any: data = super().clean(data) @@ -768,7 +781,7 @@ def clean(self, data: str) -> str: def _deserialize(self, data: Any) -> bytes: return base64.b64decode(data) - def _serialize(self, data: Any) -> Optional[str]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]: if data is None: return None return base64.b64encode(data).decode() @@ -2619,7 +2632,7 @@ def _deserialize(self, data: Any) -> Union["IPv4Address", "IPv6Address"]: # the ipaddress library for pypy only accepts unicode. return ipaddress.ip_address(unicode(data)) - def _serialize(self, data: Any) -> Optional[str]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]: if data is None: return None return str(data) @@ -3367,7 +3380,7 @@ def __init__( def _deserialize(self, data: Any) -> "Query": return Q(data) # type: ignore[no-any-return] - def _serialize(self, data: Any) -> Optional[Dict[str, Any]]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[Dict[str, Any]]: if data is None: return None return data.to_dict() # type: ignore[no-any-return] @@ -3849,9 +3862,6 @@ class SemanticText(Field): by using the Update mapping API. Use the Create inference API to create the endpoint. If not specified, the inference endpoint defined by inference_id will be used at both index and query time. - :arg index_options: Settings for index_options that override any - defaults used by semantic_text, for example specific quantization - settings. :arg chunking_settings: Settings for chunking text into smaller passages. If specified, these will override the chunking settings sent in the inference endpoint associated with inference_id. If @@ -3867,9 +3877,6 @@ def __init__( meta: Union[Mapping[str, str], "DefaultType"] = DEFAULT, inference_id: Union[str, "DefaultType"] = DEFAULT, search_inference_id: Union[str, "DefaultType"] = DEFAULT, - index_options: Union[ - "types.SemanticTextIndexOptions", Dict[str, Any], "DefaultType" - ] = DEFAULT, chunking_settings: Union[ "types.ChunkingSettings", Dict[str, Any], "DefaultType" ] = DEFAULT, @@ -3881,8 +3888,6 @@ def __init__( kwargs["inference_id"] = inference_id if search_inference_id is not DEFAULT: kwargs["search_inference_id"] = search_inference_id - if index_options is not DEFAULT: - kwargs["index_options"] = index_options if chunking_settings is not DEFAULT: kwargs["chunking_settings"] = chunking_settings super().__init__(*args, **kwargs) diff --git a/elasticsearch/dsl/types.py b/elasticsearch/dsl/types.py index 383a69d83..452a945dd 100644 --- a/elasticsearch/dsl/types.py +++ b/elasticsearch/dsl/types.py @@ -144,26 +144,8 @@ def __init__( class ChunkingSettings(AttrDict[Any]): """ - :arg strategy: (required) The chunking strategy: `sentence`, `word`, - `none` or `recursive`. * If `strategy` is set to `recursive`, - you must also specify: - `max_chunk_size` - either `separators` - or`separator_group` Learn more about different chunking - strategies in the linked documentation. Defaults to `sentence` if - omitted. - :arg separator_group: (required) This parameter is only applicable - when using the `recursive` chunking strategy. Sets a predefined - list of separators in the saved chunking settings based on the - selected text type. Values can be `markdown` or `plaintext`. - Using this parameter is an alternative to manually specifying a - custom `separators` list. - :arg separators: (required) A list of strings used as possible split - points when chunking text with the `recursive` strategy. Each - string can be a plain string or a regular expression (regex) - pattern. The system tries each separator in order to split the - text, starting from the first item in the list. After splitting, - it attempts to recombine smaller pieces into larger chunks that - stay within the `max_chunk_size` limit, to reduce the total number - of chunks generated. + :arg strategy: (required) The chunking strategy: `sentence` or `word`. + Defaults to `sentence` if omitted. :arg max_chunk_size: (required) The maximum size of a chunk in words. This value cannot be higher than `300` or lower than `20` (for `sentence` strategy) or `10` (for `word` strategy). Defaults to @@ -178,8 +160,6 @@ class ChunkingSettings(AttrDict[Any]): """ strategy: Union[str, DefaultType] - separator_group: Union[str, DefaultType] - separators: Union[Sequence[str], DefaultType] max_chunk_size: Union[int, DefaultType] overlap: Union[int, DefaultType] sentence_overlap: Union[int, DefaultType] @@ -188,8 +168,6 @@ def __init__( self, *, strategy: Union[str, DefaultType] = DEFAULT, - separator_group: Union[str, DefaultType] = DEFAULT, - separators: Union[Sequence[str], DefaultType] = DEFAULT, max_chunk_size: Union[int, DefaultType] = DEFAULT, overlap: Union[int, DefaultType] = DEFAULT, sentence_overlap: Union[int, DefaultType] = DEFAULT, @@ -197,10 +175,6 @@ def __init__( ): if strategy is not DEFAULT: kwargs["strategy"] = strategy - if separator_group is not DEFAULT: - kwargs["separator_group"] = separator_group - if separators is not DEFAULT: - kwargs["separators"] = separators if max_chunk_size is not DEFAULT: kwargs["max_chunk_size"] = max_chunk_size if overlap is not DEFAULT: @@ -3165,26 +3139,6 @@ def __init__( super().__init__(kwargs) -class SemanticTextIndexOptions(AttrDict[Any]): - """ - :arg dense_vector: - """ - - dense_vector: Union["DenseVectorIndexOptions", Dict[str, Any], DefaultType] - - def __init__( - self, - *, - dense_vector: Union[ - "DenseVectorIndexOptions", Dict[str, Any], DefaultType - ] = DEFAULT, - **kwargs: Any, - ): - if dense_vector is not DEFAULT: - kwargs["dense_vector"] = dense_vector - super().__init__(kwargs) - - class ShapeFieldQuery(AttrDict[Any]): """ :arg indexed_shape: Queries using a pre-indexed shape. diff --git a/elasticsearch/dsl/utils.py b/elasticsearch/dsl/utils.py index 127a48cc2..cce3c052c 100644 --- a/elasticsearch/dsl/utils.py +++ b/elasticsearch/dsl/utils.py @@ -603,7 +603,7 @@ def to_dict(self, skip_empty: bool = True) -> Dict[str, Any]: # if this is a mapped field, f = self.__get_field(k) if f and f._coerce: - v = f.serialize(v) + v = f.serialize(v, skip_empty=skip_empty) # if someone assigned AttrList, unwrap it if isinstance(v, AttrList): diff --git a/elasticsearch/esql/__init__.py b/elasticsearch/esql/__init__.py index d872c329a..8da8f852a 100644 --- a/elasticsearch/esql/__init__.py +++ b/elasticsearch/esql/__init__.py @@ -15,4 +15,5 @@ # specific language governing permissions and limitations # under the License. +from ..dsl import E # noqa: F401 from .esql import ESQL, and_, not_, or_ # noqa: F401 diff --git a/elasticsearch/esql/esql.py b/elasticsearch/esql/esql.py index 07ccdf839..05f4e3e3e 100644 --- a/elasticsearch/esql/esql.py +++ b/elasticsearch/esql/esql.py @@ -16,6 +16,7 @@ # under the License. import json +import re from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Tuple, Type, Union @@ -111,6 +112,29 @@ def render(self) -> str: def _render_internal(self) -> str: pass + @staticmethod + def _format_index(index: IndexType) -> str: + return index._index._name if hasattr(index, "_index") else str(index) + + @staticmethod + def _format_id(id: FieldType, allow_patterns: bool = False) -> str: + s = str(id) # in case it is an InstrumentedField + if allow_patterns and "*" in s: + return s # patterns cannot be escaped + if re.fullmatch(r"[a-zA-Z_@][a-zA-Z0-9_\.]*", s): + return s + # this identifier needs to be escaped + s.replace("`", "``") + return f"`{s}`" + + @staticmethod + def _format_expr(expr: ExpressionType) -> str: + return ( + json.dumps(expr) + if not isinstance(expr, (str, InstrumentedExpression)) + else str(expr) + ) + def _is_forked(self) -> bool: if self.__class__.__name__ == "Fork": return True @@ -427,7 +451,7 @@ def sample(self, probability: float) -> "Sample": """ return Sample(self, probability) - def sort(self, *columns: FieldType) -> "Sort": + def sort(self, *columns: ExpressionType) -> "Sort": """The ``SORT`` processing command sorts a table on one or more columns. :param columns: The columns to sort on. @@ -570,15 +594,12 @@ def metadata(self, *fields: FieldType) -> "From": return self def _render_internal(self) -> str: - indices = [ - index if isinstance(index, str) else index._index._name - for index in self._indices - ] + indices = [self._format_index(index) for index in self._indices] s = f'{self.__class__.__name__.upper()} {", ".join(indices)}' if self._metadata_fields: s = ( s - + f' METADATA {", ".join([str(field) for field in self._metadata_fields])}' + + f' METADATA {", ".join([self._format_id(field) for field in self._metadata_fields])}' ) return s @@ -594,7 +615,11 @@ class Row(ESQLBase): def __init__(self, **params: ExpressionType): super().__init__() self._params = { - k: json.dumps(v) if not isinstance(v, InstrumentedExpression) else v + self._format_id(k): ( + json.dumps(v) + if not isinstance(v, InstrumentedExpression) + else self._format_expr(v) + ) for k, v in params.items() } @@ -615,7 +640,7 @@ def __init__(self, item: str): self._item = item def _render_internal(self) -> str: - return f"SHOW {self._item}" + return f"SHOW {self._format_id(self._item)}" class Branch(ESQLBase): @@ -667,11 +692,11 @@ def as_(self, type_name: str, pvalue_name: str) -> "ChangePoint": return self def _render_internal(self) -> str: - key = "" if not self._key else f" ON {self._key}" + key = "" if not self._key else f" ON {self._format_id(self._key)}" names = ( "" if not self._type_name and not self._pvalue_name - else f' AS {self._type_name or "type"}, {self._pvalue_name or "pvalue"}' + else f' AS {self._format_id(self._type_name or "type")}, {self._format_id(self._pvalue_name or "pvalue")}' ) return f"CHANGE_POINT {self._value}{key}{names}" @@ -709,12 +734,13 @@ def with_(self, inference_id: str) -> "Completion": def _render_internal(self) -> str: if self._inference_id is None: raise ValueError("The completion command requires an inference ID") + with_ = {"inference_id": self._inference_id} if self._named_prompt: column = list(self._named_prompt.keys())[0] prompt = list(self._named_prompt.values())[0] - return f"COMPLETION {column} = {prompt} WITH {self._inference_id}" + return f"COMPLETION {self._format_id(column)} = {self._format_id(prompt)} WITH {json.dumps(with_)}" else: - return f"COMPLETION {self._prompt[0]} WITH {self._inference_id}" + return f"COMPLETION {self._format_id(self._prompt[0])} WITH {json.dumps(with_)}" class Dissect(ESQLBase): @@ -742,9 +768,13 @@ def append_separator(self, separator: str) -> "Dissect": def _render_internal(self) -> str: sep = ( - "" if self._separator is None else f' APPEND_SEPARATOR="{self._separator}"' + "" + if self._separator is None + else f" APPEND_SEPARATOR={json.dumps(self._separator)}" + ) + return ( + f"DISSECT {self._format_id(self._input)} {json.dumps(self._pattern)}{sep}" ) - return f"DISSECT {self._input} {json.dumps(self._pattern)}{sep}" class Drop(ESQLBase): @@ -760,7 +790,7 @@ def __init__(self, parent: ESQLBase, *columns: FieldType): self._columns = columns def _render_internal(self) -> str: - return f'DROP {", ".join([str(col) for col in self._columns])}' + return f'DROP {", ".join([self._format_id(col, allow_patterns=True) for col in self._columns])}' class Enrich(ESQLBase): @@ -814,12 +844,18 @@ def with_(self, *fields: FieldType, **named_fields: FieldType) -> "Enrich": return self def _render_internal(self) -> str: - on = "" if self._match_field is None else f" ON {self._match_field}" + on = ( + "" + if self._match_field is None + else f" ON {self._format_id(self._match_field)}" + ) with_ = "" if self._named_fields: - with_ = f' WITH {", ".join([f"{name} = {field}" for name, field in self._named_fields.items()])}' + with_ = f' WITH {", ".join([f"{self._format_id(name)} = {self._format_id(field)}" for name, field in self._named_fields.items()])}' elif self._fields is not None: - with_ = f' WITH {", ".join([str(field) for field in self._fields])}' + with_ = ( + f' WITH {", ".join([self._format_id(field) for field in self._fields])}' + ) return f"ENRICH {self._policy}{on}{with_}" @@ -832,7 +868,10 @@ class Eval(ESQLBase): """ def __init__( - self, parent: ESQLBase, *columns: FieldType, **named_columns: FieldType + self, + parent: ESQLBase, + *columns: ExpressionType, + **named_columns: ExpressionType, ): if columns and named_columns: raise ValueError( @@ -844,10 +883,13 @@ def __init__( def _render_internal(self) -> str: if isinstance(self._columns, dict): cols = ", ".join( - [f"{name} = {value}" for name, value in self._columns.items()] + [ + f"{self._format_id(name)} = {self._format_expr(value)}" + for name, value in self._columns.items() + ] ) else: - cols = ", ".join([f"{col}" for col in self._columns]) + cols = ", ".join([f"{self._format_expr(col)}" for col in self._columns]) return f"EVAL {cols}" @@ -900,7 +942,7 @@ def __init__(self, parent: ESQLBase, input: FieldType, pattern: str): self._pattern = pattern def _render_internal(self) -> str: - return f"GROK {self._input} {json.dumps(self._pattern)}" + return f"GROK {self._format_id(self._input)} {json.dumps(self._pattern)}" class Keep(ESQLBase): @@ -916,7 +958,7 @@ def __init__(self, parent: ESQLBase, *columns: FieldType): self._columns = columns def _render_internal(self) -> str: - return f'KEEP {", ".join([f"{col}" for col in self._columns])}' + return f'KEEP {", ".join([f"{self._format_id(col, allow_patterns=True)}" for col in self._columns])}' class Limit(ESQLBase): @@ -932,7 +974,7 @@ def __init__(self, parent: ESQLBase, max_number_of_rows: int): self._max_number_of_rows = max_number_of_rows def _render_internal(self) -> str: - return f"LIMIT {self._max_number_of_rows}" + return f"LIMIT {json.dumps(self._max_number_of_rows)}" class LookupJoin(ESQLBase): @@ -967,7 +1009,9 @@ def _render_internal(self) -> str: if isinstance(self._lookup_index, str) else self._lookup_index._index._name ) - return f"LOOKUP JOIN {index} ON {self._field}" + return ( + f"LOOKUP JOIN {self._format_index(index)} ON {self._format_id(self._field)}" + ) class MvExpand(ESQLBase): @@ -983,7 +1027,7 @@ def __init__(self, parent: ESQLBase, column: FieldType): self._column = column def _render_internal(self) -> str: - return f"MV_EXPAND {self._column}" + return f"MV_EXPAND {self._format_id(self._column)}" class Rename(ESQLBase): @@ -999,7 +1043,7 @@ def __init__(self, parent: ESQLBase, **columns: FieldType): self._columns = columns def _render_internal(self) -> str: - return f'RENAME {", ".join([f"{old_name} AS {new_name}" for old_name, new_name in self._columns.items()])}' + return f'RENAME {", ".join([f"{self._format_id(old_name)} AS {self._format_id(new_name)}" for old_name, new_name in self._columns.items()])}' class Sample(ESQLBase): @@ -1015,7 +1059,7 @@ def __init__(self, parent: ESQLBase, probability: float): self._probability = probability def _render_internal(self) -> str: - return f"SAMPLE {self._probability}" + return f"SAMPLE {json.dumps(self._probability)}" class Sort(ESQLBase): @@ -1026,12 +1070,16 @@ class Sort(ESQLBase): in a single expression. """ - def __init__(self, parent: ESQLBase, *columns: FieldType): + def __init__(self, parent: ESQLBase, *columns: ExpressionType): super().__init__(parent) self._columns = columns def _render_internal(self) -> str: - return f'SORT {", ".join([f"{col}" for col in self._columns])}' + sorts = [ + " ".join([self._format_id(term) for term in str(col).split(" ")]) + for col in self._columns + ] + return f'SORT {", ".join([f"{sort}" for sort in sorts])}' class Stats(ESQLBase): @@ -1062,14 +1110,17 @@ def by(self, *grouping_expressions: ExpressionType) -> "Stats": def _render_internal(self) -> str: if isinstance(self._expressions, dict): - exprs = [f"{key} = {value}" for key, value in self._expressions.items()] + exprs = [ + f"{self._format_id(key)} = {self._format_expr(value)}" + for key, value in self._expressions.items() + ] else: - exprs = [f"{expr}" for expr in self._expressions] + exprs = [f"{self._format_expr(expr)}" for expr in self._expressions] expression_separator = ",\n " by = ( "" if self._grouping_expressions is None - else f'\n BY {", ".join([f"{expr}" for expr in self._grouping_expressions])}' + else f'\n BY {", ".join([f"{self._format_expr(expr)}" for expr in self._grouping_expressions])}' ) return f'STATS {expression_separator.join([f"{expr}" for expr in exprs])}{by}' @@ -1087,7 +1138,7 @@ def __init__(self, parent: ESQLBase, *expressions: ExpressionType): self._expressions = expressions def _render_internal(self) -> str: - return f'WHERE {" AND ".join([f"{expr}" for expr in self._expressions])}' + return f'WHERE {" AND ".join([f"{self._format_expr(expr)}" for expr in self._expressions])}' def and_(*expressions: InstrumentedExpression) -> "InstrumentedExpression": diff --git a/elasticsearch/esql/functions.py b/elasticsearch/esql/functions.py index 515e3ddfc..91f18d2d8 100644 --- a/elasticsearch/esql/functions.py +++ b/elasticsearch/esql/functions.py @@ -19,11 +19,15 @@ from typing import Any from elasticsearch.dsl.document_base import InstrumentedExpression -from elasticsearch.esql.esql import ExpressionType +from elasticsearch.esql.esql import ESQLBase, ExpressionType def _render(v: Any) -> str: - return json.dumps(v) if not isinstance(v, InstrumentedExpression) else str(v) + return ( + json.dumps(v) + if not isinstance(v, InstrumentedExpression) + else ESQLBase._format_expr(v) + ) def abs(number: ExpressionType) -> InstrumentedExpression: @@ -69,7 +73,9 @@ def atan2( :param y_coordinate: y coordinate. If `null`, the function returns `null`. :param x_coordinate: x coordinate. If `null`, the function returns `null`. """ - return InstrumentedExpression(f"ATAN2({y_coordinate}, {x_coordinate})") + return InstrumentedExpression( + f"ATAN2({_render(y_coordinate)}, {_render(x_coordinate)})" + ) def avg(number: ExpressionType) -> InstrumentedExpression: @@ -114,7 +120,7 @@ def bucket( :param to: End of the range. Can be a number, a date or a date expressed as a string. """ return InstrumentedExpression( - f"BUCKET({_render(field)}, {_render(buckets)}, {from_}, {_render(to)})" + f"BUCKET({_render(field)}, {_render(buckets)}, {_render(from_)}, {_render(to)})" ) @@ -169,7 +175,7 @@ def cidr_match(ip: ExpressionType, block_x: ExpressionType) -> InstrumentedExpre :param ip: IP address of type `ip` (both IPv4 and IPv6 are supported). :param block_x: CIDR block to test the IP against. """ - return InstrumentedExpression(f"CIDR_MATCH({_render(ip)}, {block_x})") + return InstrumentedExpression(f"CIDR_MATCH({_render(ip)}, {_render(block_x)})") def coalesce(first: ExpressionType, rest: ExpressionType) -> InstrumentedExpression: @@ -264,7 +270,7 @@ def date_diff( :param end_timestamp: A string representing an end timestamp """ return InstrumentedExpression( - f"DATE_DIFF({_render(unit)}, {start_timestamp}, {end_timestamp})" + f"DATE_DIFF({_render(unit)}, {_render(start_timestamp)}, {_render(end_timestamp)})" ) @@ -285,7 +291,9 @@ def date_extract( the function returns `null`. :param date: Date expression. If `null`, the function returns `null`. """ - return InstrumentedExpression(f"DATE_EXTRACT({date_part}, {_render(date)})") + return InstrumentedExpression( + f"DATE_EXTRACT({_render(date_part)}, {_render(date)})" + ) def date_format( @@ -301,7 +309,7 @@ def date_format( """ if date_format is not None: return InstrumentedExpression( - f"DATE_FORMAT({json.dumps(date_format)}, {_render(date)})" + f"DATE_FORMAT({_render(date_format)}, {_render(date)})" ) else: return InstrumentedExpression(f"DATE_FORMAT({_render(date)})") @@ -317,7 +325,9 @@ def date_parse( :param date_string: Date expression as a string. If `null` or an empty string, the function returns `null`. """ - return InstrumentedExpression(f"DATE_PARSE({date_pattern}, {date_string})") + return InstrumentedExpression( + f"DATE_PARSE({_render(date_pattern)}, {_render(date_string)})" + ) def date_trunc( @@ -929,7 +939,7 @@ def replace( :param new_string: Replacement string. """ return InstrumentedExpression( - f"REPLACE({_render(string)}, {_render(regex)}, {new_string})" + f"REPLACE({_render(string)}, {_render(regex)}, {_render(new_string)})" ) @@ -1004,7 +1014,7 @@ def scalb(d: ExpressionType, scale_factor: ExpressionType) -> InstrumentedExpres :param scale_factor: Numeric expression for the scale factor. If `null`, the function returns `null`. """ - return InstrumentedExpression(f"SCALB({_render(d)}, {scale_factor})") + return InstrumentedExpression(f"SCALB({_render(d)}, {_render(scale_factor)})") def sha1(input: ExpressionType) -> InstrumentedExpression: @@ -1116,7 +1126,7 @@ def st_contains( first. This means it is not possible to combine `geo_*` and `cartesian_*` parameters. """ - return InstrumentedExpression(f"ST_CONTAINS({geom_a}, {geom_b})") + return InstrumentedExpression(f"ST_CONTAINS({_render(geom_a)}, {_render(geom_b)})") def st_disjoint( @@ -1135,7 +1145,7 @@ def st_disjoint( first. This means it is not possible to combine `geo_*` and `cartesian_*` parameters. """ - return InstrumentedExpression(f"ST_DISJOINT({geom_a}, {geom_b})") + return InstrumentedExpression(f"ST_DISJOINT({_render(geom_a)}, {_render(geom_b)})") def st_distance( @@ -1153,7 +1163,7 @@ def st_distance( also have the same coordinate system as the first. This means it is not possible to combine `geo_point` and `cartesian_point` parameters. """ - return InstrumentedExpression(f"ST_DISTANCE({geom_a}, {geom_b})") + return InstrumentedExpression(f"ST_DISTANCE({_render(geom_a)}, {_render(geom_b)})") def st_envelope(geometry: ExpressionType) -> InstrumentedExpression: @@ -1208,7 +1218,7 @@ def st_geohash_to_long(grid_id: ExpressionType) -> InstrumentedExpression: :param grid_id: Input geohash grid-id. The input can be a single- or multi-valued column or an expression. """ - return InstrumentedExpression(f"ST_GEOHASH_TO_LONG({grid_id})") + return InstrumentedExpression(f"ST_GEOHASH_TO_LONG({_render(grid_id)})") def st_geohash_to_string(grid_id: ExpressionType) -> InstrumentedExpression: @@ -1218,7 +1228,7 @@ def st_geohash_to_string(grid_id: ExpressionType) -> InstrumentedExpression: :param grid_id: Input geohash grid-id. The input can be a single- or multi-valued column or an expression. """ - return InstrumentedExpression(f"ST_GEOHASH_TO_STRING({grid_id})") + return InstrumentedExpression(f"ST_GEOHASH_TO_STRING({_render(grid_id)})") def st_geohex( @@ -1254,7 +1264,7 @@ def st_geohex_to_long(grid_id: ExpressionType) -> InstrumentedExpression: :param grid_id: Input geohex grid-id. The input can be a single- or multi-valued column or an expression. """ - return InstrumentedExpression(f"ST_GEOHEX_TO_LONG({grid_id})") + return InstrumentedExpression(f"ST_GEOHEX_TO_LONG({_render(grid_id)})") def st_geohex_to_string(grid_id: ExpressionType) -> InstrumentedExpression: @@ -1264,7 +1274,7 @@ def st_geohex_to_string(grid_id: ExpressionType) -> InstrumentedExpression: :param grid_id: Input Geohex grid-id. The input can be a single- or multi-valued column or an expression. """ - return InstrumentedExpression(f"ST_GEOHEX_TO_STRING({grid_id})") + return InstrumentedExpression(f"ST_GEOHEX_TO_STRING({_render(grid_id)})") def st_geotile( @@ -1300,7 +1310,7 @@ def st_geotile_to_long(grid_id: ExpressionType) -> InstrumentedExpression: :param grid_id: Input geotile grid-id. The input can be a single- or multi-valued column or an expression. """ - return InstrumentedExpression(f"ST_GEOTILE_TO_LONG({grid_id})") + return InstrumentedExpression(f"ST_GEOTILE_TO_LONG({_render(grid_id)})") def st_geotile_to_string(grid_id: ExpressionType) -> InstrumentedExpression: @@ -1310,7 +1320,7 @@ def st_geotile_to_string(grid_id: ExpressionType) -> InstrumentedExpression: :param grid_id: Input geotile grid-id. The input can be a single- or multi-valued column or an expression. """ - return InstrumentedExpression(f"ST_GEOTILE_TO_STRING({grid_id})") + return InstrumentedExpression(f"ST_GEOTILE_TO_STRING({_render(grid_id)})") def st_intersects( @@ -1330,7 +1340,9 @@ def st_intersects( first. This means it is not possible to combine `geo_*` and `cartesian_*` parameters. """ - return InstrumentedExpression(f"ST_INTERSECTS({geom_a}, {geom_b})") + return InstrumentedExpression( + f"ST_INTERSECTS({_render(geom_a)}, {_render(geom_b)})" + ) def st_within(geom_a: ExpressionType, geom_b: ExpressionType) -> InstrumentedExpression: @@ -1346,7 +1358,7 @@ def st_within(geom_a: ExpressionType, geom_b: ExpressionType) -> InstrumentedExp first. This means it is not possible to combine `geo_*` and `cartesian_*` parameters. """ - return InstrumentedExpression(f"ST_WITHIN({geom_a}, {geom_b})") + return InstrumentedExpression(f"ST_WITHIN({_render(geom_a)}, {_render(geom_b)})") def st_x(point: ExpressionType) -> InstrumentedExpression: diff --git a/test_elasticsearch/test_dsl/test_integration/_async/test_document.py b/test_elasticsearch/test_dsl/test_integration/_async/test_document.py index 99f475cf1..3d769c606 100644 --- a/test_elasticsearch/test_dsl/test_integration/_async/test_document.py +++ b/test_elasticsearch/test_dsl/test_integration/_async/test_document.py @@ -630,7 +630,9 @@ async def test_can_save_to_different_index( async def test_save_without_skip_empty_will_include_empty_fields( async_write_client: AsyncElasticsearch, ) -> None: - test_repo = Repository(field_1=[], field_2=None, field_3={}, meta={"id": 42}) + test_repo = Repository( + field_1=[], field_2=None, field_3={}, owner={"name": None}, meta={"id": 42} + ) assert await test_repo.save(index="test-document", skip_empty=False) assert_doc_equals( @@ -638,7 +640,12 @@ async def test_save_without_skip_empty_will_include_empty_fields( "found": True, "_index": "test-document", "_id": "42", - "_source": {"field_1": [], "field_2": None, "field_3": {}}, + "_source": { + "field_1": [], + "field_2": None, + "field_3": {}, + "owner": {"name": None}, + }, }, await async_write_client.get(index="test-document", id=42), ) diff --git a/test_elasticsearch/test_dsl/_async/test_esql.py b/test_elasticsearch/test_dsl/test_integration/_async/test_esql.py similarity index 88% rename from test_elasticsearch/test_dsl/_async/test_esql.py rename to test_elasticsearch/test_dsl/test_integration/_async/test_esql.py index 7aacb833c..27d26ca99 100644 --- a/test_elasticsearch/test_dsl/_async/test_esql.py +++ b/test_elasticsearch/test_dsl/test_integration/_async/test_esql.py @@ -17,7 +17,7 @@ import pytest -from elasticsearch.dsl import AsyncDocument, M +from elasticsearch.dsl import AsyncDocument, E, M from elasticsearch.esql import ESQL, functions @@ -91,3 +91,13 @@ async def test_esql(async_client): ) r = await async_client.esql.query(query=str(query)) assert r.body["values"] == [[1.95]] + + # find employees by name using a parameter + query = ( + ESQL.from_(Employee) + .where(Employee.first_name == E("?")) + .keep(Employee.last_name) + .sort(Employee.last_name.desc()) + ) + r = await async_client.esql.query(query=str(query), params=["Maria"]) + assert r.body["values"] == [["Luna"], ["Cannon"]] diff --git a/test_elasticsearch/test_dsl/test_integration/_sync/test_document.py b/test_elasticsearch/test_dsl/test_integration/_sync/test_document.py index 05dd05fd9..a005d45bf 100644 --- a/test_elasticsearch/test_dsl/test_integration/_sync/test_document.py +++ b/test_elasticsearch/test_dsl/test_integration/_sync/test_document.py @@ -624,7 +624,9 @@ def test_can_save_to_different_index( def test_save_without_skip_empty_will_include_empty_fields( write_client: Elasticsearch, ) -> None: - test_repo = Repository(field_1=[], field_2=None, field_3={}, meta={"id": 42}) + test_repo = Repository( + field_1=[], field_2=None, field_3={}, owner={"name": None}, meta={"id": 42} + ) assert test_repo.save(index="test-document", skip_empty=False) assert_doc_equals( @@ -632,7 +634,12 @@ def test_save_without_skip_empty_will_include_empty_fields( "found": True, "_index": "test-document", "_id": "42", - "_source": {"field_1": [], "field_2": None, "field_3": {}}, + "_source": { + "field_1": [], + "field_2": None, + "field_3": {}, + "owner": {"name": None}, + }, }, write_client.get(index="test-document", id=42), ) diff --git a/test_elasticsearch/test_dsl/_sync/test_esql.py b/test_elasticsearch/test_dsl/test_integration/_sync/test_esql.py similarity index 88% rename from test_elasticsearch/test_dsl/_sync/test_esql.py rename to test_elasticsearch/test_dsl/test_integration/_sync/test_esql.py index 1c4084fc7..85ceee5ae 100644 --- a/test_elasticsearch/test_dsl/_sync/test_esql.py +++ b/test_elasticsearch/test_dsl/test_integration/_sync/test_esql.py @@ -17,7 +17,7 @@ import pytest -from elasticsearch.dsl import Document, M +from elasticsearch.dsl import Document, E, M from elasticsearch.esql import ESQL, functions @@ -91,3 +91,13 @@ def test_esql(client): ) r = client.esql.query(query=str(query)) assert r.body["values"] == [[1.95]] + + # find employees by name using a parameter + query = ( + ESQL.from_(Employee) + .where(Employee.first_name == E("?")) + .keep(Employee.last_name) + .sort(Employee.last_name.desc()) + ) + r = client.esql.query(query=str(query), params=["Maria"]) + assert r.body["values"] == [["Luna"], ["Cannon"]] diff --git a/test_elasticsearch/test_esql.py b/test_elasticsearch/test_esql.py index 70c9ec679..35b026fb5 100644 --- a/test_elasticsearch/test_esql.py +++ b/test_elasticsearch/test_esql.py @@ -84,7 +84,7 @@ def test_completion(): assert ( query.render() == """ROW question = "What is Elasticsearch?" -| COMPLETION question WITH test_completion_model +| COMPLETION question WITH {"inference_id": "test_completion_model"} | KEEP question, completion""" ) @@ -97,7 +97,7 @@ def test_completion(): assert ( query.render() == """ROW question = "What is Elasticsearch?" -| COMPLETION answer = question WITH test_completion_model +| COMPLETION answer = question WITH {"inference_id": "test_completion_model"} | KEEP question, answer""" ) @@ -128,7 +128,7 @@ def test_completion(): "Synopsis: ", synopsis, "\\n", "Actors: ", MV_CONCAT(actors, ", "), "\\n", ) -| COMPLETION summary = prompt WITH test_completion_model +| COMPLETION summary = prompt WITH {"inference_id": "test_completion_model"} | KEEP title, summary, rating""" ) @@ -160,7 +160,7 @@ def test_completion(): | SORT rating DESC | LIMIT 10 | EVAL prompt = CONCAT("Summarize this movie using the following information: \\n", "Title: ", title, "\\n", "Synopsis: ", synopsis, "\\n", "Actors: ", MV_CONCAT(actors, ", "), "\\n") -| COMPLETION summary = prompt WITH test_completion_model +| COMPLETION summary = prompt WITH {"inference_id": "test_completion_model"} | KEEP title, summary, rating""" ) @@ -713,3 +713,11 @@ def test_match_operator(): == """FROM books | WHERE author:"Faulkner\"""" ) + + +def test_parameters(): + query = ESQL.from_("employees").where("name == ?") + assert query.render() == "FROM employees\n| WHERE name == ?" + + query = ESQL.from_("employees").where(E("name") == E("?")) + assert query.render() == "FROM employees\n| WHERE name == ?" diff --git a/test_elasticsearch/test_server/test_rest_api_spec.py b/test_elasticsearch/test_server/test_rest_api_spec.py index a84f0822a..768453c10 100644 --- a/test_elasticsearch/test_server/test_rest_api_spec.py +++ b/test_elasticsearch/test_server/test_rest_api_spec.py @@ -78,6 +78,7 @@ "cluster/voting_config_exclusions", "entsearch/10_basic", "indices/clone", + "indices/data_stream_mappings[0]", "indices/resolve_cluster", "indices/settings", "indices/split", @@ -494,7 +495,7 @@ def remove_implicit_resolver(cls, tag_to_remove): # Try loading the REST API test specs from the Elastic Artifacts API try: # Construct the HTTP and Elasticsearch client - http = urllib3.PoolManager(retries=10) + http = urllib3.PoolManager(retries=urllib3.Retry(total=10)) yaml_tests_url = ( "https://api.github.com/repos/elastic/elasticsearch-clients-tests/zipball/main" diff --git a/test_elasticsearch/utils.py b/test_elasticsearch/utils.py index 021deb76e..cfcb5259c 100644 --- a/test_elasticsearch/utils.py +++ b/test_elasticsearch/utils.py @@ -179,7 +179,7 @@ def wipe_data_streams(client): def wipe_indices(client): indices = client.cat.indices().strip().splitlines() if len(indices) > 0: - index_names = [i.split(" ")[2] for i in indices] + index_names = [i.split()[2] for i in indices] client.options(ignore_status=404).indices.delete( index=",".join(index_names), expand_wildcards="all", diff --git a/utils/templates/field.py.tpl b/utils/templates/field.py.tpl index 8a4c73f33..8699d852e 100644 --- a/utils/templates/field.py.tpl +++ b/utils/templates/field.py.tpl @@ -119,9 +119,16 @@ class Field(DslBase): def __getitem__(self, subfield: str) -> "Field": return cast(Field, self._params.get("fields", {})[subfield]) - def _serialize(self, data: Any) -> Any: + def _serialize(self, data: Any, skip_empty: bool) -> Any: return data + def _safe_serialize(self, data: Any, skip_empty: bool) -> Any: + try: + return self._serialize(data, skip_empty) + except TypeError: + # older method signature, without skip_empty + return self._serialize(data) # type: ignore[call-arg] + def _deserialize(self, data: Any) -> Any: return data @@ -133,10 +140,10 @@ class Field(DslBase): return AttrList([]) return self._empty() - def serialize(self, data: Any) -> Any: + def serialize(self, data: Any, skip_empty: bool = True) -> Any: if isinstance(data, (list, AttrList, tuple)): - return list(map(self._serialize, cast(Iterable[Any], data))) - return self._serialize(data) + return list(map(self._safe_serialize, cast(Iterable[Any], data), [skip_empty] * len(data))) + return self._safe_serialize(data, skip_empty) def deserialize(self, data: Any) -> Any: if isinstance(data, (list, AttrList, tuple)): @@ -186,7 +193,7 @@ class RangeField(Field): data = {k: self._core_field.deserialize(v) for k, v in data.items()} # type: ignore[union-attr] return Range(data) - def _serialize(self, data: Any) -> Optional[Dict[str, Any]]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[Dict[str, Any]]: if data is None: return None if not isinstance(data, collections.abc.Mapping): @@ -318,7 +325,7 @@ class {{ k.name }}({{ k.parent }}): return self._wrap(data) def _serialize( - self, data: Optional[Union[Dict[str, Any], "InnerDoc"]] + self, data: Optional[Union[Dict[str, Any], "InnerDoc"]], skip_empty: bool ) -> Optional[Dict[str, Any]]: if data is None: return None @@ -327,7 +334,7 @@ class {{ k.name }}({{ k.parent }}): if isinstance(data, collections.abc.Mapping): return data - return data.to_dict() + return data.to_dict(skip_empty=skip_empty) def clean(self, data: Any) -> Any: data = super().clean(data) @@ -433,7 +440,7 @@ class {{ k.name }}({{ k.parent }}): # the ipaddress library for pypy only accepts unicode. return ipaddress.ip_address(unicode(data)) - def _serialize(self, data: Any) -> Optional[str]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]: if data is None: return None return str(data) @@ -448,7 +455,7 @@ class {{ k.name }}({{ k.parent }}): def _deserialize(self, data: Any) -> bytes: return base64.b64decode(data) - def _serialize(self, data: Any) -> Optional[str]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]: if data is None: return None return base64.b64encode(data).decode() @@ -458,7 +465,7 @@ class {{ k.name }}({{ k.parent }}): def _deserialize(self, data: Any) -> "Query": return Q(data) # type: ignore[no-any-return] - def _serialize(self, data: Any) -> Optional[Dict[str, Any]]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[Dict[str, Any]]: if data is None: return None return data.to_dict() # type: ignore[no-any-return]